In [2]:
!git clone https://github.com/tedhuang96/torch_data_tutorial.git

Cloning into 'torch_data_tutorial'...
remote: Enumerating objects: 59, done.[K
remote: Counting objects: 100% (59/59), done.[K
remote: Compressing objects: 100% (42/42), done.[K
remote: Total 59 (delta 12), reused 55 (delta 12), pack-reused 0[K
Unpacking objects: 100% (59/59), done.


In [3]:
import torch
import argparse
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from torch_data_tutorial.utils import anorm, create_datasets_test
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [4]:
### expected output: 0.2 ###

p1, p2 = torch.tensor([1.,2.]), torch.tensor([4.,-2.])
print(anorm(p1, p2)) # the inverse of distance between p1 and p2

0.2


In [5]:
### expected output: ###
### Namespace(attn_mech='auth', dataset='zara1', obs_seq_len=8, pred_seq_len=12) ###

def arg_parse_notebook(obs_seq_len, pred_seq_len, dataset, attn_mech):
    obs_seq_len, pred_seq_len = str(obs_seq_len), str(pred_seq_len)
    parser = argparse.ArgumentParser()
    parser.add_argument('--obs_seq_len', type=int, default=0)
    parser.add_argument('--pred_seq_len', type=int, default=0)
    parser.add_argument('--dataset', default='eth',
                        help='eth,hotel,univ,zara1,zara2')
    parser.add_argument('--attn_mech', default='glob_kip',
                        help='attention mechanism: glob_kip, auth')
    args_list = ['--obs_seq_len', obs_seq_len, '--pred_seq_len', pred_seq_len, '--dataset', dataset, '--attn_mech', attn_mech]
    return parser.parse_args(args_list)

obs_seq_len, pred_seq_len, dataset, attn_mech = 8, 12, 'zara1', 'auth'
args = arg_parse_notebook(obs_seq_len, pred_seq_len, dataset, attn_mech)
print(args)

Namespace(attn_mech='auth', dataset='zara1', obs_seq_len=8, pred_seq_len=12)


In [None]:
### expected output: ###
# ['torch_data_tutorial/datasets/zara1/test/crowds_zara01.txt']
# Processing Data .....
# 100% #A PROCESSING BAR# 602/602 [00:13<00:00, 43.32it/s]

pkg_path = 'torch_data_tutorial'
dsets = create_datasets(args, pkg_path, save_datasets=False)

In [None]:
dloader_test = DataLoader(
        dsets['test'],
        batch_size=1,
        shuffle=True,
        num_workers=1)

In [None]:
def get_batch_sample(loader_test, device='cuda:0'):
    batch_count = 1
    for cnt, batch in enumerate(loader_test):
        batch_count += 1
        # Get data
        batch = [tensor.to(device) for tensor in batch]
        # ** Name of variables in a batch
        # * obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,\
        # *    loss_mask, V_obs, A_obs, V_tr, A_tr = batch
        var_names = [
            'obs_traj', 'pred_traj_gt', 'obs_traj_rel', 'pred_traj_gt_rel', 'non_linear_ped',
            'loss_mask', 'V_obs', 'A_obs', 'V_tr', 'A_tr']
        for var_name_i, batch_i in zip(var_names, batch):
            if var_name_i == 'obs_traj' and batch_i.shape[1] == 5: # get five-pedestrian case
                return batch


def visualize_dataloader(data_loader, device='cuda:0'):
    """
    visualize data and attention in a batch generated by data_loader.
    """
    batch = get_batch_sample(data_loader, device=device)
    obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped,\
        loss_mask, V_obs, A_obs, V_tr, A_tr = batch
    obs_ts = 6 # time of attention
    attn = A_obs[0, obs_ts]  # (num_peds, num_peds)
    fig, ax = plt.subplots()
    colors = 'rymbk'
    attn_scale = 1. # size of attention circle
    for ped_i in range(obs_traj.shape[1]):
        x_obs_i = obs_traj[0, ped_i].to('cpu')
        x_gt_i = pred_traj_gt[0, ped_i].to('cpu')
        ax.plot(x_obs_i[0, :], x_obs_i[1, :], 'k.-')
        ax.plot(x_gt_i[0, :], x_gt_i[1, :], 'r.-')
        attn_i = plt.Circle(
            x_obs_i[:, obs_ts], abs(attn[ped_i, 0])*attn_scale, color=colors[ped_i], fill=False)
        ax.add_artist(attn_i)
    ax.set_aspect('equal', adjustable='box')
    plt.show()

In [None]:
### expected output: ###
### Visualization on trajectories and attention ###

visualize_dataloader(dloader_test, device=device)