In [7]:
import pandas as pd
import yaml
from model import YNet
from datetime import datetime
from utils.preprocessing import load_raw_dataset
import time

tic = time.time()

FOLDERNAME = './'
time_stamp = datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
CHECKPOINT = None # FOLDERNAME + 'pretrained_models/2022_01_27_23_58_00_weights.pt' # None means no checkpoint will be used to fine-tune
CONFIG_FILE_PATH = 'config/sdd_raw_train.yaml'  # yaml config file containing all the hyperparameters
EXPERIMENT_NAME = time_stamp  # arbitrary name for this experiment
DATASET_NAME = 'sdd'
SDD_RAW_PATH = FOLDERNAME + "data/sdd_raw"
with open(CONFIG_FILE_PATH) as file:
    params = yaml.load(file, Loader=yaml.FullLoader)
print(f"Experiment {EXPERIMENT_NAME} has started")

if params['use_raw_data']:
    TRAIN_IMAGE_PATH = FOLDERNAME + 'data/sdd_raw/annotations'
    TEST_IMAGE_PATH = FOLDERNAME + 'data/sdd_raw/annotations'
else:
    TEST_DATA_PATH = FOLDERNAME + 'data/SDD/test_trajnet.pkl'
    TEST_IMAGE_PATH = FOLDERNAME + 'data/SDD/test'  # only needed for YNet, PECNet ignores this value
params['segmentation_model_fp'] = FOLDERNAME + 'ynet_additional_files/segmentation_models/SDD_segmentation.pth'
OBS_LEN = 8  # in timesteps
PRED_LEN = 12  # in timesteps
NUM_GOALS = 20  # K_e
NUM_TRAJ = 1  # K_a
ROUNDS = 1  # Y-net is stochastic. How often to evaluate the whole dataset
BATCH_SIZE = 8




Experiment 2022_01_29_02_36_25 has started


In [8]:

if params['use_raw_data']:
    train_data, val_data = load_raw_dataset(path=SDD_RAW_PATH, step=params['step'],
                                  window_size=params['min_num_steps_seq'], stride=params['filter_stride'],
                                  train_labels=params['train_labels'], test_labels=params['test_labels'],
                                  test_per=params['test_per'], max_train_agents=params['max_train_agents'],
                                  train_set_ratio=params['train_set_ratio'], test_on_train=params['test_on_train'],
                                  num_train_agents=params['num_train_agents'], num_test_agents=params['num_test_agents'],
                                  random_train_test=params['random_train_test_split'])
else:
	train_data = pd.read_pickle(TRAIN_IMAGE_PATH)
	val_data = pd.read_pickle(TEST_DATA_PATH)

2662 agents for each training class, 501 agents for test class


In [9]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [33]:
model = YNet(obs_len=OBS_LEN, pred_len=PRED_LEN, params=params)
# print(sum(p.numel() for p in model.model.style_hat.parameters() if p.requires_grad))
if CHECKPOINT: model.load(CHECKPOINT)
model.train_style_enc(train_data, val_data, params, train_image_path=TRAIN_IMAGE_PATH, val_image_path=TEST_IMAGE_PATH,
				experiment_name=EXPERIMENT_NAME, batch_size=BATCH_SIZE, num_goals=NUM_GOALS, num_traj=NUM_TRAJ, 
				device=None, dataset_name=DATASET_NAME, use_raw_data=params['use_raw_data'])

toc = time.time()
print(time.strftime("%Hh%Mm%Ss", time.gmtime(toc - tic)))

Preprocess data


Prepare Dataset: 100%|██████████| 424/424 [00:00<00:00, 1473.07it/s]
Prepare Dataset: 100%|██████████| 184/184 [00:00<00:00, 1348.22it/s]
Prepare Dataset: 100%|██████████| 472/472 [00:00<00:00, 1453.38it/s]
Prepare Dataset: 100%|██████████| 15/15 [00:00<00:00, 1401.06it/s]
Prepare Dataset: 100%|██████████| 16/16 [00:00<00:00, 1420.32it/s]
Prepare Dataset: 100%|██████████| 7/7 [00:00<00:00, 1227.43it/s]


Start training


 35%|███▌      | 150/424 [00:18<00:34,  8.03it/s]
 12%|█▎        | 2/16 [00:00<00:01, 13.27it/s]

Epoch 0: 	Train style accuracy: 0.32367663612231234 	(loss: 3.7747411727905273) 


100%|██████████| 16/16 [00:00<00:00, 17.83it/s]


Epoch 0: 	Valid style accuracy: 0.39460182172862207
Best Epoch 0: 
Val Accuracy: 0.39460182172862207


 35%|███▌      | 150/424 [00:18<00:34,  7.97it/s]
 12%|█▎        | 2/16 [00:00<00:00, 16.78it/s]

Epoch 1: 	Train style accuracy: 0.31889612584302 	(loss: 3.62009596824646) 


100%|██████████| 16/16 [00:00<00:00, 19.50it/s]
  0%|          | 1/424 [00:00<01:18,  5.41it/s]

Epoch 1: 	Valid style accuracy: 0.1553002063025689


 35%|███▌      | 150/424 [00:20<00:37,  7.24it/s]
 12%|█▎        | 2/16 [00:00<00:00, 16.85it/s]

Epoch 2: 	Train style accuracy: 0.3498815708998197 	(loss: 3.525845527648926) 


100%|██████████| 16/16 [00:00<00:00, 19.06it/s]
  0%|          | 2/424 [00:00<00:33, 12.69it/s]

Epoch 2: 	Valid style accuracy: 0.39460182172862207


 35%|███▌      | 150/424 [00:16<00:30,  8.90it/s]
 12%|█▎        | 2/16 [00:00<00:00, 16.83it/s]

Epoch 3: 	Train style accuracy: 0.36249534909257913 	(loss: 3.630995988845825) 


100%|██████████| 16/16 [00:00<00:00, 19.10it/s]
  0%|          | 1/424 [00:00<00:59,  7.15it/s]

Epoch 3: 	Valid style accuracy: 0.1553002063025689


 35%|███▌      | 150/424 [00:18<00:34,  7.92it/s]
 12%|█▎        | 2/16 [00:00<00:00, 16.72it/s]

Epoch 4: 	Train style accuracy: 0.35809290226226287 	(loss: 3.5999529361724854) 


100%|██████████| 16/16 [00:00<00:00, 19.21it/s]
  0%|          | 1/424 [00:00<00:56,  7.50it/s]

Epoch 4: 	Valid style accuracy: 0.1553002063025689


 35%|███▌      | 150/424 [00:18<00:33,  8.06it/s]
 12%|█▎        | 2/16 [00:00<00:00, 15.42it/s]

Epoch 5: 	Train style accuracy: 0.35510884782571794 	(loss: 3.523750066757202) 


100%|██████████| 16/16 [00:00<00:00, 18.78it/s]
  0%|          | 0/424 [00:00<?, ?it/s]

Epoch 5: 	Valid style accuracy: 0.1553002063025689


 35%|███▌      | 150/424 [00:19<00:35,  7.73it/s]
 12%|█▎        | 2/16 [00:00<00:00, 16.68it/s]

Epoch 6: 	Train style accuracy: 0.3538574943113927 	(loss: 3.447648286819458) 


100%|██████████| 16/16 [00:00<00:00, 18.68it/s]
  0%|          | 1/424 [00:00<00:58,  7.21it/s]

Epoch 6: 	Valid style accuracy: 0.2297910993841666


 35%|███▌      | 150/424 [00:19<00:34,  7.85it/s]
 12%|█▎        | 2/16 [00:00<00:00, 16.73it/s]

Epoch 7: 	Train style accuracy: 0.36363668881399624 	(loss: 3.5462779998779297) 


100%|██████████| 16/16 [00:00<00:00, 18.38it/s]
  0%|          | 0/424 [00:00<?, ?it/s]

Epoch 7: 	Valid style accuracy: 0.1553002063025689


 35%|███▌      | 150/424 [00:19<00:35,  7.65it/s]
 12%|█▎        | 2/16 [00:00<00:00, 16.33it/s]

Epoch 8: 	Train style accuracy: 0.3458556109085056 	(loss: 3.5400471687316895) 


100%|██████████| 16/16 [00:00<00:00, 18.65it/s]
  0%|          | 1/424 [00:00<01:01,  6.85it/s]

Epoch 8: 	Valid style accuracy: 0.31145724682992426


 35%|███▌      | 150/424 [00:18<00:34,  7.97it/s]
 12%|█▎        | 2/16 [00:00<00:00, 16.81it/s]

Epoch 9: 	Train style accuracy: 0.37127887945261406 	(loss: 3.478595018386841) 


100%|██████████| 16/16 [00:00<00:00, 18.24it/s]
  0%|          | 2/424 [00:00<00:33, 12.57it/s]

Epoch 9: 	Valid style accuracy: 0.2297910993841666


 35%|███▌      | 150/424 [00:18<00:34,  7.93it/s]
 12%|█▎        | 2/16 [00:00<00:00, 16.74it/s]

Epoch 10: 	Train style accuracy: 0.35518006986805717 	(loss: 3.6003777980804443) 


100%|██████████| 16/16 [00:00<00:00, 18.86it/s]


Epoch 10: 	Valid style accuracy: 0.4265329540207581
Best Epoch 10: 
Val Accuracy: 0.4265329540207581


 35%|███▌      | 150/424 [00:17<00:32,  8.38it/s]
 12%|█▎        | 2/16 [00:00<00:00, 16.48it/s]

Epoch 11: 	Train style accuracy: 0.37808947520694136 	(loss: 3.5195624828338623) 


100%|██████████| 16/16 [00:00<00:00, 19.48it/s]
  0%|          | 0/424 [00:00<?, ?it/s]

Epoch 11: 	Valid style accuracy: 0.4265329540207581


 35%|███▌      | 150/424 [00:18<00:33,  8.22it/s]
 12%|█▎        | 2/16 [00:00<00:00, 17.00it/s]

Epoch 12: 	Train style accuracy: 0.3691340563378595 	(loss: 3.544086456298828) 


100%|██████████| 16/16 [00:00<00:00, 19.11it/s]
  0%|          | 2/424 [00:00<00:32, 13.00it/s]

Epoch 12: 	Valid style accuracy: 0.4265329540207581


 35%|███▌      | 150/424 [00:18<00:33,  8.12it/s]
 12%|█▎        | 2/16 [00:00<00:00, 16.74it/s]

Epoch 13: 	Train style accuracy: 0.33908419022449093 	(loss: 3.5467984676361084) 


100%|██████████| 16/16 [00:00<00:00, 19.51it/s]


Epoch 13: 	Valid style accuracy: 0.4631055768002774
Best Epoch 13: 
Val Accuracy: 0.4631055768002774


 35%|███▌      | 150/424 [00:19<00:35,  7.64it/s]
 12%|█▎        | 2/16 [00:00<00:00, 16.51it/s]

Epoch 14: 	Train style accuracy: 0.36972127550095285 	(loss: 3.499068021774292) 


100%|██████████| 16/16 [00:00<00:00, 18.96it/s]
  0%|          | 2/424 [00:00<00:30, 14.05it/s]

Epoch 14: 	Valid style accuracy: 0.4265329540207581


 35%|███▌      | 150/424 [00:19<00:36,  7.55it/s]
 12%|█▎        | 2/16 [00:00<00:00, 16.60it/s]

Epoch 15: 	Train style accuracy: 0.3899817596553288 	(loss: 3.665250301361084) 


100%|██████████| 16/16 [00:00<00:00, 19.47it/s]
  0%|          | 2/424 [00:00<00:31, 13.52it/s]

Epoch 15: 	Valid style accuracy: 0.36278501080233333


 35%|███▌      | 150/424 [00:17<00:31,  8.60it/s]
 12%|█▎        | 2/16 [00:00<00:00, 16.60it/s]

Epoch 16: 	Train style accuracy: 0.33866659310771696 	(loss: 3.6078133583068848) 


100%|██████████| 16/16 [00:00<00:00, 19.49it/s]
  0%|          | 2/424 [00:00<00:32, 12.90it/s]

Epoch 16: 	Valid style accuracy: 0.4567664291840516


 24%|██▍       | 101/424 [00:12<00:41,  7.78it/s]


KeyboardInterrupt: 