In [30]:
from abc import ABC

from gulpio2 import GulpDirectory
from pathlib import Path

from collections import defaultdict

import pickle
import pandas as pd
from pathlib import Path

import torch as t
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from omegaconf import OmegaConf
from typing import Any, Dict, List, Sequence, Union, Tuple

from systems import EpicActionRecognitionSystem
from systems import EpicActionRecogintionDataModule

from utils.metrics import compute_metrics
from utils.actions import action_id_from_verb_noun
from scipy.special import softmax

from GPUtil import showUtilization as gpu_usage
from tqdm import tqdm

from frame_sampling import RandomSampler
from torchvideo.samplers import FrameSampler
from torchvideo.samplers import frame_idx_to_list

from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.utils.data as data_utils

device = t.device("cuda:0" if t.cuda.is_available() else "cpu")

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"


In [161]:
class PickleDataset(Dataset):
    
    def __init__(self, pkl_path: Path, frame_sampler: FrameSampler, features_dim: int = 256):
        self.pkl_path = pkl_path
        self.frame_sampler = frame_sampler
        self.features_dim = features_dim
        self.pkl_dict = Dict[str, Any]
        self.frame_cumsum = np.array([0.])
        self._load()
        
    def _load(self):
        with open(self.pkl_path, 'rb') as f:
            self.pkl_dict = pickle.load(f)
            frame_counts = [label['num_frames'] for label in self.pkl_dict['labels']]
            self.frame_cumsum = np.cumsum(np.concatenate([self.frame_cumsum, frame_counts]), dtype=int)
    
    def _video_from_narration_id(self, key: int):
        l = self.frame_cumsum[key]
        r = self.frame_cumsum[key+1]
        return self.pkl_dict['features'][l:r]
    
    def __len__(self):
        return len(self.pkl_dict['narration_id'])
    
    def __getitem__(self, key: int):
        features = self._video_from_narration_id(key)
        video_length = features.shape[0]
        
        assert video_length == self.pkl_dict['labels'][key]['num_frames']
        if video_length < self.frame_sampler.frame_count:
            raise ValueError(f"Video too short to sample {self.frame_sampler.frame_count} from")
        
        sample_idxs = np.array(frame_idx_to_list(frame_sampler.sample(video_length)))
        return (features[sample_idxs], { k: self.pkl_dict['labels'][key][k] for k in ['narration_id','verb_class','noun_class'] })


In [162]:
class Net(nn.Module):
    
    def __init__(self, frame_count: int):
        super().__init__()
        self.frame_count = frame_count
        self.fc1 = nn.Linear(256 * frame_count, 512)
        self.fc2 = nn.Linear(512, 397)
    
    def forward(self, x):
        x = x.view(-1, 256 * self.frame_count)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x

In [163]:
n_frames = 8
frame_sampler = RandomSampler(frame_count=n_frames, snippet_length=1, test=False)

# def collate(data):
#     inputs, labels = zip(*data)
    
#     inp = t.tensor(inputs)
#     print(labels)

In [172]:
dataset = PickleDataset('../datasets/epic/features/p01_features.pkl', frame_sampler)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [173]:
class Classifier:
    
    def __init__(self, 
        model: nn.Module, 
        dataloader: DataLoader, 
        optimiser: optim.Adadelta, 
        device: t.device, 
        log_interval: int = 100
    ):
        self.model = model
        self.dataloader = dataloader
        self.optimiser = optimiser
        self.device = device
        self.log_interval = log_interval
        
    def _step(self, batch: Tuple[t.Tensor, Dict[str, Any]]) -> Dict[str, Any]:

        data, labels = batch
        self.optimiser.zero_grad()
        outputs = self.model(data.to(self.device))
        tasks = {
            'verb': {
                'output': outputs[:,:97],
                'preds': outputs[:,:97].argmax(-1),
                'labels': labels['verb_class'],
                'weight': 1
            },
            'noun': {
                'output': outputs[:,97:],
                'preds': outputs[:,97:].argmax(-1),
                'labels': labels['noun_class'],
                'weight': 1
            },
        }
        
        step_results = dict()
        loss = 0.0
        n_tasks = len(tasks)
        for task, d in tasks.items():
            task_loss = F.cross_entropy(d['output'], d['labels'].to(device))
            loss += d['weight'] * task_loss
            
        step_results['narration_id'] = labels['narration_id']
        step_results['loss'] = loss / n_tasks
        return step_results
        
    def train(self, epoch):
        self.model.train()
        running_loss = 0.0
        for batch_idx, data in enumerate(self.dataloader):
            
            step_results = self._step(data)
            loss = step_results['loss']
            
            loss.backward()
            self.optimiser.step()
            
            running_loss += loss.item()
            
            if batch_idx % self.log_interval == 0:
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, batch_idx + 1, running_loss / self.log_interval))
                running_loss = 0.0

In [174]:
model = Net(frame_count=8).to(device)
optimiser = optim.Adadelta(model.parameters(), lr=0.005)
classifier = Classifier(model, dataloader, optimiser, device, log_interval=100)

for epoch in range(2):
    classifier.train(epoch)
    
# dataiter = iter(dataloader)
# inputs, labels = dataiter.next()

# out = model(inputs.to(device))

# nouns = out[:,:97]

# nouns.argmax(-1)

[1,     1] loss: 0.237
[1,   101] loss: 9.541
[1,   201] loss: 6.007
[1,   301] loss: 5.143
[1,   401] loss: 4.564
[1,   501] loss: 4.461
[1,   601] loss: 4.341
[1,   701] loss: 4.411
[1,   801] loss: 4.211
[1,   901] loss: 3.993
[1,  1001] loss: 4.157
[1,  1101] loss: 4.143
[1,  1201] loss: 4.061
[1,  1301] loss: 3.679
[1,  1401] loss: 4.289
[1,  1501] loss: 4.056
[1,  1601] loss: 3.700
[1,  1701] loss: 3.787
[1,  1801] loss: 4.019
[1,  1901] loss: 3.939
[1,  2001] loss: 3.848
[1,  2101] loss: 3.979
[1,  2201] loss: 3.974
[1,  2301] loss: 3.945
[1,  2401] loss: 3.755
[1,  2501] loss: 3.820
[1,  2601] loss: 3.653
[1,  2701] loss: 3.717
[1,  2801] loss: 3.973
[1,  2901] loss: 3.579
[1,  3001] loss: 3.755
[1,  3101] loss: 3.654
[1,  3201] loss: 4.076
[1,  3301] loss: 3.512
[1,  3401] loss: 3.813
[1,  3501] loss: 3.577
[1,  3601] loss: 3.646
[1,  3701] loss: 3.630
[1,  3801] loss: 3.513
[1,  3901] loss: 3.806
[1,  4001] loss: 3.655
[1,  4101] loss: 3.631
[1,  4201] loss: 3.511
[1,  4301] 