In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from data import GODData
from tqdm.notebook import tqdm, trange
from sklearn.metrics import accuracy_score
from transformers import VideoMAEConfig, VideoMAEForVideoClassification

In [2]:
# better plots
sns.set_style("whitegrid")

# better progress
tqdm.pandas()

In [3]:
# use GPU if available
device = torch.device("cpu")
if torch.cuda.is_available():
    print("Using GPU")
    device = torch.device("cuda")

Using GPU


In [4]:
# load data
print("Loading data...")
train_dataset = GODData(
    subject="01", 
    session_id="01", 
    task="perception", 
    train=True, 
    limit_size=200,
)
eval_dataset = GODData(
    subject="01", 
    session_id="01", 
    task="perception", 
    train=False, 
    limit_size=50,
)

train_dataloader = DataLoader(train_dataset, batch_size=8)
eval_dataloader = DataLoader(eval_dataset, batch_size=8)

print(f"# train: {len(train_dataset):>5}\n# test: {len(eval_dataset):>5}")

Loading data...
# train:   200
# test:    50


In [5]:
# instantiate model
print("Instantiating model...")
config = VideoMAEConfig(
    image_size=64,
    num_channels=3,
    num_frames=50,
    num_labels=150,
    problem_type="single_label_classification",
)

model = VideoMAEForVideoClassification(config).to(device)

Instantiating model...


In [6]:
def evaluate(model, dataloader):
    model.eval()
    accuracy = 0

    with torch.no_grad():
        for features, targets in dataloader:
            batch = {}
            batch["pixel_values"] = torch.stack([f.permute(1, 0, 2, 3) for f in features]).to(device)
            batch["labels"] = targets.to(device)

            outputs = model(**batch)
            predictions = np.argmax(outputs.logits.detach().cpu(), axis=-1)
            print(predictions, targets)
            accuracy += accuracy_score(targets.cpu(), predictions.cpu())

    accuracy /= len(dataloader)

    return {"accuracy": accuracy}

In [7]:
def train(model, dataloader, num_epochs, optimizer, eval_freq=10):
    model.train()
    loss_history = [] 
    metrics_history = []

    pbar = trange(num_epochs)
    for epoch in pbar:
        loss_epoch = 0
        for features, targets in dataloader:
            batch = {}
            batch["pixel_values"] = torch.stack([f.permute(1, 0, 2, 3) for f in features]).to(device)
            batch["labels"] = targets.to(device)

            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

            loss_epoch += loss.item()

        loss_epoch /= len(dataloader)

        if epoch % eval_freq == 0:
            loss_history.append(loss_epoch)
            metrics_history.append(evaluate(model, eval_dataloader))
            pbar.set_postfix(loss=f"{loss_history[-1]:.4f}", accuracy=f"{metrics_history[-1]['accuracy']*100:.4f}%")

    return loss_history

In [8]:
num_epochs = 1000
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [9]:
loss_history = train(model, train_dataloader, num_epochs, optimizer)

  0%|          | 0/1000 [00:00<?, ?it/s]

tensor([68, 68, 68, 68, 68, 68, 68, 68]) tensor([34, 27, 39, 28, 12, 12, 42, 10])
tensor([68, 68, 68, 68, 68, 68, 68, 68]) tensor([40, 33, 11, 44,  6, 22,  2, 46])
tensor([68, 68, 68, 68, 68, 68, 68, 68]) tensor([15, 32, 49, 49, 14,  0,  4, 17])
tensor([68, 68, 68, 68, 68, 68, 68, 68]) tensor([17, 25, 13, 13, 18, 48, 48,  1])
tensor([68, 68, 68, 68, 68, 68, 68, 68]) tensor([35, 31, 29, 16,  9,  7, 45, 37])
tensor([68, 68, 68, 68, 68, 68, 68, 68]) tensor([47, 23,  5,  8, 38, 24, 26, 41])
tensor([68, 68]) tensor([20, 19])
tensor([68, 68, 68, 68, 68, 68, 68, 68]) tensor([34, 27, 39, 28, 12, 12, 42, 10])
tensor([68, 68, 68, 68, 68, 68, 68, 68]) tensor([40, 33, 11, 44,  6, 22,  2, 46])
tensor([68, 68, 68, 68, 68, 68, 68, 68]) tensor([15, 32, 49, 49, 14,  0,  4, 17])
tensor([68, 68, 68, 68, 68, 68, 68, 68]) tensor([17, 25, 13, 13, 18, 48, 48,  1])
tensor([68, 68, 68, 68, 68, 68, 68, 68]) tensor([35, 31, 29, 16,  9,  7, 45, 37])
tensor([68, 68, 68, 68, 68, 68, 68, 68]) tensor([47, 23,  5,  8,