In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy 
import pickle
from glob import glob
import numpy as np
from defs import ArgoverseDataset
from defs import my_collate

In [2]:
import torch, gc

gc.collect()
torch.cuda.empty_cache()

In [3]:
from tqdm import tqdm_notebook as tqdm

def train(model, train_loader, device, optimizer, epoch, log_interval=10000):
    model.train()
    criterion = nn.MSELoss()
    iterator = tqdm(train_loader, total=int(len(train_loader)))
    counter = 0
    # for i_batch, sample_batch in enumerate(train_loader):
    for batch_idx, (inp, out) in enumerate(iterator):
         
        # inp, out = sample_batch
        inp = inp.to(device)
        out = out.to(device)
        
        optimizer.zero_grad()
        
        inp = inp.reshape(inp.shape[0], inp.shape[3] * inp.shape[1] * inp.shape[2])
        pred_out = model(inp)

        pred_out = pred_out.reshape(out.shape[0], out.shape[1], out.shape[2], out.shape[3])
        out = out.reshape(out.shape[0], out.shape[1], out.shape[2], out.shape[3])
        
        loss = torch.sqrt(criterion(pred_out, out))
        
        l1_lambda = 0.0001
        l2_lambda = 1e-4
        l1_norm = sum(p.abs().sum() for p in model.parameters())
        l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())

        loss = loss + l1_lambda*l1_norm + l2_lambda*l2_norm
        
        
        loss.backward()
        optimizer.step()
        
        counter += 1
        iterator.set_postfix(loss=(loss.item()*inp.size(0) / (counter * train_loader.batch_size)))
        
    return (loss.item()*inp.size(0) / (counter * train_loader.batch_size))

In [4]:
def test(model, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    criterion = nn.MSELoss()
    with torch.no_grad():
        for i_batch, sample_batch in enumerate(test_loader):
            inp, out = sample_batch
            inp = inp.to(device)
            out = out.to(device)
            
            inp = inp.reshape(inp.shape[0], inp.shape[3] * inp.shape[1] * inp.shape[2])
            pred_out = model(inp)
            
            pred_out = pred_out.reshape(out.shape[0], out.shape[1], out.shape[2], out.shape[3])
            out = out.reshape(out.shape[0], out.shape[1], out.shape[2], out.shape[3])
            
        l1_lambda = 0.0001
        l2_lambda = 1e-4
        l1_norm = sum(p.abs().sum() for p in model.parameters())
        l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
            
        test_loss += torch.sqrt(criterion(pred_out, out)).item() + l1_lambda*l1_norm + l2_lambda*l2_norm
            
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.7f}\n'.format(
        test_loss, correct, len(test_loader.dataset)))
    return test_loss

In [5]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        
        self.model = torch.nn.Sequential(
            torch.nn.Linear(19*60*4, 60*30*2),
            torch.nn.RReLU(lower=0.05, upper=0.5, inplace=False)
        )
        
    def forward(self, x):
        
        x = self.model(x)
        return x

In [6]:
if __name__ == '__main__':
    """Change to the data folder"""
    train_path = "./new_train/new_train"
    test_path = "./new_val_in/new_val_in"
    # number of sequences in each dataset
    # train:205942  val:3200 test: 36272 
    # sequences sampled at 10HZ rate
    
    # intialize a dataset
    val_dataset  = ArgoverseDataset(data_path=train_path)
    test_dataset = ArgoverseDataset(data_path=test_path)

    TRAIN_SET, TEST_SET = torch.utils.data.random_split(val_dataset, [169670, 36272])

    batch_size_train = 64
    batch_size_test = 2048


    train_loader = DataLoader(TRAIN_SET,batch_size=batch_size_train, shuffle = True, collate_fn=my_collate, num_workers=4,pin_memory=True)
    test_loader = DataLoader(TEST_SET,batch_size=batch_size_test, shuffle = True, collate_fn=my_collate, num_workers=4, pin_memory=True)


    learning_rate = 0.001
    # momentum = 0.2
    device = "cuda"
    model = MLP().to(device) #using cpu here
    torch.backends.cudnn.benchmark = True
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, 
                                betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    num_epoch = 35
    
    best_valid_loss = float('inf')

    for epoch in range(1, num_epoch + 1):
        print("Epoch: " + str(epoch))
        train(model, train_loader, device, optimizer, epoch)
        valid_loss = test(model, test_loader, device)
        
        if epoch == 5:
            for g in optimizer.param_groups:
                g['lr'] = 0.00005
                
        if epoch == 15:
            for g in optimizer.param_groups:
                g['lr'] = 0.00001

        if epoch == 25:
            for g in optimizer.param_groups:
                g['lr'] = 0.000005
                
        if valid_loss <= best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'MLP-model.pt')

Epoch: 1


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0023119

Epoch: 2


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0023829

Epoch: 3


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0023630

Epoch: 4


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0020780

Epoch: 5


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0022237

Epoch: 6


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0007988

Epoch: 7


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0001494

Epoch: 8


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0001216

Epoch: 9


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0001200

Epoch: 10


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0001221

Epoch: 11


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0001128

Epoch: 12


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0001211

Epoch: 13


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0001315

Epoch: 14


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0001227

Epoch: 15


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0001407

Epoch: 16


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000650

Epoch: 17


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000572

Epoch: 18


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000535

Epoch: 19


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000663

Epoch: 20


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000533

Epoch: 21


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000529

Epoch: 22


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000588

Epoch: 23


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000567

Epoch: 24


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000547

Epoch: 25


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000717

Epoch: 26


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000484

Epoch: 27


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000477

Epoch: 28


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000509

Epoch: 29


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000480

Epoch: 30


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000486

Epoch: 31


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000489

Epoch: 32


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000472

Epoch: 33


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000477

Epoch: 34


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000474

Epoch: 35


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000487



In [7]:
best_valid_loss.item()

4.722451922134496e-05

In [8]:
num_epoch = 70
for epoch in range(36, num_epoch + 1):
    
    if epoch == 36:
        for g in optimizer.param_groups:
            g['lr'] = 0.00001
    
    print("Epoch: " + str(epoch))
    train(model, train_loader, device, optimizer, epoch)
    valid_loss = test(model, test_loader, device)
    
        
    if epoch == 46:
        for g in optimizer.param_groups:
            g['lr'] = 0.000001
                
    if epoch == 56:
        for g in optimizer.param_groups:
            g['lr'] = 0.0000001

    if epoch == 66:
        for g in optimizer.param_groups:
            g['lr'] = 0.00000005

            
    if valid_loss <= best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'MLP-model.pt')

Epoch: 36


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000650

Epoch: 37


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000533

Epoch: 38


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000633

Epoch: 39


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000556

Epoch: 40


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000548

Epoch: 41


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000573

Epoch: 42


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000527

Epoch: 43


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000529

Epoch: 44


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000530

Epoch: 45


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000504

Epoch: 46


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000555

Epoch: 47


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000454

Epoch: 48


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000447

Epoch: 49


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000428

Epoch: 50


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000431

Epoch: 51


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000436

Epoch: 52


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000437

Epoch: 53


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000434

Epoch: 54


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000542

Epoch: 55


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000433

Epoch: 56


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000425

Epoch: 57


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000464

Epoch: 58


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000423

Epoch: 59


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000418

Epoch: 60


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000433

Epoch: 61


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000464

Epoch: 62


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000528

Epoch: 63


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000425

Epoch: 64


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0005167

Epoch: 65


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000415

Epoch: 66


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000424

Epoch: 67


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000420

Epoch: 68


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000421

Epoch: 69


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000419

Epoch: 70


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000412



In [9]:
best_valid_loss.item()

4.1225845052395016e-05

In [None]:
num_epoch = 120
for epoch in range(70, num_epoch + 1):
    
    if epoch == 75:
        for g in optimizer.param_groups:
            g['lr'] = 0.00000001
            
    if epoch == 85:
        for g in optimizer.param_groups:
            g['lr'] = 0.000000005
            
    if epoch == 95:
        for g in optimizer.param_groups:
            g['lr'] = 0.000000001

    if epoch == 105:
        for g in optimizer.param_groups:
            g['lr'] = 0.0000000001
    
    print("Epoch: " + str(epoch))
    train(model, train_loader, device, optimizer, epoch)
    valid_loss = test(model, test_loader, device)
            
    if valid_loss <= best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'MLP-model.pt')

Epoch: 70


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000425

Epoch: 71


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000427

Epoch: 72


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000424

Epoch: 73


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000420

Epoch: 74


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000409

Epoch: 75


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000428

Epoch: 76


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000425

Epoch: 77


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000431

Epoch: 78


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000419

Epoch: 79


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000424

Epoch: 80


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000419

Epoch: 81


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000418

Epoch: 82


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000419

Epoch: 83


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000416

Epoch: 84


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000421

Epoch: 85


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000427

Epoch: 86


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000416

Epoch: 87


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000422

Epoch: 88


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000438

Epoch: 89


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000432

Epoch: 90


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000424

Epoch: 91


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000418

Epoch: 92


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))

In [11]:
best_valid_loss.item()

4.088086279807612e-05

In [12]:
def collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    inp = torch.FloatTensor(inp)
    return inp

t_loader = DataLoader(test_dataset,batch_size=1, shuffle = False, collate_fn=collate, num_workers=0)

In [13]:
import csv

model.load_state_dict(torch.load('MLP-model.pt'))

header = ['ID']
for i in range(1, 61):
    header.append('v' + str(i))
    
with open('mlp.csv', 'w') as csvfile: 
    # creating a csv writer object 
    csvwriter = csv.writer(csvfile) 
        
    # writing the fields 
    csvwriter.writerow(header) 
        
    for i_batch, sample_batch in enumerate(t_loader):
        header = []
        header.append(test_dataset[i_batch]['scene_idx'])
        
        model.eval()
        inp = sample_batch
        inp = inp.reshape(inp.shape[0], inp.shape[3] * inp.shape[1] * inp.shape[2])
        inp = inp.to(device)
        pred_out = model(inp)
        pred_out = pred_out.reshape(1, 60, 30, 2)
        pred_out = pred_out.squeeze() 
        track_id = test_dataset[i_batch]['track_id']
        track_id = track_id[:,0,0]
        index = 0
        for i in range(len(track_id)):
            if test_dataset[i_batch]['agent_id'] == track_id[i]:
                index = i
                break
        p_out = pred_out[index]
        p_out = p_out.reshape(30*2)
        for i in range(len(p_out)):
            header.append(p_out[i].item())
            
        csvwriter.writerow(header)