In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy 
import pickle
from glob import glob
import numpy as np
from defs import ArgoverseDataset
from defs import my_collate

In [2]:
import torch, gc

gc.collect()
torch.cuda.empty_cache()

In [3]:
from tqdm import tqdm_notebook as tqdm

def train(model, train_loader, device, optimizer, epoch, log_interval=10000):
    model.train()
    criterion = nn.MSELoss()
    iterator = tqdm(train_loader, total=int(len(train_loader)))
    counter = 0
    # for i_batch, sample_batch in enumerate(train_loader):
    for batch_idx, (inp, out) in enumerate(iterator):
         
        # inp, out = sample_batch
        inp = inp.to(device)
        out = out.to(device)
        
        optimizer.zero_grad()
        
        inp = inp.reshape(inp.shape[0], inp.shape[3] * inp.shape[1] * inp.shape[2])
        pred_out = model(inp)

        pred_out = pred_out.reshape(out.shape[0], out.shape[1], out.shape[2], out.shape[3])
        out = out.reshape(out.shape[0], out.shape[1], out.shape[2], out.shape[3])
        
        loss = torch.sqrt(criterion(pred_out, out))
        
        l1_lambda = 0.0001
        l2_lambda = 1e-4
        l1_norm = sum(p.abs().sum() for p in model.parameters())
        l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())

        loss = loss + l1_lambda*l1_norm + l2_lambda*l2_norm
        
        
        loss.backward()
        optimizer.step()
        
        counter += 1
        iterator.set_postfix(loss=(loss.item()*inp.size(0) / (counter * train_loader.batch_size)))
        
    return (loss.item()*inp.size(0) / (counter * train_loader.batch_size))

In [4]:
def test(model, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    criterion = nn.MSELoss()
    with torch.no_grad():
        for i_batch, sample_batch in enumerate(test_loader):
            inp, out = sample_batch
            inp = inp.to(device)
            out = out.to(device)
            
            inp = inp.reshape(inp.shape[0], inp.shape[3] * inp.shape[1] * inp.shape[2])
            pred_out = model(inp)
            
            pred_out = pred_out.reshape(out.shape[0], out.shape[1], out.shape[2], out.shape[3])
            out = out.reshape(out.shape[0], out.shape[1], out.shape[2], out.shape[3])
            
        l1_lambda = 0.0001
        l2_lambda = 1e-4
        l1_norm = sum(p.abs().sum() for p in model.parameters())
        l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
            
        test_loss += torch.sqrt(criterion(pred_out, out)).item() + l1_lambda*l1_norm + l2_lambda*l2_norm
            
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.7f}\n'.format(
        test_loss, correct, len(test_loader.dataset)))
    return test_loss

In [5]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        
        self.model = torch.nn.Sequential(
            torch.nn.Linear(19*60*4, 60*30*2),
            torch.nn.LeakyReLU(negative_slope=0.5, inplace=False)
        )
        
    def forward(self, x):
        
        x = self.model(x)
        return x

In [7]:
if __name__ == '__main__':
    """Change to the data folder"""
    train_path = "./new_train/new_train"
    test_path = "./new_val_in/new_val_in"
    # number of sequences in each dataset
    # train:205942  val:3200 test: 36272 
    # sequences sampled at 10HZ rate
    
    # intialize a dataset
    val_dataset  = ArgoverseDataset(data_path=train_path)
    test_dataset = ArgoverseDataset(data_path=test_path)

    TRAIN_SET, TEST_SET = torch.utils.data.random_split(val_dataset, [169670, 36272])

    batch_size_train = 64
    batch_size_test = 2048


    train_loader = DataLoader(TRAIN_SET,batch_size=batch_size_train, shuffle = True, collate_fn=my_collate, num_workers=4,pin_memory=True)
    test_loader = DataLoader(TEST_SET,batch_size=batch_size_test, shuffle = True, collate_fn=my_collate, num_workers=4, pin_memory=True)


    learning_rate = 0.001
    # momentum = 0.2
    device = "cuda"
    model = MLP().to(device) #using cpu here
    torch.backends.cudnn.benchmark = True
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, 
                                betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    num_epoch = 35
    
    best_valid_loss = float('inf')

    for epoch in range(1, num_epoch + 1):
        print("Epoch: " + str(epoch))
        train(model, train_loader, device, optimizer, epoch)
        valid_loss = test(model, test_loader, device)
        
        if epoch == 5:
            for g in optimizer.param_groups:
                g['lr'] = 0.00005
                
        if epoch == 10:
            for g in optimizer.param_groups:
                g['lr'] = 0.00001

        if epoch == 15:
            for g in optimizer.param_groups:
                g['lr'] = 0.000005
                
        if valid_loss <= best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'MLP-model.pt')

Epoch: 1


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0016346

Epoch: 2


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0024668

Epoch: 3


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0023781

Epoch: 4


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0019257

Epoch: 5


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0020561

Epoch: 6


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0001513

Epoch: 7


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0001388

Epoch: 8


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0001306

Epoch: 9


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0001381

Epoch: 10


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0001235

Epoch: 11


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000657

Epoch: 12


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000635

Epoch: 13


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000627

Epoch: 14


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000639

Epoch: 15


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000644

Epoch: 16


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000532

Epoch: 17


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000580

Epoch: 18


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000547

Epoch: 19


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000527

Epoch: 20


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000520

Epoch: 21


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000534

Epoch: 22


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000554

Epoch: 23


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000562

Epoch: 24


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)





Test set: Average loss: 0.0000955

Epoch: 26


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000531

Epoch: 27


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))




KeyboardInterrupt: 

In [8]:
best_valid_loss.item()

5.2013103413628414e-05

In [9]:
num_epoch = 50
for epoch in range(27, num_epoch + 1):
    
    if epoch == 27:
        for g in optimizer.param_groups:
            g['lr'] = 0.000001
    
    print("Epoch: " + str(epoch))
    train(model, train_loader, device, optimizer, epoch)
    valid_loss = test(model, test_loader, device)
    
        
    if epoch == 32:
        for g in optimizer.param_groups:
            g['lr'] = 0.0000005
                
    if epoch == 37:
        for g in optimizer.param_groups:
            g['lr'] = 0.0000001

    if epoch == 42:
        for g in optimizer.param_groups:
            g['lr'] = 0.00000005

    if epoch == 47:
        for g in optimizer.param_groups:
            g['lr'] = 0.00000001
            
    if valid_loss <= best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'MLP-model.pt')

Epoch: 27


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000485

Epoch: 28


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000488

Epoch: 29


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000493

Epoch: 30


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000484

Epoch: 31


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000491

Epoch: 32


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000490

Epoch: 33


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000496

Epoch: 34


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000472

Epoch: 35


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000909

Epoch: 36


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000470

Epoch: 37


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000472

Epoch: 38


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000478

Epoch: 39


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000485

Epoch: 40


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000476

Epoch: 41


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000477

Epoch: 42


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000466

Epoch: 43


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000473

Epoch: 44


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000477

Epoch: 45


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000481

Epoch: 46


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000457

Epoch: 47


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000474

Epoch: 48


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000479

Epoch: 49


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000490

Epoch: 50


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))




KeyboardInterrupt: 

In [10]:
best_valid_loss.item()

4.5719145418843254e-05

In [13]:
for g in optimizer.param_groups:
    print(g.keys())

dict_keys(['params', 'lr', 'betas', 'eps', 'weight_decay', 'amsgrad'])


In [14]:
num_epoch = 55
for epoch in range(50, num_epoch + 1):
    
    if epoch == 50:
        for g in optimizer.param_groups:
            g['lr'] = 0.00000001
            g['betas'] = (0.09, 0.0999)
            g['eps'] = 1e-09
    
    print("Epoch: " + str(epoch))
    train(model, train_loader, device, optimizer, epoch)
    valid_loss = test(model, test_loader, device)
    
            
    if valid_loss <= best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'MLP-model.pt')

Epoch: 50


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000485

Epoch: 51


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000474

Epoch: 52


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000459

Epoch: 53


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000475

Epoch: 54


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000473

Epoch: 55


HBox(children=(FloatProgress(value=0.0, max=2652.0), HTML(value='')))



Test set: Average loss: 0.0000461



In [15]:
best_valid_loss.item()

4.5719145418843254e-05

In [11]:
def collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    inp = torch.FloatTensor(inp)
    return inp

t_loader = DataLoader(test_dataset,batch_size=1, shuffle = False, collate_fn=collate, num_workers=0)

In [12]:
import csv

model.load_state_dict(torch.load('MLP-model.pt'))

header = ['ID']
for i in range(1, 61):
    header.append('v' + str(i))
    
with open('mlp.csv', 'w') as csvfile: 
    # creating a csv writer object 
    csvwriter = csv.writer(csvfile) 
        
    # writing the fields 
    csvwriter.writerow(header) 
        
    for i_batch, sample_batch in enumerate(t_loader):
        header = []
        header.append(test_dataset[i_batch]['scene_idx'])
        
        model.eval()
        inp = sample_batch
        inp = inp.reshape(inp.shape[0], inp.shape[3] * inp.shape[1] * inp.shape[2])
        inp = inp.to(device)
        pred_out = model(inp)
        pred_out = pred_out.reshape(1, 60, 30, 2)
        pred_out = pred_out.squeeze() 
        track_id = test_dataset[i_batch]['track_id']
        track_id = track_id[:,0,0]
        index = 0
        for i in range(len(track_id)):
            if test_dataset[i_batch]['agent_id'] == track_id[i]:
                index = i
                break
        p_out = pred_out[index]
        p_out = p_out.reshape(30*2)
        for i in range(len(p_out)):
            header.append(p_out[i].item())
            
        csvwriter.writerow(header)