In [1]:
import random

import numpy as np
from tqdm import tqdm
import torch
from torch import nn
import torch.nn.functional as F
import torchmetrics

In [2]:
from torch.utils.data import DataLoader
from datasets.video_dataset import VideoDataset

ds = VideoDataset('./data/train')
train_loader = DataLoader(ds, batch_size=4, shuffle=True)

val_ds = VideoDataset('./data/val')
val_loader = DataLoader(val_ds, batch_size=4, shuffle=False)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_classes = 49
num_epochs = 10

In [4]:
from models.lstm import Seq2Seq

model = Seq2Seq(
    num_channels=1,
    num_kernels=16,
    kernel_size=3,
    padding=1,
    activation='relu',
    frame_size=(160, 240),
    num_layers=3
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()

In [5]:
def to_binary(masks: torch.Tensor) -> torch.Tensor:
    results = []
    for mask in masks:
        objects = torch.unique(mask[mask != 0])
        minibatch = mask.unsqueeze(0).repeat(len(objects), 1, 1, 1)
        for i, obj in enumerate(objects):
            minibatch[i][mask != obj] = 0
            minibatch[i][mask == obj] = 1
        results.append(minibatch)
    return torch.cat(results, 0)

masks = torch.randint(high=49, size=(2, 11, 160, 240))
print(masks.size())
to_binary(masks).size()

torch.Size([2, 11, 160, 240])


torch.Size([96, 11, 160, 240])

In [6]:
jaccard = torchmetrics.JaccardIndex(task='multiclass', num_classes=num_classes).to(device)

for epoch in range(1, num_epochs + 1):
    model.train()
    for batch in tqdm(train_loader):
        masks = to_binary(batch['masks'])
        
        idx = random.randint(11, 21)
        inputs = masks[:, idx - 11:idx, :, :].float().to(device)
        target = masks[:, idx, :, :].float().to(device)
        
        inputs = inputs.unsqueeze(1)
        target = target.unsqueeze(1)
        
        optimizer.zero_grad()
        output = model(inputs)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    
    val_loss = 0
    model.eval()
    with torch.inference_mode():                            
        for batch in tqdm(val_loader):
            masks = to_binary(batch['masks'])
            idx = random.randint(11, 21)
            inputs = masks[:, idx - 11:idx, :, :].float().to(device)
            target = masks[:, idx, :, :].float().to(device)
            
            inputs = inputs.unsqueeze(1)
            
            output = model(inputs)
            pred = torch.argmax(output, dim=1)
            j_loss = jaccard(pred, target)
            val_loss += j_loss
            
        print(val_loss / len(val_loader))

  0%|          | 1/250 [00:31<2:09:23, 31.18s/it]


KeyboardInterrupt: 