# <span style="color:lightblue">Bimodal transformer for trajectory prediction</span>

This notebook contains the necessary code to train the model we built to train on eth dataset. <br>
Our approach is based on the mechanism on cross attention. <br>
Figure 1 contains the general diagram of our model which is composed of a crossattention encoder and a classical transformer decoder.

<figure>
<center><img src="https://raw.githubusercontent.com/schockschock/trajectory_bitransformer_/main/img/Trajectory_cross_attention_transformer.drawio.png" title="Alternative text" width="1300"/>
<figcaption align = "center"> Figure 1 : Diagram of our model</figcaption>
</center>
</figure>

## Import of the librairies

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import softmax
import torch.optim as optim
from torch.distributions import Normal
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision.utils import save_image

# Log
from torch.utils.tensorboard import SummaryWriter

import sqlite3
import time
import math
from copy import deepcopy
import os
import datetime
import shutil
import pandas as pd

from models.TrajectoryDataset import TrajectoryPredictionDataset
from models.indivSeq2SeqCA import Seq2SeqCA
from models.resnet.feature_extraction import Resnet

## Definition of environment variables

In [4]:
dataset_name = "eth" # dataset options: 'university', 'zara_01', 'zara_02', 'eth', 'hotel'

#Building of output folder
now = datetime.datetime.now() # current date and time
__file__=os.getcwd()
current_time_date = now.strftime("%d_%m_%y_%H_%M_%S")
run_folder  = "Outputs/traj_pred_"+ dataset_name + "_" + str(os.path.basename(__file__)) + str(current_time_date)
os.makedirs(run_folder)

# Make log folder for tensorboard
#SummaryWriter_path = run_folder + "/log"
#os.makedirs(SummaryWriter_path) 
SummaryWriter_path = '/notebook_data/work_dirs/first_test'
writer = SummaryWriter(SummaryWriter_path,comment="ADE_FDE_Train")

# Make image folder to save outputs
image_path  = run_folder + "/Visual_Prediction"
os.makedirs(image_path)

#cuda env
if torch.cuda.is_available():
    device = torch.device("cuda")
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    torch.cuda.device_count()
    torch.cuda.current_device()
else :
    device = torch.device("cpu")

## Connexion to database

In [10]:
#DataBase Variables
image_folder_path       = 'data//data_trajpred/'+dataset_name
DB_PATH_train     = "./data/data_trajpred/"+dataset_name+"/pos_data_train.db"
cnx_train         = sqlite3.connect(DB_PATH_train)
DB_PATH_val     = "./data/data_trajpred/"+dataset_name+"/pos_data_val.db"
cnx_val         = sqlite3.connect(DB_PATH_val)
DB_DIR      = run_folder + '/database'
os.makedirs( DB_DIR )
DB_PATH2    = DB_DIR+'/db_one_ped_delta_coordinates_results.db'
cnx2        = sqlite3.connect(DB_PATH2)

FileExistsError: [Errno 17] File exists: 'Outputs/traj_pred_eth_trajectory_bitransformer_-main08_03_23_14_54_57/database'

## Training variables

In [7]:
#Other variables
T_obs                   = 8
T_pred                  = 12
T_total                 = T_obs + T_pred #8+12=20
data_id                 = 0 
batch_size              = 4 #10#100 #15 #2
chunk_size              = batch_size * T_total # Chunksize should be multiple of T_total
in_size                 = 2
stochastic_out_size     = in_size * 2
hidden_size             = 256 #!64
embed_size              = 64 #16 #!64
global dropout_val
dropout_val             = 0.2 #0.5
teacher_forcing_ratio   = 0.7 # 0.9
regularization_factor   = 0.5 # 0.001
avg_n_path_eval         = 20
bst_n_path_eval         = 20
path_mode               = "top5" #"avg","bst","single","top5"
regularization_mode     = "regular" #"weighted","e_weighted", "regular"
startpoint_mode         = "on" #"on","off"
enc_out                 = "on" #"on","off"
biased_loss_mode        = 0 # 0 , 1

table_out   = "results_delta"
table       = "dataset_T_length_20delta_coordinates" #"dataset_T_length_"+str(T_total)+"delta_coordinates"
df_id       = pd.read_sql_query("SELECT data_id FROM "+table, cnx_train)
data_size   = df_id.data_id.max() * T_total
epoch_num   = 100
from_epoch  = 0

#Visual Variables
image_size              = 256  
image_dimension         = 3
mask_size               = 16
visual_features_size    = 128 
visual_embed_size       = 64  #128 #256 #64
vsn_module_out_size    = 256
to_pil = torchvision.transforms.ToPILImage()


#Model Path
model_path = run_folder + "/NNmodel" 
os.makedirs(model_path)   
model_path = model_path + str("/model")

## Needed methods for the training

In [8]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.2, 0.2)

        
def distance_from_line_regularizer(input_tensor, prediction):
    sum_sigma_distance = torch.zeros(1)
    input_tensor = input_tensor.double()
    prediction = prediction.double()
    input_tensor = input_tensor.cumsum(dim=1)
    X = torch.ones_like(input_tensor).to('cuda', non_blocking=True)
    X[:,:,0] = input_tensor[:,:,0]
    Y = (input_tensor[:,:,1]).unsqueeze(-1)
    try:
        try:
            XTX_1 = torch.matmul(X.transpose(-1,-2), X).inverse()
        except:
            XTX_1 = torch.matmul(X.transpose(-1,-2), X).ppinverse()
        XTY = torch.matmul( X.transpose(-1,-2), Y)
        theta = torch.matmul( XTX_1.double(), XTY)
        
        # Calculate real values of prediction instead of delta
        prediction[:,:,0] = prediction[:,:,0] + input_tensor[:,-1,0].unsqueeze(-1) 
        prediction[:,:,1] = prediction[:,:,1] + input_tensor[:,-1,1].unsqueeze(-1)
        
        # Calculate distance ( predicted_points , observation_fitted_line ) over batch
        theta0x0        = theta[:,0,:] * prediction[:,:,0]
        denominator     = torch.sqrt( theta[:,0,:] * theta[:,0,:] + 1 )
        nominator       = theta0x0 + theta[:,1,:] - prediction[:,:,1]
        distance        = nominator.abs() / denominator
        if regularization_mode =='weighted':
            weight              = torch.flip( torch.arange(1,T_pred+1).cuda().double(),[0])
            weight              = (weight / T_pred).repeat(distance.size(0)).view(-1,T_pred)
            weighted_distance   = weight * distance

        elif regularization_mode =='e_weighted':
            weight              = torch.flip( torch.arange(1,T_pred+1).cuda().double(),[0])
            weight              = (weight / T_pred).repeat(distance.size(0)).view(distance.size(0),T_pred)
            weight              = torch.exp(weight)
            weighted_distance   = weight*distance

        else:
            weighted_distance = distance
        sigma_distance  = torch.mean(weighted_distance,1)
        sum_sigma_distance  = torch.mean(sigma_distance)
        return sum_sigma_distance
    except:
        print("SINGULAR VALUE")
        sum_sigma_distance = torch.zeros(1).to('cuda', non_blocking=True) + 20
        return sum_sigma_distance
    
    


# Training of the model

### Definition of data loader

In [9]:
batch_size = 4
print("Initializing train dataset")
print("Batch_size : {}".format(batch_size))

conv_model = Resnet()

dataset_train = TrajectoryPredictionDataset(image_folder_path, cnx_train, conv_model = conv_model)
train_loader  = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True, pin_memory=True)
validation_loader = None

Initializing train dataset
Batch_size : 4


FileNotFoundError: [Errno 2] No such file or directory: 'data/Introvert_ResnetTransf/data_trajpred/eth/visual_data'

## Definition of the model

In [6]:
model = Seq2SeqCA(device,embed_size=512,code_size=512,dropout_val=0.2,batch_size=batch_size)
model = model.to(device).double()

### Training variables

In [9]:
learning_step               = 40
initial_learning_rate       = 0.01
clip                        = 1

# MSE loss
criterion                   = nn.MSELoss(reduction='mean')
criterion_vision            = nn.MSELoss(reduction='sum')

# SGD optimizer
optimizer                   = optim.SGD(model.parameters(), lr=initial_learning_rate, momentum=0.9, weight_decay=0.01) #SGD
scheduler                   = torch.optim.lr_scheduler.StepLR(optimizer, step_size=learning_step, gamma=0.1)
five_fold_cross_validation  = 0

### Training methods

In [12]:
def save_checkpoint(state, is_best, save_path, filename):
    torch.save(state, os.path.join(save_path,filename))
    if is_best:
        shutil.copyfile(os.path.join(save_path,filename), os.path.join(save_path,'model_best.pth'))
        
def train(model, optimizer, scheduler, criterion, criterion_vision, clip,train_loader, validation_loader, save_path=None):
    global batch_size
    i               = None
    checked_frame   = 0

    losses = []
    print("Data Size ",data_size,"\tChunk Size ",chunk_size)
    global teacher_forcing_ratio
    counter =0
    best_val = float("inf")
    for j in range(epoch_num):
        model.train()
        epoch_loss=0
        if j%7 == 6:
            teacher_forcing_ratio = (teacher_forcing_ratio - 0.2) if teacher_forcing_ratio>=0.1 else 0.0

        # Update TeachForce ratio to gradually change during training
        # global teacher_forcing_ratio
        # teacher_forcing_ratio-= 1/epoch_num
        print("TEACHER FORCE RATIO\t",teacher_forcing_ratio)
        #print("Learning Rate\t", scheduler.get_last_lr())

        
        if(j>=from_epoch):
            optimizer.zero_grad()
            start_time = time.time()
            ADE = 0
            FDE = 0
            i   = 0
            for i,data in enumerate(train_loader):
                #print("\n--------------- Batch %d/ %d ---------------"%(j,i)) #(epoch/i)
                # Forward
                obs, pred, visual_obs, frame_tensor = data
                
                input_tensor, output_tensor         = obs.double().squeeze(dim=1).to('cuda', non_blocking=True), pred.double().squeeze(dim=1).to('cuda', non_blocking=True)
                    
                visual_input_tensor                 = visual_obs.double().squeeze(dim=1).to('cuda', non_blocking=True)

                prediction = model(input_tensor, output_tensor, visual_input_tensor)
                
                calculated_prediction = prediction.cumsum(axis=1) #calculated_prediction
                
                loss_line_regularizer = distance_from_line_regularizer(input_tensor,calculated_prediction) #loss (regularisation term Lreg)
                
                if biased_loss_mode:
                    weight  = torch.arange(1,2*T_pred+1,2).cuda().float()
                    weight  = torch.exp(weight / T_pred).repeat(prediction.size(0)).view(prediction.size(0),T_pred,1)
                    loss    = criterion( (calculated_prediction)*weight, torch.cumsum(output_tensor,dim=-2)*weight)
                else:
                    loss    = criterion( (calculated_prediction), torch.cumsum(output_tensor,dim=-2)) #mean squared error (lmse)
                    
                out_x       = output_tensor[:,:,0].cumsum(axis=1)
                out_y       = output_tensor[:,:,1].cumsum(axis=1)
                pred_x      = calculated_prediction[:,:,0]
                pred_y      = calculated_prediction[:,:,1]
                ADE         += ((out_x.sub(pred_x)**2).add((out_y.sub(pred_y)**2))**(1/2)).mean(0).mean(0)   
                # FDE      += ((out_x.sub(pred_x)**2).add((out_y.sub(pred_y)**2))**(1/2)).mean(0)[-1]
                
                # Backward Propagation
                total_loss      = loss.double() + torch.tensor(regularization_factor).to('cuda', non_blocking=True) * loss_line_regularizer.double() #total loss
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
                optimizer.step()
                #print("Total Loss\t{:.2f}".format(total_loss.item()))
                epoch_loss += total_loss.item()
                #print("Time\t\t{:.2f} sec \n".format(time.time() - start_time))
                start_time = time.time()
                torch.cuda.empty_cache()
                writer.close()
                count_div=i
            
            # tensorboard log
            writer.add_scalar('ADE/train', ADE.item()/(count_div+1), counter)
            # writer.add_scalar('FDE/train', FDE.item()/(count_div+1), counter)
            # writer.add_scalar('LOSS/train', epoch_loss/(count_div+1), counter)
            counter += 1

        if scheduler.get_last_lr()[0]>0.001:
            scheduler.step()
        # validation(model, optimizer, criterion, criterion_vision, clip, validation_loader, j) 
        epoch_loss = epoch_loss / (int(data_size/chunk_size))
        losses.append(epoch_loss)
        display.clear_output(wait=True)
        plt.plot(losses, '--ro', label='train loss')
        plt.legend()
        plt.title(f'epoch {j}')
        plt.show()
        print("Time\t\t{:.2f} sec \n".format(time.time() - start_time))
        print("EPOCH ", j, "\tLOSS ", epoch_loss)
        writer.add_scalar('epoch_loss/train', epoch_loss/ (int(data_size/chunk_size)), j ) #see how model performs on the training dataset
        #torch.save( model.state_dict(), model_path+"_current")
        print("-----------------------------------------------\n"+"-----------------------------------------------")

        # save checkpoint for each epoch and a fine called best_model so far 
        print(np.argmin(losses))
        is_best = epoch_loss < best_val
        best_val = min(epoch_loss, best_val)
        print("bestvaleur", best_val)
        if save_path and (j+1)%5==0:
            save_checkpoint({'epoch': j+1,'state_dict': model.state_dict(),'optimizer': optimizer.state_dict(),'scheduler': scheduler.state_dict(),'best_loss': best_val}, is_best, save_path, 'epoch_{}.pth'.format(j+1))
        
    return epoch_loss / (int(data_size/chunk_size))

### training

In [8]:
print("TRAIN")
model.train()
print("path mode\t",path_mode)

save_path = "/notebook_data/Introvert_ResnetTransf/save_models/cross_attention/"
loss  = train(model, optimizer, scheduler, criterion, criterion_vision, clip, train_loader, validation_loader, save_path)
print("LOSS ",loss)

TRAIN
path mode	 top5


NameError: name 'train' is not defined

### Evaluation method

In [13]:
def validation(model, optimizer, criterion, criterion_vision, clip, validation_loader, counter):
    global batch_size
    model.eval()
    i           = None
    ADEs        = 0
    FDEs        = 0
    epoch_loss  = 0
    loss_line_regularizer = 0
    loss = 0 
    total_loss = 0
    ADE  = 0
    FDE  = 0
    for i,data in enumerate(test_loader):
        # Forward
        obs, pred, visual_obs, frame_tensor = data

        input_tensor, output_tensor         = obs.double().squeeze(dim=1).to('cuda', non_blocking=True), pred.double().squeeze(dim=1).to('cuda', non_blocking=True)

        visual_input_tensor                 = visual_obs.double().squeeze(dim=1).to('cuda', non_blocking=True)

        prediction = model(input_tensor, output_tensor, visual_input_tensor) 
        
        calculated_prediction = prediction.cumsum(axis=1) 

        loss_line_regularizer = distance_from_line_regularizer(input_tensor,calculated_prediction)
        
        if biased_loss_mode:
            weight  = torch.arange(1,2*T_pred+1,2).cuda().float()
            weight  = torch.exp(weight / T_pred).repeat(prediction.size(0)).view(prediction.size(0),T_pred,1)
            loss    = criterion( (calculated_prediction)*weight, torch.cumsum(output_tensor,dim=-2)*weight)
        else:
            loss    = criterion( (calculated_prediction), torch.cumsum(output_tensor,dim=-2))
            
        out_x       = output_tensor[:,:,0].cumsum(axis=1)
        out_y       = output_tensor[:,:,1].cumsum(axis=1)
        pred_x      = calculated_prediction[:,:,0]
        pred_y      = calculated_prediction[:,:,1]
        ADE         += ((out_x.sub(pred_x)**2).add((out_y.sub(pred_y)**2))**(1/2)).mean(0).mean(0)   
        FDE         += ((out_x.sub(pred_x)**2).add((out_y.sub(pred_y)**2))**(1/2)).mean(0)[-1]
        total_loss  += loss.double() + regularization_factor * loss_line_regularizer.double() 
        print("Total Loss\t{:.2f}".format(total_loss.item()))

    writer.add_scalar('ADE/val_'+path_mode,             ADE.item()/(i+1),             counter)
    writer.add_scalar('FDE/val_'+path_mode,             FDE.item()/(i+1),             counter)
    writer.add_scalar('LOSS/val_'+path_mode,            total_loss.item()/(i+1)   ,   counter)
    writer.add_scalar('LOSS_c/val_'+path_mode,          loss.item()/(i+1)        ,    counter)
    writer.add_scalar('L-REGULARIZER/val_'+path_mode,   loss_line_regularizer.item()/(i+1), counter)
    writer.close()
    
def evaluate_eval(model, criterion):
    global batch_size
    model.eval()
    i           = None
    ADEs        = 0
    FDEs        = 0
    epoch_loss  = 0
    list_x_obs          = ['x_obs_'+str(i)              for i in range(0,T_obs)] #x_obs_0 --> x_obs_7
    list_y_obs          = ['y_obs_'+str(i)              for i in range(0,T_obs)] #y_obs_0 --> y_obs_7
    list_c_context      = ['context_c_'+str(i)          for i in range(0,hidden_size)] #context_c_0 --> context_c_255
    list_h_context      = ['context_h_'+str(i)          for i in range(0,hidden_size)] #context_h_0 --> context_h_255
    list_x_pred         = ['x_pred_'+str(i)             for i in range(0,T_pred)] #x_pred_0 --> x_pred_11
    list_y_pred         = ['y_pred_'+str(i)             for i in range(0,T_pred)] #y_pred_0 --> y_pred_11
    list_x_stoch_pred_m = ['x_stoch_pred_m_'+str(i)     for i in range(0,T_pred)] #x_stoch_pred_m_0 --> x_stoch_pred_m_11
    list_y_stoch_pred_m = ['y_stoch_pred_m_'+str(i)     for i in range(0,T_pred)] #y_stoch_pred_m_0 --> y_stoch_pred_m_11
    list_x_stoch_pred_s = ['x_stoch_pred_s_'+str(i)     for i in range(0,T_pred)] #x_stoch_pred_s_0 --> x_stoch_pred_s_11
    list_y_stoch_pred_s = ['y_stoch_pred_s_'+str(i)     for i in range(0,T_pred)] #y_stoch_pred_s_0 --> y_stoch_pred_s_11
    list_x_out          = ['x_out_'+str(i)              for i in range(0,T_pred)] #x_out_0 --> x_out_11
    list_y_out          = ['y_out_'+str(i)              for i in range(0,T_pred)] #y_out_0 --> y_out_11
    list_vsn            = ['vsn_'+str(i)               for i in range(0,hidden_size)] #vsn_0 --> vsn_255
    df_out              = pd.DataFrame(columns=list_x_obs + list_y_obs + list_x_out + list_y_out + list_x_pred + list_y_pred + list_x_stoch_pred_m + list_y_stoch_pred_m + list_x_stoch_pred_s + list_y_stoch_pred_s + list_c_context + list_h_context + list_vsn)# + list_vsn_visual)

    for i,data in enumerate(test_loader):
        start_time = time.time()
        # Forward
        obs, pred, visual_obs, frame_tensor = data

        input_tensor, output_tensor         = obs.double().squeeze(dim=1).to('cuda', non_blocking=True), pred.double().squeeze(dim=1).to('cuda', non_blocking=True)

        visual_input_tensor                 = visual_obs.double().squeeze(dim=1).to('cuda', non_blocking=True)

        prediction = model(input_tensor, output_tensor, visual_input_tensor) 
    
        calculated_prediction =  prediction.cumsum(axis=1) 

        loss_line_regularizer = distance_from_line_regularizer(input_tensor,calculated_prediction) #lreg

        if biased_loss_mode:
            weight  = torch.arange(1,2*T_pred+1,2).cuda().float()
            weight  = torch.exp(weight / T_pred).repeat(prediction.size(0)).view(prediction.size(0),T_pred,1)
            loss    = criterion( (calculated_prediction)*weight, torch.cumsum(output_tensor,dim=-2)*weight)
        else:
            loss    = criterion( (calculated_prediction), torch.cumsum(output_tensor,dim=-2)) #lmse
        out_x           = output_tensor[:,:,0].cumsum(axis=1)
        out_y           = output_tensor[:,:,1].cumsum(axis=1)
        pred_x          = calculated_prediction[:,:,0]
        pred_y          = calculated_prediction[:,:,1]
        ADE             = ((out_x.sub(pred_x)**2).add((out_y.sub(pred_y)**2))**(1/2)).mean(0).mean(0)   
        FDE             = ((out_x.sub(pred_x)**2).add((out_y.sub(pred_y)**2))**(1/2)).mean(0)[-1]
        total_loss      = loss.double() + regularization_factor * loss_line_regularizer.double() #loss
        print("Total Loss\t{:.2f}".format(total_loss.item()))
        epoch_loss += total_loss.item()
        ADEs    += ADE.item()
        FDEs    += FDE.item()
        input_x_lin                 = input_tensor[:,:,0].view(-1, T_obs).cpu() #x_obs
        input_y_lin                 = input_tensor[:,:,1].view(-1, T_obs).cpu() #y_obs
        output_x_lin                = output_tensor[:,:,0].view(-1, T_pred).cpu() #x_out
        output_y_lin                = output_tensor[:,:,1].view(-1, T_pred).cpu() #y_out
        prediction_x_lin            = prediction[:,:,0].view(-1, T_pred).cpu() #x_pred
        prediction_y_lin            = prediction[:,:,1].view(-1, T_pred).cpu() #y_pred


        print("Time\t\t{:.2f} sec \n".format(time.time() - start_time))

    # ADE/FDE Report
    out_x  = df_out[['x_out_' +str(i) for i in range(0,T_pred)]].cumsum(axis=1)
    pred_x = df_out[['x_pred_'+str(i) for i in range(0,T_pred)]].cumsum(axis=1)
    out_y  = df_out[['y_out_' +str(i) for i in range(0,T_pred)]].cumsum(axis=1)
    pred_y = df_out[['y_pred_'+str(i) for i in range(0,T_pred)]].cumsum(axis=1)
    ADE = (out_x.sub(pred_x.values)**2).add((out_y.sub(pred_y.values)**2).values, axis=1)**(1/2)
    df_out['ADE'] = ADE.mean(axis=1)
    FDE = ADE.x_out_11
    df_out['FDE'] = FDE
    Mean_ADE = df_out.ADE.mean()
    Mean_FDE = df_out.FDE.mean()
    print("MEAN ADE/FDE\t",Mean_ADE,Mean_FDE)
    writer.add_scalar("Final_Test/ADE_"+path_mode, Mean_ADE, global_step=0)
    writer.add_scalar("Final_Test/FDE_"+path_mode, Mean_FDE, global_step=0)

    df_out.to_sql(table_out+'_'+path_mode, cnx2, if_exists="replace", index=False)
    writer.close()
    return ADEs, FDEs, int(data_size/chunk_size)

### Method to load the model

In [15]:
def load_checkpoint(model, optimizer, scheduler, filename='checkpoint.pth.tar'):
    start_epoch = 0
    best_val=-1
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        try:
            best_val=checkpoint['best_loss']
        except:
            best_val=-1
        print("=> loaded checkpoint '{}' (epoch {})".format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return model, optimizer, scheduler, start_epoch, best_val

### Evaluation

In [None]:
#loading of model
model_dir = "/notebook_data/Introvert_ResnetTransf/save_models/cross_attention/model_best.pth"
print("LOAD MODEL")
# Change device to cpu
#del model
if torch.cuda.is_available():
    torch.cuda.empty_cache()
model = Seq2SeqCA(device,embed_size=512,code_size=512,dropout_val=dropout_val,batch_size=batch_size)
model = model.to(device).double()
model, optimizer, scheduler, start_epoch, best_val = load_checkpoint(model, optimizer, scheduler, filename= model_dir)

#test dataset and loader
print("Initializing val dataset")
dataset_val   = TrajectoryPredictionDataset(image_folder_path, cnx_val)
test_loader   = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=False, pin_memory=True)

#Evaluation
print("EVALUATE bst")
model.eval()
path_mode = 'bst'
print("path mode\t",path_mode)
evaluate_eval(model, criterion)