In [None]:
from abc import ABC

from gulpio2 import GulpDirectory
from pathlib import Path

from collections import defaultdict

import pickle
import pandas as pd
from pathlib import Path

import torch as t
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from omegaconf import OmegaConf
from typing import Any, Dict, List, Sequence, Union

from systems import EpicActionRecognitionSystem
from systems import EpicActionRecogintionDataModule

from utils.metrics import compute_metrics
from utils.actions import action_id_from_verb_noun
from scipy.special import softmax

from GPUtil import showUtilization as gpu_usage
from tqdm import tqdm

from frame_sampling import RandomSampler
from torchvideo.samplers import FrameSampler
from torchvideo.samplers import frame_idx_to_list

from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
import torch.utils.data as data_utils

device = t.device("cuda:0" if t.cuda.is_available() else "cpu")

In [None]:
class PickleDataset(Dataset):
    
    def __init__(self, pkl_path: Path, frame_sampler: FrameSampler, features_dim: int = 256):
        self.pkl_path = pkl_path
        self.frame_sampler = frame_sampler
        self.features_dim = features_dim
        self.pkl_dict = Dict[str, Any]
        self.frame_cumsum = np.array([0.])
        self._load()
        
    def _load(self):
        with open(self.pkl_path, 'rb') as f:
            self.pkl_dict = pickle.load(f)
            frame_counts = [label['num_frames'] for label in self.pkl_dict['labels']]
            self.frame_cumsum = np.cumsum(np.concatenate([self.frame_cumsum, frame_counts]), dtype=int)
    
    def _video_from_narration_id(self, key: int):
#         video_no = self.pkl_dict['narration_id'].index(narration_id)
        l = self.frame_cumsum[key]
        r = self.frame_cumsum[key+1]
        return self.pkl_dict['features'][l:r]
    
    def __len__(self):
        return len(self.pkl_dict['narration_id'])
    
    def __getitem__(self, key: int):
        features = self._video_from_narration_id(key)
        video_length = features.shape[0]
        
        assert video_length == self.pkl_dict['labels'][key]['num_frames']
        if video_length < frame_sampler.frame_count:
            raise ValueError(f"Video too short to sample {n_frames} from")
        
        sample_idxs = np.array(frame_idx_to_list(frame_sampler.sample(video_length)))
        return (features[sample_idxs], self.pkl_dict['labels'][key])


In [None]:
n_frames = 8
frame_sampler = RandomSampler(frame_count=n_frames, snippet_length=1, test=False)

In [None]:
dataset = PickleDataset('../datasets/epic/features/p01_01_1_features.pkl', frame_sampler)

In [None]:
# xd, xs = dataset._video_from_narration_id(dataset.pkl_dict['narration_id'][1])

In [None]:
dataloader = DataLoader(dataset, shuffle=True)

In [None]:
# dataset.pkl_dict['labels']
# for inp, lables in :
#     print(inp, labels)

In [None]:
class Net(nn.Module):
    
    def __init__(self, frame_count: int):
        super().__init__()
        self.frame_count = frame_count
        self.fc1 = nn.Linear(256 * frame_count, 512)
        self.fc2 = nn.Linear(512, 397)
    
    def forward(self, x):
        x = x.view(-1, 256 * self.frame_count)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x

In [None]:
net = Net(frame_count=8).to(device)
criterion = nn.CrossEntropyLoss()
optimiser = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(2):
    running_loss = 0.0
    
    for i, data in tqdm(
        enumerate(dataloader),
        total=len(dataloader),
        dynamic_ncols=True
    ):
        inputs, labels = data
        
        labels = t.tensor([[labels['verb_class'],labels['noun_class']]], device = device)
        optimiser.zero_grad()
        
        out = net(inputs.to(device))
#         scores = {
#             'verb': out[:,:97].cpu().numpy(),
#             'noun': out[:,97:].cpu().numpy(),
#             'narration_id': labels['narration_id']
#         }
    
#         verb_top_n = scores['verb'][0].argsort()[::-1][0]
#         noun_top_n = scores['noun'][0].argsort()[::-1][0]
        
#         output = np.concatenate([verb_top_n,noun_top_n])

        out_v = out[:,:97]
        out_n = out[:,97:]
        
        loss_v = criterion(out_v, labels[:,0])
#         loss_n = criterion(out_n, labels[:,1])
        
        loss_v.backward()
#         loss_n.backward()

        
        optimiser.step()
        
        running_loss += loss_v.item()#+loss_n.item())/2
        if i % 2 == 0:
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0