In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from efficientnet_pytorch import EfficientNet
from pathlib import Path
import numpy as np
import multiprocessing

In [24]:
# The optical flow input will look like this
of = torch.randn(3,2,640,480).cuda()

In [2]:
model = EfficientNet.from_pretrained("efficientnet-b3", in_channels=2, num_classes=1).cuda()

Loaded pretrained weights for efficientnet-b3


In [26]:
features = model(of)
features.shape

torch.Size([3, 1])

In [3]:
class OFDataset(Dataset):
    def __init__(self, of_dir, label_f):
        self.len = len(list(Path(of_dir).glob('*.npy')))
        self.of_dir = of_dir
        self.label_file = open(label_f).readlines()
    def __len__(self): return self.len
    def __getitem__(self, idx):
        of_array = np.load(Path(self.of_dir)/f'{idx}.npy')
        of_tensor = torch.squeeze(torch.Tensor(of_array))
        label = float(self.label_file[idx].split()[0])
        return [of_tensor, label]

In [4]:
ds = OFDataset('/home/sharif/Documents/RAFT/train_predictions', '/home/sharif/Documents/commai-challenge/data/labels/train.txt')

In [5]:
train_split = .8

In [6]:
ds_size = len(ds)
indices = list(range(ds_size))
split = int(np.floor(train_split * ds_size))
train_indices, val_indices = indices[:split], indices[split:]

In [7]:
assert len(train_indices) > len(val_indices)
assert train_indices[0] == 0
assert val_indices[0] == int(ds_size * train_split)

In [8]:
sample = ds[3]
assert type(sample[0]) == torch.Tensor
assert type(sample[1]) == float

In [9]:
cpu_cores = multiprocessing.cpu_count()
cpu_cores

6

In [10]:
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

In [11]:
train_dl = DataLoader(ds, batch_size=4, sampler=train_sampler, num_workers=cpu_cores)
val_dl = DataLoader(ds, batch_size=4, sampler=val_sampler, num_workers=cpu_cores)

In [12]:
epochs = 100
log_train_steps = 25
log_val_steps = 50
val_steps = 200

In [13]:
criterion = nn.MSELoss()
opt = optim.Adam(model.parameters())

In [14]:
writer = SummaryWriter('runs/optical_flow_exp_1')

In [15]:
n_validations = 0
validation_losses = []

for epoch in range(epochs):
    running_loss = 0.0
    for i, sample in enumerate(train_dl):
        of_tensor = sample[0].cuda()
        label = sample[1].float().cuda()
        opt.zero_grad()
        pred = torch.squeeze(model(of_tensor))
        loss = criterion(pred, label)
        loss.backward()
        opt.step()
        
        running_loss += loss.item()
        if i % log_train_steps == 0 and i != 0:
            writer.add_scalar('training loss',
                             running_loss / log_train_steps,
                             epoch * len(train_dl) + i)
            running_loss = 0.0    
        if i % 100 == 0: print(f'{i}/{len(train_dl)}')
        
        # val loss
        if i % val_steps == 0 and i != 0:
            with torch.no_grad():
                for j, val_sample in enumerate(val_dl):
                    of_tensor = val_sample[0].cuda()
                    label = val_sample[1].float().cuda()
                    pred = torch.squeeze(model(of_tensor))
                    loss = criterion(pred, label)

                    running_loss += loss.item()
                    if j % 25 == 0: print(f'{j}/{len(val_dl)}')

                    if j == log_val_steps:
                        l = running_loss / log_val_steps
                        writer.add_scalar('validation loss',
                                         l,
                                         n_validations)
                        validation_losses.append(l)
                        if min(validation_losses) == l:
                            p = f'/home/sharif/Documents/commai-challenge/trained_models/{epoch}_{i}_{l}.pth'
                            torch.save(model.state_dict(), p)
                        
                        running_loss = 0.0
                        n_validations += 1
                        break

0/4080
100/4080
200/4080
0/1020
25/1020
50/1020
300/4080
400/4080
0/1020
25/1020
50/1020
500/4080
600/4080
0/1020
25/1020
50/1020
700/4080
800/4080
0/1020
25/1020
50/1020
900/4080
1000/4080
0/1020
25/1020
50/1020
1100/4080
1200/4080
0/1020
25/1020
50/1020
1300/4080
1400/4080
0/1020
25/1020
50/1020
1500/4080
1600/4080
0/1020
25/1020
50/1020
1700/4080
1800/4080
0/1020
25/1020
50/1020
1900/4080
2000/4080
0/1020
25/1020
50/1020
2100/4080
2200/4080
0/1020
25/1020
50/1020
2300/4080
2400/4080
0/1020
25/1020
50/1020
2500/4080
2600/4080
0/1020
25/1020
50/1020
2700/4080
2800/4080
0/1020
25/1020
50/1020
2900/4080
3000/4080
0/1020
25/1020
50/1020
3100/4080
3200/4080
0/1020
25/1020
50/1020
3300/4080
3400/4080
0/1020
25/1020
50/1020
3500/4080
3600/4080
0/1020
25/1020
50/1020
3700/4080
3800/4080
0/1020
25/1020
50/1020
3900/4080
4000/4080
0/1020
25/1020
50/1020
0/4080
100/4080
200/4080
0/1020
25/1020
50/1020
300/4080
400/4080
0/1020
25/1020
50/1020
500/4080
600/4080
0/1020
25/1020
50/1020
700/4080
800

2300/4080
2400/4080
0/1020
25/1020
50/1020
2500/4080
2600/4080
0/1020
25/1020
50/1020
2700/4080
2800/4080
0/1020
25/1020
50/1020
2900/4080
3000/4080
0/1020
25/1020
50/1020
3100/4080
3200/4080
0/1020
25/1020
50/1020
3300/4080
3400/4080
0/1020
25/1020
50/1020
3500/4080
3600/4080
0/1020
25/1020
50/1020
3700/4080
3800/4080
0/1020
25/1020
50/1020
3900/4080
4000/4080
0/1020
25/1020
50/1020
0/4080
100/4080
200/4080
0/1020
25/1020
50/1020
300/4080
400/4080
0/1020
25/1020
50/1020
500/4080
600/4080
0/1020
25/1020
50/1020
700/4080
800/4080
0/1020
25/1020
50/1020
900/4080
1000/4080
0/1020
25/1020
50/1020
1100/4080
1200/4080
0/1020
25/1020
50/1020
1300/4080
1400/4080
0/1020
25/1020
50/1020
1500/4080
1600/4080
0/1020
25/1020
50/1020
1700/4080
1800/4080
0/1020
25/1020
50/1020
1900/4080
2000/4080
0/1020
25/1020
50/1020
2100/4080
2200/4080
0/1020
25/1020
50/1020
2300/4080
2400/4080
0/1020
25/1020
50/1020
2500/4080
2600/4080
0/1020
25/1020
50/1020
2700/4080
2800/4080
0/1020
25/1020
50/1020
2900/4080
300

KeyboardInterrupt: 