In [4]:
from os.path import isdir, join
import pickle
import numpy as np
import matplotlib.pyplot as plt
# import rosbag
import glob
from sklearn.model_selection import train_test_split, KFold
import os
import pdb
import math
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [5]:
dataset_filepath = 'short_prediction_data.pt'
dataset = torch.load(dataset_filepath)

xobs_train, xpred_train, yintention_train, xobs_test, xpred_test, yintention_test = \
dataset["xobs_train"], dataset["xpred_train"], dataset["yintention_train"], dataset["xobs_test"], \
dataset["xpred_test"], dataset["yintention_test"]
obs_seq_len, pred_seq_len = dataset["obs_seq_len"], dataset["pred_seq_len"]
batch_size = 32

In [3]:
class TrajectoriesDataset(Dataset):
    def __init__(
        self,
        xobs,
        xpred,
        yintention,
        obs_seq_len=4,
        pred_seq_len=6,
        ):
        super(TrajectoriesDataset, self).__init__()
        assert xobs.shape[0]==xpred.shape[0]==yintention.shape[0]
        assert xobs.shape[1]==obs_seq_len and xpred.shape[1]==pred_seq_len
        self.obs_seq_len = obs_seq_len
        self.pred_seq_len = pred_seq_len
        self.seq_len = self.obs_seq_len + self.pred_seq_len
        self.xobs = xobs
        self.xpred = xpred
        self.yintention = yintention
        self.num_seq = self.xobs.shape[0]

    def __len__(self):
        return self.num_seq


    def __getitem__(self, index):
        start, end = self.seq_start_end[index]
        out = [
            self.xobs[index],
            self.xpred[index],
            self.yintention[index],
        ]
        return out

In [6]:
kf = KFold(n_splits=5, random_state=0, shuffle=True)
for train_index, validation_index in kf.split(xobs_train):
    xobs_train_k, xobs_val_k = xobs_train[train_index], xobs_train[validation_index]
    xpred_train_k, xpred_val_k = xpred_train[train_index], xpred_train[validation_index] 
    yintention_train_k, yintention_val_k = yintention_train[train_index], yintention_train[validation_index]
    dataset_train = TrajectoriesDataset(
        xobs_train_k,
        xpred_train_k,
        yintention_train_k,
        obs_seq_len=obs_seq_len,
        pred_seq_len=pred_seq_len,
    )
    dataset_val = TrajectoriesDataset(
        xobs_val_k,
        xpred_val_k,
        yintention_val_k,
        obs_seq_len=obs_seq_len,
        pred_seq_len=pred_seq_len,
    )
    loader_train = DataLoader(
        dataset_train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,  
    )
    loader_val = DataLoader(
        dataset_val,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,  
    )
    break

In [None]:
# import sys
# pkg_path = '..'
# sys.path.append(pkg_path)
# import pickle
# import torch
# from torch import nn
# from src_v2.utils import average_offset_error, padding_mask
# import pdb


class IntentionLstm(nn.Module):
    def __init__(
        self,
        embedding_size=32,
        hidden_size=32,
        num_layers=1,
        dropout=0.,
    ):
        super(IntentionLstm, self).__init__()
#         self.embedding_size, self.hidden_size, self.num_layers = \
#             embedding_size, hidden_size, num_layers
        self.lstm = nn.LSTM(
            input_size=2*embedding_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=False,
            )
        self.spatial_embedding = nn.Linear(2, embedding_size)
        self.intention_embedding = nn.Linear(1, embedding_size)
        self.hidden2pos = nn.Linear(hidden_size, 2)
        self.embedding_size = embedding_size
    
    def forward(self, b_obs, b_intention, device="cuda:0"):
        
        """
        Forward function.
        inputs:
            - b_obs: batch of observation. (batch, obs_seq_len, 2)
            - b_intention: batch of intention. (batch, 1)
        outputs:
            - b_out: (batch, pred_seq_len, 2)
            
        """
        batch_size, obs_seq_len, _ = b_obs.shape
        b_obs = self.spatial_embedding(b_obs) # (batch, obs_seq_len, embedding_size)
        b_intention = self.intention_embedding(b_intention.unsqueeze(-1)) # (batch, 1, embedding_size)
        b_intention_obs = b_intention*torch.ones(batch_size, obs_seq_len, self.embedding_size).to(device)
        
        # PAUSE
        
        b_obs_intention =  torch.cat((so, si * torch.ones_like(so).to(device)), dim=2).float()
        
        
        so_intent = torch.cat((so, si * torch.ones_like(so)), dim=2).float().to(device) # N, T_obs_max, 4
        sp_intent = torch.cat((sp, si * torch.ones_like(sp)), dim=2).float().to(device) # N, T_pred_max, 4
        si = si.float().to(device)
        sp = sp.float().to(device)
        sp_mask = padding_mask(sp_lens).float().to(device) # (N, T_pred_max, 1)

        batch_size, _, input_channel = so_intent.shape
        T_pred_max = sp.shape[1]

        so_intent_ebd = self.spatial_embedding(so_intent.reshape(-1, input_channel))
        so_intent_ebd = so_intent_ebd.reshape(batch_size, -1, self.embedding_size) # (N, T_obs_max, embed_size)

        packed_so_intent_ebd = torch.nn.utils.rnn.pack_padded_sequence(so_intent_ebd, so_lens, batch_first=True, enforce_sorted=False)

        out, hc_t = self.lstm(packed_so_intent_ebd)

        out, _ = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True)


        so_last_idx = (torch.ones(batch_size, 1, self.hidden_size)*(so_lens-1).reshape(-1,1,1)).long().to(device) # so_lens-1 -> len-1 is idx
        out = out.gather(1, so_last_idx) # dim=1 -> time dimension # (batch_size, 1, hidden_size)

        pred_list = []

        pred_ts = self.hidden2pos(out.reshape(batch_size, -1)).reshape(batch_size, 1, 2) # (batch, 1, 2)

        pred_list.append(pred_ts)
        pred_ts_past = pred_ts
        for ts in range(1, T_pred_max):
            if mode == 'teacher_forcing':
                pred_ts_past_intent = sp_intent[:, ts-1:ts] # torch.cat((pred_ts_past, si), dim=2) # (batch, 1, 4) 
            elif mode == 'no_forcing':
                pred_ts_past_intent = torch.cat((pred_ts_past, si), dim=2) # (batch, 1, 4)
            pred_ts_past_intent_ebd = self.spatial_embedding(pred_ts_past_intent.reshape(-1, 4))
            pred_ts_past_intent_ebd = pred_ts_past_intent_ebd.reshape(batch_size, 1, self.embedding_size)
            out, hc_t = self.lstm(pred_ts_past_intent_ebd, hc_t) #(N, 1, embed_size)
            pred_ts = self.hidden2pos(out.reshape(batch_size, -1)).reshape(batch_size, 1, 2)
            pred_list.append(pred_ts)
            pred_ts_past = pred_ts

        sp_pred = torch.cat(pred_list, dim=1) # (N, T_pred_max, 2)
        loss = average_offset_error(sp, sp_pred, sp_mask)
        return loss, sp_pred

    

    
    def forward_origin(self, so, sp, si, mode='teacher_forcing', device='cuda:0'):

        """
        Let's first pretend there is no corner case and everything is fine.
        inputs:
            - so # observation sample. # tensor. size: (N, obs_seq_len, 2) # already .to(device)
            - sp # ground truth prediction sample. # tensor. size: (N, pred_seq_len, 2) # already .to(device)
            - si # intention sample. # tensor. size: (N, 1, 2) # already .to(device)
            - mode # 'teacher_forcing' or 'no_forcing'.

        outputs:
            
        """
        batch_size = so.size(0) # batch_size = 
        sio = torch.ones_like(so).to(device) * si # attach to so
        so_ebd = torch.cat((so, sio), dim=2) # (1, obs_seq_len, 4)
        so_ebd = self.spatial_embedding(so_ebd.reshape(-1, 4))

        so_ebd = so_ebd.reshape(batch_size, -1, self.embedding_size)
        hc_0 = self.init_hidden(batch_size, device=device)
        out, hc_t = self.lstm(so_ebd, hc_0) #  out: (batch_size, obs_seq_len, hidden_size)
        hf = out[:, -1:, :] # (batch_size, 1, hidden_size)

        sp_pred = torch.zeros_like(sp)
        sp_pred_ts = self.hidden2pos(hf.reshape(batch_size, -1)).reshape(batch_size, 1, 2)
        sp_pred[:, 0:1, :] = sp_pred_ts
        sp_pred_ts_prev = sp_pred_ts

        pred_seq_len = sp.size(1)

        for ts in range(1, pred_seq_len):
            if mode == 'teacher_forcing':
                sp_ts_prev = sp[:, ts-1:ts] # (batch_size, 1, 2)
                sp_pred_ts_prev_ebd = torch.cat((sp_ts_prev, si), dim=2) # (1, 1, 4)
                sp_pred_ts_prev_ebd = self.spatial_embedding(sp_pred_ts_prev_ebd .reshape(-1, 4)) # embedding from ground truth data
            elif mode == 'no_forcing':
                sp_pred_ts_prev_ebd = torch.cat((sp_pred_ts_prev, si), dim=2) # (1, 1, 4)
                sp_pred_ts_prev_ebd = self.spatial_embedding(sp_pred_ts_prev_ebd.reshape(-1, 4))
            else:
                print('Error on mode.')
                sys.exit(1)

            sp_pred_ts_prev_ebd = sp_pred_ts_prev_ebd.reshape(batch_size, 1, self.embedding_size)

            out, hc_t = self.lstm(sp_pred_ts_prev_ebd, hc_t)
            hf = out[:, -1:, :]
            
            sp_pred_ts = self.hidden2pos(hf.reshape(batch_size, -1)).reshape(batch_size, 1, 2)
            sp_pred[:, ts:ts+1, :] = sp_pred_ts
            sp_pred_ts_prev = sp_pred_ts

#         loss = torch.mean(((sp_pred-sp) ** 2).sum(2))
        loss = torch.mean((((sp_pred-sp) ** 2).sum(2))**(1./2))

        return loss, sp_pred


In [None]:
def train(args, loader_train, loader_val, device="cuda:0"):
    
    pass

In [None]:




  train(args, train_data_loaders, writer, logdir, device=device)
    writer.close()


def train(args, data_loaders, writer, logdir, device='cuda:0'):
    print('-'*50)
    print('Training Phase')
    print('-'*50, '\n')
    loader_train, loader_val = data_loaders
    model = st_model(args, device=device).to(device)
    temperature_scheduler = Temp_Scheduler(args.num_epochs, args.init_temp, args.init_temp, temp_min=0.03)      
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    lr_scheduler = StepLR(optimizer, step_size=50, gamma=0.3)
    print('Model is initialized.')
    print('learning rate: ', args.lr)
    checkpoint_dir = join(logdir, 'checkpoint')
    if not isdir(checkpoint_dir):
        makedirs(checkpoint_dir)
    with open(join(checkpoint_dir, 'args.pickle'), 'wb') as f:
        pickle.dump(args, f)
    print('EPOCHS: ', args.num_epochs)
    print('Training started.\n')
    train_loss_task, train_aoe_task, train_foe_task = [], [], []
    val_loss_task, val_aoe_task, val_foe_task = [], [], []
    for epoch in range(1, args.num_epochs+1):
        model.train()
        epoch_start_time = time.time()
        tau = temperature_scheduler.step()
        train_loss_epoch, train_aoe_epoch, train_foe_epoch, train_loss_mask_epoch = [], [], [], []
        for batch_idx, batch in enumerate(loader_train):
            obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_rel_gt, loss_mask_rel, loss_mask, \
            v_obs, A_obs, v_pred_gt, A_pred_gt, attn_mask_obs, attn_mask_pred_gt = batch
            if args.rotation_pattern is not None:
                (v_obs, A_obs, v_pred_gt, A_pred_gt), _ = \
                    random_rotate_graph(args, v_obs, A_obs, v_pred_gt, A_pred_gt)
            v_obs, A_obs, v_pred_gt, attn_mask_obs, loss_mask_rel = \
                v_obs.to(device), A_obs.to(device), v_pred_gt.to(device), \
                attn_mask_obs.to(device), loss_mask_rel.to(device)
            if args.deterministic:
                sampling = False
            else:
                sampling = True
            results = model(v_obs, A_obs, attn_mask_obs, loss_mask_rel, tau=tau, hard=False, sampling=sampling, device=device)
            gaussian_params_pred, x_sample_pred, info = results
            loss_mask_per_pedestrian = info['loss_mask_per_pedestrian']
            loss_mask_rel_full_partial = info['loss_mask_rel_full_partial']
            if args.deterministic:
                loss_mask_rel_pred = loss_mask_rel[:,:,-args.pred_seq_len:]
                offset_error_sq, eventual_loss_mask = offset_error_square_full_partial(x_sample_pred, v_pred_gt, loss_mask_rel_full_partial, loss_mask_rel_pred)
                loss = offset_error_sq.sum()/eventual_loss_mask.sum()
            else:
                loss = negative_log_likelihood(gaussian_params_pred, v_pred_gt, loss_mask=loss_mask_per_pedestrian)  
            train_loss_epoch.append(loss.detach().to('cpu').item())
            loss = loss / args.batch_size
            loss.backward()
            aoe = average_offset_error(x_sample_pred, v_pred_gt, loss_mask=loss_mask_per_pedestrian)
            foe = final_offset_error(x_sample_pred, v_pred_gt, loss_mask=loss_mask_per_pedestrian)
            train_aoe_epoch.append(aoe.detach().to('cpu').numpy())
            train_foe_epoch.append(foe.detach().to('cpu').numpy())
            train_loss_mask_epoch.append(loss_mask_per_pedestrian[0].detach().to('cpu').numpy())

            if (batch_idx+1) % args.batch_size == 0 or batch_idx+1 == len(loader_train):
                if args.clip_grad is not None:
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), args.clip_grad)
                optimizer.step()
                optimizer.zero_grad()

        lr_scheduler.step()
        train_aoe_epoch, train_foe_epoch, train_loss_mask_epoch = \
            np.concatenate(train_aoe_epoch, axis=0), \
            np.concatenate(train_foe_epoch, axis=0), \
            np.concatenate(train_loss_mask_epoch, axis=0)
        train_loss_epoch, train_aoe_epoch, train_foe_epoch = \
            np.mean(train_loss_epoch), \
            train_aoe_epoch.sum()/train_loss_mask_epoch.sum(), \
            train_foe_epoch.sum()/train_loss_mask_epoch.sum()
        train_loss_task.append(train_loss_epoch)
        train_aoe_task.append(train_aoe_epoch)
        train_foe_task.append(train_foe_epoch)
        training_epoch_period = time.time() - epoch_start_time
        training_epoch_period_per_sample = training_epoch_period/len(loader_train)

        val_loss_epoch, val_aoe_epoch, val_foe_epoch = inference(loader_val, model, args, mode='val', tau=tau, device=device)
        
        print('Epoch: {0} | train loss: {1:.4f} | val loss: {2:.4f} | train aoe: {3:.4f} | val aoe: {4:.4f} | train foe: {5:.4f} | val foe: {6:.4f} | period: {7:.2f} sec | time per sample: {8:.4f} sec'\
                        .format(epoch, train_loss_epoch, val_loss_epoch,\
                        train_aoe_epoch, val_aoe_epoch,\
                        train_foe_epoch, val_foe_epoch,\
                        training_epoch_period, training_epoch_period_per_sample))
        if epoch % 10 == 0:
            model_filename = join(checkpoint_dir, 'epoch_'+str(epoch)+'.pt')
            torch.save({'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'train_loss': train_loss_epoch,
                    'val_loss': val_loss_epoch,
                    'train_aoe': train_aoe_epoch,
                    'val_aoe': val_aoe_epoch, 
                    'train_foe': train_foe_epoch,
                    'val_foe': val_foe_epoch,   
                    }, model_filename)
            print('epoch_'+str(epoch)+'.pt is saved.')
        
        val_loss_task.append(val_loss_epoch)
        val_aoe_task.append(val_aoe_epoch)
        val_foe_task.append(val_foe_epoch)
        writer.add_scalars('loss', {'train': train_loss_task[-1], 'val': val_loss_task[-1]}, epoch)
        writer.add_scalars('aoe', {'train': train_aoe_task[-1], 'val': val_aoe_task[-1]}, epoch)
        writer.add_scalars('foe', {'train': train_foe_task[-1], 'val': val_foe_task[-1]}, epoch)

    hist = {}
    hist['train_loss'], hist['val_loss'] = train_loss_task, val_loss_task
    hist['train_aoe'], hist['val_aoe'] = train_aoe_task, val_aoe_task
    hist['train_foe'], hist['val_foe'] = train_foe_task, val_foe_task
    with open(join(checkpoint_dir, 'train_hist.pickle'), 'wb') as f:
        pickle.dump(hist, f)
        print(join(checkpoint_dir, 'train_hist.pickle')+' is saved.')
    return



In [3]:

    
    
    import pathhack
import pickle
import time
from os.path import join, isdir
from os import makedirs
import torch
import numpy as np
from tensorboardX import SummaryWriter
from torch.optim.lr_scheduler import StepLR
from src.mgnn.utils import arg_parse, average_offset_error, final_offset_error, \
    negative_log_likelihood, random_rotate_graph, args2writername
from src.gumbel_social_transformer.temperature_scheduler import Temp_Scheduler
from torch.utils.data import DataLoader

def load_batch_dataset(args, pkg_path, subfolder='train', num_workers=4, shuffle=None):
    result_filename = args.dataset+'_dset_'+subfolder+'_batch_trajectories.pt'
    if args.dataset == 'sdd':
        dataset_folderpath = join(pkg_path, 'datasets/sdd/social_pool_data')
    else:
        dataset_folderpath = join(pkg_path, 'datasets/eth_ucy', args.dataset)
    dset = torch.load(join(dataset_folderpath, result_filename))
    if shuffle is None:
        if subfolder == 'train':
            shuffle = True
        else:
            shuffle = False
    dloader = DataLoader(
        dset,
        batch_size=1,
        shuffle=shuffle,
        num_workers=num_workers)
    return dloader

def main(args):
    print('\n\n')
    print('-'*50)
    print('arguments: ', args)
    torch.manual_seed(args.random_seed)
    np.random.seed(args.random_seed)
    if args.batch_size != 1:
        raise RuntimeError("Batch size must be 1 for BatchTrajectoriesDataset.")
    if args.dataset == 'sdd' and args.rotation_pattern is not None:
        raise RuntimeError("SDD should not allow rotation since it uses pixels.")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('device: ', device)
    loader_train = load_batch_dataset(args, pathhack.pkg_path, subfolder='train')
    if args.dataset == 'sdd':
        loader_val = load_batch_dataset(args, pathhack.pkg_path, subfolder='test') # no val for sdd
    else:
        loader_val = load_batch_dataset(args, pathhack.pkg_path, subfolder='val')
        
        
    
    break

In [5]:
print(xobs_train_k.shape)
print(xpred_train_k.shape)
print(yintention_train_k.shape)

torch.Size([639, 4, 2])
torch.Size([639, 6, 2])
torch.Size([639, 1])


ModuleNotFoundError: No module named 'src.mgnn'

In [None]:
import sys
pkg_path = '..'
sys.path.append(pkg_path)
import pickle
import torch
from torch import nn
from src_v2.utils import average_offset_error, padding_mask
import pdb


class IntentionLstm(nn.Module):

    def __init__(self, embedding_size=64, hidden_size=64, num_layers=1, \
                 dropout=0., device='cuda:0'):
        """num_lstms indicates the number of LSTMs stacked together."""
        super(IntentionLstm, self).__init__()
        self.embedding_size, self.hidden_size, self.num_layers = \
            embedding_size, hidden_size, num_layers

        self.lstm = nn.LSTM(
                    input_size=embedding_size,
                    hidden_size=hidden_size,
                    num_layers=num_layers,
                    batch_first=True,
                    dropout=dropout,
                    bidirectional=False,
                    ).to(device)

        self.spatial_embedding = nn.Linear(4, embedding_size)
        self.hidden2pos = nn.Linear(hidden_size, 2)
        
        
    
    def init_hidden(self, batch_size, device='cuda:0'):
        return (
            torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device),
            torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device),
        )
    
    def get_hf(self, so, sp, si, mode='teacher_forcing', device='cuda:0'):

        """
        Let's first pretend there is no corner case and everything is fine.
        inputs:
            - so # observation sample. # tensor. size: (N, obs_seq_len, 2) # already .to(device)
            - sp # ground truth prediction sample. # tensor. size: (N, pred_seq_len, 2) # already .to(device)
            - si # intention sample. # tensor. size: (N, 1, 2) # already .to(device)
            - mode # 'teacher_forcing' or 'no_forcing'.

        outputs:
            
        """
        batch_size = so.size(0) # batch_size = 
        sio = torch.ones_like(so).to(device) * si # attach to so
        so_ebd = torch.cat((so, sio), dim=2) # (1, obs_seq_len, 4)
        so_ebd = self.spatial_embedding(so_ebd.reshape(-1, 4))

        so_ebd = so_ebd.reshape(batch_size, -1, self.embedding_size)
        hc_0 = self.init_hidden(batch_size, device=device)
        out, hc_t = self.lstm(so_ebd, hc_0) #  out: (batch_size, obs_seq_len, hidden_size)
        hf = out[:, -1:, :] # (batch_size, 1, hidden_size)
    
        return hf
    
    def forward(self, so, so_lens, sp, sp_lens, si, mode='teacher_forcing', device='cuda:0'):
        """
        Parallel computation.
        inputs:
            - so # observation sample. # tensor. size: (N, T_obs_max, 2)
            - so_lens # T_obs for all samples in sp. # tensor. size: (N,)
            - sp # ground truth prediction sample. # tensor. size: (N, T_pred_max, 2)
            - sp_lens # T_pred for all samples in sp. # tensor. size: (N,)
            - si # intention sample. # tensor. size: (N, 1, 2)
            - mode # 'teacher_forcing' or 'no_forcing'.

        outputs:
            - loss # average offset error across samples.
            - sp_pred # prediction on sp. # tensor. size: (N, T_pred_max, 2)
        """
        so_intent = torch.cat((so, si * torch.ones_like(so)), dim=2).float().to(device) # N, T_obs_max, 4
        sp_intent = torch.cat((sp, si * torch.ones_like(sp)), dim=2).float().to(device) # N, T_pred_max, 4
        si = si.float().to(device)
        sp = sp.float().to(device)
        sp_mask = padding_mask(sp_lens).float().to(device) # (N, T_pred_max, 1)

        batch_size, _, input_channel = so_intent.shape
        T_pred_max = sp.shape[1]

        so_intent_ebd = self.spatial_embedding(so_intent.reshape(-1, input_channel))
        so_intent_ebd = so_intent_ebd.reshape(batch_size, -1, self.embedding_size) # (N, T_obs_max, embed_size)

        packed_so_intent_ebd = torch.nn.utils.rnn.pack_padded_sequence(so_intent_ebd, so_lens, batch_first=True, enforce_sorted=False)

        out, hc_t = self.lstm(packed_so_intent_ebd)

        out, _ = torch.nn.utils.rnn.pad_packed_sequence(out, batch_first=True)


        so_last_idx = (torch.ones(batch_size, 1, self.hidden_size)*(so_lens-1).reshape(-1,1,1)).long().to(device) # so_lens-1 -> len-1 is idx
        out = out.gather(1, so_last_idx) # dim=1 -> time dimension # (batch_size, 1, hidden_size)

        pred_list = []

        pred_ts = self.hidden2pos(out.reshape(batch_size, -1)).reshape(batch_size, 1, 2) # (batch, 1, 2)

        pred_list.append(pred_ts)
        pred_ts_past = pred_ts
        for ts in range(1, T_pred_max):
            if mode == 'teacher_forcing':
                pred_ts_past_intent = sp_intent[:, ts-1:ts] # torch.cat((pred_ts_past, si), dim=2) # (batch, 1, 4) 
            elif mode == 'no_forcing':
                pred_ts_past_intent = torch.cat((pred_ts_past, si), dim=2) # (batch, 1, 4)
            pred_ts_past_intent_ebd = self.spatial_embedding(pred_ts_past_intent.reshape(-1, 4))
            pred_ts_past_intent_ebd = pred_ts_past_intent_ebd.reshape(batch_size, 1, self.embedding_size)
            out, hc_t = self.lstm(pred_ts_past_intent_ebd, hc_t) #(N, 1, embed_size)
            pred_ts = self.hidden2pos(out.reshape(batch_size, -1)).reshape(batch_size, 1, 2)
            pred_list.append(pred_ts)
            pred_ts_past = pred_ts

        sp_pred = torch.cat(pred_list, dim=1) # (N, T_pred_max, 2)
        loss = average_offset_error(sp, sp_pred, sp_mask)
        return loss, sp_pred

    

    
    def forward_origin(self, so, sp, si, mode='teacher_forcing', device='cuda:0'):

        """
        Let's first pretend there is no corner case and everything is fine.
        inputs:
            - so # observation sample. # tensor. size: (N, obs_seq_len, 2) # already .to(device)
            - sp # ground truth prediction sample. # tensor. size: (N, pred_seq_len, 2) # already .to(device)
            - si # intention sample. # tensor. size: (N, 1, 2) # already .to(device)
            - mode # 'teacher_forcing' or 'no_forcing'.

        outputs:
            
        """
        batch_size = so.size(0) # batch_size = 
        sio = torch.ones_like(so).to(device) * si # attach to so
        so_ebd = torch.cat((so, sio), dim=2) # (1, obs_seq_len, 4)
        so_ebd = self.spatial_embedding(so_ebd.reshape(-1, 4))

        so_ebd = so_ebd.reshape(batch_size, -1, self.embedding_size)
        hc_0 = self.init_hidden(batch_size, device=device)
        out, hc_t = self.lstm(so_ebd, hc_0) #  out: (batch_size, obs_seq_len, hidden_size)
        hf = out[:, -1:, :] # (batch_size, 1, hidden_size)

        sp_pred = torch.zeros_like(sp)
        sp_pred_ts = self.hidden2pos(hf.reshape(batch_size, -1)).reshape(batch_size, 1, 2)
        sp_pred[:, 0:1, :] = sp_pred_ts
        sp_pred_ts_prev = sp_pred_ts

        pred_seq_len = sp.size(1)

        for ts in range(1, pred_seq_len):
            if mode == 'teacher_forcing':
                sp_ts_prev = sp[:, ts-1:ts] # (batch_size, 1, 2)
                sp_pred_ts_prev_ebd = torch.cat((sp_ts_prev, si), dim=2) # (1, 1, 4)
                sp_pred_ts_prev_ebd = self.spatial_embedding(sp_pred_ts_prev_ebd .reshape(-1, 4)) # embedding from ground truth data
            elif mode == 'no_forcing':
                sp_pred_ts_prev_ebd = torch.cat((sp_pred_ts_prev, si), dim=2) # (1, 1, 4)
                sp_pred_ts_prev_ebd = self.spatial_embedding(sp_pred_ts_prev_ebd.reshape(-1, 4))
            else:
                print('Error on mode.')
                sys.exit(1)

            sp_pred_ts_prev_ebd = sp_pred_ts_prev_ebd.reshape(batch_size, 1, self.embedding_size)

            out, hc_t = self.lstm(sp_pred_ts_prev_ebd, hc_t)
            hf = out[:, -1:, :]
            
            sp_pred_ts = self.hidden2pos(hf.reshape(batch_size, -1)).reshape(batch_size, 1, 2)
            sp_pred[:, ts:ts+1, :] = sp_pred_ts
            sp_pred_ts_prev = sp_pred_ts

#         loss = torch.mean(((sp_pred-sp) ** 2).sum(2))
        loss = torch.mean((((sp_pred-sp) ** 2).sum(2))**(1./2))

        return loss, sp_pred
