# <span style="color:lightblue">Bimodal transformer for trajectory prediction</span>

This notebook contains the necessary code to train the model we built to train on eth dataset. <br>
Our approach is based on the fact that  `Attention is all you need`, not only in NLP but also in vision. <br>
Figure 1 contains the general diagram of our model which is composed of a bimodal transformer encoder and a classical transformer decoder.

<figure>
<center><img src="https://raw.githubusercontent.com/schockschock/trajectory_bitransformer_/main/img/trajectory_bimodal_transformer.drawio%20(2).png" title="Alternative text" width="1300"/>
<figcaption align = "center"> Figure 1 : Diagram of our model</figcaption>
</center>
</figure>

## Import of the librairies

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import softmax
import torch.optim as optim
from torch.distributions import Normal
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision.utils import save_image

# Log
from torch.utils.tensorboard import SummaryWriter

import sqlite3
import time
import math
from copy import deepcopy
import os
import datetime

from src.TrajectoryDataset import TrajectoryPredictionDataset

## Definition of environment variables

In [2]:
dataset_name = "eth" # dataset options: 'university', 'zara_01', 'zara_02', 'eth', 'hotel'

#Building of output folder
now = datetime.datetime.now() # current date and time
__file__=os.getcwd()
current_time_date = now.strftime("%d_%m_%y_%H_%M_%S")
run_folder  = "Outputs/traj_pred_"+ dataset_name + "_" + str(os.path.basename(__file__)) + str(current_time_date)
os.makedirs(run_folder)

# Make log folder for tensorboard
#SummaryWriter_path = run_folder + "/log"
#os.makedirs(SummaryWriter_path) 
SummaryWriter_path = '/notebook_data/work_dirs/first_test'
writer = SummaryWriter(SummaryWriter_path,comment="ADE_FDE_Train")

# Make image folder to save outputs
image_path  = run_folder + "/Visual_Prediction"
os.makedirs(image_path)

#cuda env
if torch.cuda.is_available():
    device = torch.device("cuda")
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    torch.cuda.device_count()
    torch.cuda.current_device()
else :
    device = torch.device("cpu")



## Connexion to database

In [3]:
#DataBase Variables
image_folder_path       = 'data_trajpred/'+dataset_name
DB_PATH_train     = "./data/Introvert_ResnetTransf/data_trajpred/"+dataset_name+"/pos_data_train.db"
cnx_train         = sqlite3.connect(DB_PATH_train)
DB_PATH_val     = "./data/Introvert_ResnetTransf/data_trajpred/"+dataset_name+"/pos_data_val.db"
cnx_val         = sqlite3.connect(DB_PATH_val)
DB_DIR      = run_folder + '/database'
os.makedirs( DB_DIR )
DB_PATH2    = DB_DIR+'/db_one_ped_delta_coordinates_results.db'
cnx2        = sqlite3.connect(DB_PATH2)

## Training variables

In [4]:
#Other variables
T_obs                   = 8
T_pred                  = 12
T_total                 = T_obs + T_pred #8+12=20
data_id                 = 0 
batch_size              = 4 #10#100 #15 #2
chunk_size              = batch_size * T_total # Chunksize should be multiple of T_total
in_size                 = 2
stochastic_out_size     = in_size * 2
hidden_size             = 256 #!64
embed_size              = 64 #16 #!64
global dropout_val
dropout_val             = 0.2 #0.5
teacher_forcing_ratio   = 0.7 # 0.9
regularization_factor   = 0.5 # 0.001
avg_n_path_eval         = 20
bst_n_path_eval         = 20
path_mode               = "top5" #"avg","bst","single","top5"
regularization_mode     = "regular" #"weighted","e_weighted", "regular"
startpoint_mode         = "on" #"on","off"
enc_out                 = "on" #"on","off"
biased_loss_mode        = 0 # 0 , 1


## Needed methods for the training

In [8]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.2, 0.2)

        
def distance_from_line_regularizer(input_tensor, prediction):
    sum_sigma_distance = torch.zeros(1)
    input_tensor = input_tensor.double()
    prediction = prediction.double()
    input_tensor = input_tensor.cumsum(dim=1)
    X = torch.ones_like(input_tensor).to('cuda', non_blocking=True)
    X[:,:,0] = input_tensor[:,:,0]
    Y = (input_tensor[:,:,1]).unsqueeze(-1)
    try:
        try:
            XTX_1 = torch.matmul(X.transpose(-1,-2), X).inverse()
        except:
            XTX_1 = torch.matmul(X.transpose(-1,-2), X).ppinverse()
        XTY = torch.matmul( X.transpose(-1,-2), Y)
        theta = torch.matmul( XTX_1.double(), XTY)
        
        # Calculate real values of prediction instead of delta
        prediction[:,:,0] = prediction[:,:,0] + input_tensor[:,-1,0].unsqueeze(-1) 
        prediction[:,:,1] = prediction[:,:,1] + input_tensor[:,-1,1].unsqueeze(-1)
        
        # Calculate distance ( predicted_points , observation_fitted_line ) over batch
        theta0x0        = theta[:,0,:] * prediction[:,:,0]
        denominator     = torch.sqrt( theta[:,0,:] * theta[:,0,:] + 1 )
        nominator       = theta0x0 + theta[:,1,:] - prediction[:,:,1]
        distance        = nominator.abs() / denominator
        if regularization_mode =='weighted':
            weight              = torch.flip( torch.arange(1,T_pred+1).cuda().double(),[0])
            weight              = (weight / T_pred).repeat(distance.size(0)).view(-1,T_pred)
            weighted_distance   = weight * distance

        elif regularization_mode =='e_weighted':
            weight              = torch.flip( torch.arange(1,T_pred+1).cuda().double(),[0])
            weight              = (weight / T_pred).repeat(distance.size(0)).view(distance.size(0),T_pred)
            weight              = torch.exp(weight)
            weighted_distance   = weight*distance

        else:
            weighted_distance = distance
        sigma_distance  = torch.mean(weighted_distance,1)
        sum_sigma_distance  = torch.mean(sigma_distance)
        return sum_sigma_distance
    except:
        print("SINGULAR VALUE")
        sum_sigma_distance = torch.zeros(1).to('cuda', non_blocking=True) + 20
        return sum_sigma_distance
    
    
