In [1]:
import numpy as np
import torch

import sys, os

folder_path = "../models/"
if folder_path not in sys.path:
    sys.path.append(folder_path)
from cnmp import CNMP

In [2]:
torch.set_float32_matmul_precision('high')

def get_free_gpu():
    gpu_util = []
    for i in range(torch.cuda.device_count()):
        torch.cuda.set_device(i)  # Switch GPU
        gpu_util.append((i, torch.cuda.utilization()))
    gpu_util.sort(key=lambda x: x[1])
    return gpu_util[0][0]

if torch.cuda.is_available():
    available_gpu = get_free_gpu()
    if available_gpu == 0:
        device = torch.device("cuda:0")
    else:
        device = torch.device(f"cuda:{available_gpu}")
else:
    device = torch.device("cpu")

print("Device :", device)


Device : cuda:0


In [3]:
# load data

data_path = '/home/yigit/projects/cnep/data/32trajectory/y_normalized.npy'
all_y = torch.from_numpy(np.load(data_path)).to(device)
print(all_y.shape)

torch.Size([32, 200, 8])


In [4]:
num_trajs, t_steps, dy = all_y.shape
num_val = num_trajs//4
num_demos = num_trajs - num_val
inds = torch.randperm(num_trajs)
demo_inds, val_inds = inds[:num_demos], inds[num_demos:]

y = all_y[demo_inds]
vy = all_y[val_inds]

In [None]:
dx = 1
n_max = 10
m_max = 10
enc_dims = [128, 128, 128]
dec_dims = [128, 128, 128]
batch_size = num_val

x_single = torch.linspace(0, 1, t_steps).unsqueeze(0).unsqueeze(-1)
x = x_single.repeat(num_demos, 1, 1).to(device)
vx = x_single.repeat(num_val, 1, 1).to(device)

cnmp_ = CNMP(dx, dy, n_max, m_max, enc_dims, dec_dims, batch_size, device)
optimizer = torch.optim.Adam(lr=3e-4, params=cnmp_.parameters())

In [6]:
obs = torch.zeros((batch_size, n_max, dx+dy), dtype=torch.float32, device=device)
tar_x = torch.zeros((batch_size, m_max, dx), dtype=torch.float32, device=device)
tar_y = torch.zeros((batch_size, m_max, dy), dtype=torch.float32, device=device)
obs_mask = torch.zeros((batch_size, n_max), dtype=torch.bool, device=device)
tar_mask = torch.zeros((batch_size, m_max), dtype=torch.bool, device=device)

def prepare_masked_batch(t: list, traj_ids: list):
    obs.fill_(0)
    tar_x.fill_(0)
    tar_y.fill_(0)
    obs_mask.fill_(False)
    tar_mask.fill_(False)

    for i, traj_id in enumerate(traj_ids):
        traj = t[traj_id]

        n = torch.randint(1, n_max, (1,)).item()
        m = torch.randint(1, m_max, (1,)).item()

        permuted_ids = torch.randperm(t_steps)
        n_ids = permuted_ids[:n]
        m_ids = permuted_ids[n:n+m]
        
        obs[i, :n, :dx] = (n_ids/t_steps).unsqueeze(1)  # X
        obs[i, :n, dx:] = traj[n_ids]  # Y
        obs_mask[i, :n] = True
        
        tar_x[i, :m] = (m_ids/t_steps).unsqueeze(1)
        tar_y[i, :m] = traj[m_ids]
        tar_mask[i, :m] = True

val_obs = torch.zeros((batch_size, n_max, dx+dy), dtype=torch.float32, device=device)
val_tar_x = torch.zeros((batch_size, t_steps, dx), dtype=torch.float32, device=device)
val_tar_y = torch.zeros((batch_size, t_steps, dy), dtype=torch.float32, device=device)
val_obs_mask = torch.zeros((batch_size, n_max), dtype=torch.bool, device=device)

def prepare_masked_val_batch(t: list, traj_ids: list):
    val_obs.fill_(0)
    val_tar_x.fill_(0)
    val_tar_y.fill_(0)
    val_obs_mask.fill_(False)

    for i, traj_id in enumerate(traj_ids):
        traj = t[traj_id]

        n = torch.randint(1, n_max, (1,)).item()

        permuted_ids = torch.randperm(t_steps)
        n_ids = permuted_ids[:n]
        m_ids = torch.arange(t_steps)
        
        val_obs[i, :n, :dx] = (n_ids/t_steps).unsqueeze(1)
        val_obs[i, :n, dx:] = traj[n_ids]
        val_obs_mask[i, :n] = True
        
        val_tar_x[i] = (m_ids/t_steps).unsqueeze(1)
        val_tar_y[i] = traj[m_ids]

In [7]:
if torch.__version__ >= "2.0":
    cnmp = torch.compile(cnmp_)
else:
    cnmp = cnmp_

root_folder = f"output/cnmp/last/"
if not os.path.exists(root_folder):
    os.makedirs(root_folder)
if not os.path.exists(f"{root_folder}saved_models/"):
    os.makedirs(f"{root_folder}saved_models/")
if not os.path.exists(f"{root_folder}img/"):
    os.makedirs(f"{root_folder}img/")

torch.save(y, f"{root_folder}y.pt")

epochs = 1_000_000
epoch_iter = num_demos//batch_size  # number of batches per epoch (e.g. 100//32 = 3)
v_epoch_iter = num_val//batch_size  # number of batches per validation (e.g. 100//32 = 3)
avg_loss = 0

val_per_epoch = 1000
min_vl = 1000000

mse_loss = torch.nn.MSELoss()

tl = []
ve = []

cnmp_tl_path = f'{root_folder}cnmp_training_loss.pt'
cnmp_ve_path = f'{root_folder}cnmp_validation_error.pt'

for epoch in range(epochs):
    epoch_loss = 0

    traj_ids = torch.randperm(x.shape[0])[:batch_size*epoch_iter].chunk(epoch_iter)  # [:batch_size*epoch_iter] because nof_trajectories may be indivisible by batch_size

    for i in range(epoch_iter):
        prepare_masked_batch(y, traj_ids[i])

        optimizer.zero_grad()
        pred = cnmp(obs, tar_x, obs_mask)
        loss = cnmp.loss(pred, tar_y, tar_mask)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    epoch_loss = epoch_loss/epoch_iter
    tl.append(epoch_loss)

    if epoch % val_per_epoch == 0:
        with torch.no_grad():
            v_traj_ids = torch.randperm(vx.shape[0])[:batch_size*v_epoch_iter].chunk(v_epoch_iter)
            val_err = 0

            for j in range(v_epoch_iter):
                prepare_masked_val_batch(vy, v_traj_ids[j])

                p = cnmp.val(val_obs, val_tar_x, val_obs_mask)
                vp_means = p[:, :, :dy]
                val_err += mse_loss(vp_means, val_tar_y).item()

            val_err = val_err/v_epoch_iter

            if val_err < min_vl:
                min_vl = val_err
                print(f'New best: {min_vl}')
                torch.save(cnmp_.state_dict(), f'{root_folder}saved_models/cnmp.pt')

            ve.append(val_err)

    avg_loss += epoch_loss

    if epoch % val_per_epoch == 0:
        print("Epoch: {}, CNMP Losses: {}".format(epoch, avg_loss/val_per_epoch))
        avg_loss = 0

torch.save(torch.Tensor(tl), cnmp_tl_path)
torch.save(torch.Tensor(ve), cnmp_ve_path)


# torch.save(cnmp_.state_dict(), f"{root_folder}saved_models/cnmp.pt")

E0315 17:21:10.877000 26667 torch/_subclasses/fake_tensor.py:2017] [0/0] failed while attempting to run meta for aten.mm.default
E0315 17:21:10.877000 26667 torch/_subclasses/fake_tensor.py:2017] [0/0] Traceback (most recent call last):
E0315 17:21:10.877000 26667 torch/_subclasses/fake_tensor.py:2017] [0/0]   File "/home/yigit/.local/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2013, in _dispatch_impl
E0315 17:21:10.877000 26667 torch/_subclasses/fake_tensor.py:2017] [0/0]     r = func(*args, **kwargs)
E0315 17:21:10.877000 26667 torch/_subclasses/fake_tensor.py:2017] [0/0]   File "/home/yigit/.local/lib/python3.10/site-packages/torch/_ops.py", line 716, in __call__
E0315 17:21:10.877000 26667 torch/_subclasses/fake_tensor.py:2017] [0/0]     return self._op(*args, **kwargs)
E0315 17:21:10.877000 26667 torch/_subclasses/fake_tensor.py:2017] [0/0]   File "/home/yigit/.local/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 273, in _fn
E0315 17:2

TorchRuntimeError: Failed running call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(8, 10, 256), grad_fn=<ReluBackward0>), Parameter(FakeTensor(..., device='cuda:0', size=(16, 128), requires_grad=True)), Parameter(FakeTensor(..., device='cuda:0', size=(16,), requires_grad=True))), **{}):
a and b must have same reduction dim, but got [80, 256] X [128, 16].

from user code:
   File "/home/yigit/projects/cnep/training_examples/../models/cnmp.py", line 62, in forward
    pred = self.decoder(rep_tar)  # (batch_size, m_max, output_dim*2)
  File "/home/yigit/.local/lib/python3.10/site-packages/torch/nn/modules/container.py", line 250, in forward
    input = module(input)
  File "/home/yigit/.local/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
