In [1]:
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd

In [2]:
class EPIC(Dataset):
    def __init__(self, data_path, p_id=None, df=None, train_batch_size=None):
        super().__init__()
        if p_id is not None:
            idd = df.index[df['participant_id'] == p_id].tolist()
            if len(idd)< train_batch_size:
                self.hasData = False
            else:
                self.hasData = True
            if 0 in idd:
                idd.remove(0)
            data = np.load(data_path)
            self.verb = data[idd, -2]
            self.noun = data[idd, -1]
            self.data = data[idd, :-2]
        else:
            data = np.load(data_path)
            self.verb = data[1:, -2]
            self.noun = data[1:, -1]
            self.data = data[1:, :-2]
    
    def __len__(self):
        return np.shape(self.data)[0]
    
    def __getitem__(self, idx):
        verb = np.squeeze(self.verb[idx])
        noun = np.squeeze(self.noun[idx])

        return {
                "data":     torch.tensor(np.expand_dims(np.squeeze(self.data[idx,:]), 0)).float(), 
                "verb_id":  torch.tensor(np.expand_dims(verb, 1)).long(), 
                "noun_id":  torch.tensor(np.expand_dims(noun, 1)).long()
                }

class IRM(torch.nn.Module):
    def __init__(self, num_class_verb, num_class_noun):
        super().__init__()
        self.verb_fc = torch.nn.Linear(2048, num_class_verb)
        self.noun_fc = torch.nn.Linear(2048, num_class_noun)
    
    def forward(self, x):
        return {"verb_out": self.verb_fc(x), "noun_out": self.noun_fc(x)}


In [15]:
torch.manual_seed(786)

df = pd.read_csv("EPIC_train_action_labels.csv", nrows=23192)
p_ids = list(set(df["participant_id"].tolist()))

d = np.load("train-unseen.npy")
weight_verb = [0]*125
weight_noun = [0]*352
val_verb, weight_verb_ = np.unique(d[:,-2], return_counts=True)
val_noun, weight_noun_ = np.unique(d[:,-1], return_counts=True)

for i in range(len(val_verb)):
    weight_verb[int(val_verb[i])] = weight_verb_[i]

for i in range(len(val_noun)):
    weight_noun[int(val_noun[i])] = weight_noun_[i]

weight_verb = [weight_verb[i]/np.shape(d)[0] for i in range(len(weight_verb))]
weight_noun = [weight_noun[i]/np.shape(d)[0] for i in range(len(weight_noun))]

train_batch_size = 100
train_loaders = []
for i in range(len(p_ids)):
    dataset_train = EPIC("train-unseen.npy", p_ids[i], df, train_batch_size)
    if dataset_train.hasData:
        train_loaders.append(DataLoader(dataset_train, batch_size=100, shuffle=True, drop_last=True))

dataset_ver= EPIC("train-unseen.npy")
ver_loader = DataLoader(dataset_ver, batch_size=200, shuffle=False, drop_last=True)

dataset_val = EPIC("val-unseen.npy")
val_loader = DataLoader(dataset_val, batch_size=200, shuffle=False, drop_last=True)

model = IRM(125, 352)
# model = IRM(2, 2)

dummy_w = torch.nn.Parameter(torch.tensor([1.0])) 
opt = torch.optim.SGD(model.parameters(), lr=1e-2)

def compute_penalty(losses, dummy):
  g1 = torch.autograd.grad(losses[0::2].mean(), dummy, create_graph=True)[0]
  g2 = torch.autograd.grad(losses[1::2].mean(), dummy, create_graph=True)[0]
  return (g1 * g2).sum()

ce_verb = torch.nn.CrossEntropyLoss(weight=torch.tensor(weight_verb).float(), reduction="none")
ce_noun = torch.nn.CrossEntropyLoss(weight=torch.tensor(weight_noun).float(), reduction="none")

for epoch in range(100):
    print("Epoch : "+str(epoch+1))
    model.train()
    _loaders = [iter(x) for x in train_loaders]

    err_verb = 0
    err_noun = 0
    counter = 0
    while True:
        verb_loss = 0
        noun_loss = 0
        verb_penalty = 0
        noun_penalty = 0
        weight_norms = 0
        penalty_factor = 0.15
        envs = 0
        for i in range(len(_loaders)):
            try:
                sample_batch = next(_loaders[i])
            except:
                continue
            envs+=1
            data = sample_batch["data"]
            verb = sample_batch["verb_id"]
            noun = sample_batch["noun_id"]

            # if torch.isnan(data).any():
            #     print("Nan Here")

            _out = model(data)
            verb_out = _out["verb_out"]
            noun_out = _out["noun_out"]
            
            verb_loss_ = ce_verb(verb_out.squeeze()*dummy_w, verb.view(-1))
            noun_loss_= ce_noun(noun_out.squeeze()*dummy_w, noun.view(-1))

            weight_norm = torch.as_tensor(0.0)

            for w in model.parameters():
                weight_norm += w.norm().pow(2)
            weight_norms += weight_norm

            if epoch > 50:
                verb_penalty += compute_penalty(verb_loss_, dummy_w)
                noun_penalty += compute_penalty(noun_loss_, dummy_w)

            verb_loss += verb_loss_.mean()
            noun_loss += noun_loss_.mean()
        
        if envs == 1:
            break
        else:
            counter+=envs
        # err_verb = verb_loss.item()+verb_penalty.item()
        # err_noun = noun_loss.item()+noun_penalty.item()
        err_verb += verb_loss.item()
        err_noun += noun_loss.item()

        opt.zero_grad()
        (penalty_factor*(verb_loss+noun_loss)+(verb_penalty+noun_penalty)+0.005*weight_norms).backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), 20)
        opt.step()

    print(err_verb/counter)
    print(err_noun/counter)
    
    model.eval()

    with torch.no_grad():
        correct_verb = 0
        correct_noun = 0
        correct_action = 0
        count = 0
        for idx, sample_batch in enumerate(ver_loader):
            data = sample_batch["data"]
            verb = sample_batch["verb_id"]
            noun = sample_batch["noun_id"]

            _out = model(data)
            verb_out = _out["verb_out"]
            noun_out = _out["noun_out"]

            pred_class_verb = torch.argmax(verb_out, dim=2)

            correct_verb += (
                (verb.view(-1) == pred_class_verb.view(-1))
                .float()
                .sum()
                .item()
            )

            pred_class_noun = torch.argmax(noun_out, dim=2)

            correct_noun += (
                (noun.view(-1) == pred_class_noun.view(-1))
                .float()
                .sum()
                .item()
            )

            # Accuracy Action
            correct_action += (
                (
                    (verb.view(-1) == pred_class_verb.view(-1)).int()
                    & (noun.view(-1) == pred_class_noun.view(-1)).int()
                )
                .float()
                .sum()
                .item()
            )

            count += 200
        print("Train Accuracy/Verb(Top 1)", correct_verb / count)
        print("Train Accuracy/Noun(Top 1)", correct_noun / count)
        print("Train Accuracy/Action(Top 1)", correct_action / count)



        correct_verb = 0
        correct_noun = 0
        correct_action = 0
        count = 0
        for idx, sample_batch in enumerate(val_loader):
            data = sample_batch["data"]
            verb = sample_batch["verb_id"]
            noun = sample_batch["noun_id"]

            _out = model(data)
            verb_out = _out["verb_out"]
            noun_out = _out["noun_out"]

            pred_class_verb = torch.argmax(verb_out, dim=2)

            correct_verb += (
                (verb.view(-1) == pred_class_verb.view(-1))
                .float()
                .sum()
                .item()
            )

            pred_class_noun = torch.argmax(noun_out, dim=2)

            correct_noun += (
                (noun.view(-1) == pred_class_noun.view(-1))
                .float()
                .sum()
                .item()
            )

            # Accuracy Action
            correct_action += (
                (
                    (verb.view(-1) == pred_class_verb.view(-1)).int()
                    & (noun.view(-1) == pred_class_noun.view(-1)).int()
                )
                .float()
                .sum()
                .item()
            )

            count += 200
        print("Val Accuracy/Verb(Top 1)", correct_verb / count)
        print("Val Accuracy/Noun(Top 1)", correct_noun / count)
        print("Val Accuracy/Action(Top 1)", correct_action / count)



al Accuracy/Verb(Top 1) 0.28365384615384615
Val Accuracy/Noun(Top 1) 0.11365384615384615
Val Accuracy/Action(Top 1) 0.026153846153846153
Epoch : 40
0.1473712130290706
0.07448281260525308
Train Accuracy/Verb(Top 1) 0.38073913043478264
Train Accuracy/Noun(Top 1) 0.12578260869565216
Train Accuracy/Action(Top 1) 0.0388695652173913
Val Accuracy/Verb(Top 1) 0.28442307692307695
Val Accuracy/Noun(Top 1) 0.11384615384615385
Val Accuracy/Action(Top 1) 0.02576923076923077
Epoch : 41
0.1466542488191186
0.07384980379081355
Train Accuracy/Verb(Top 1) 0.38130434782608696
Train Accuracy/Noun(Top 1) 0.12565217391304348
Train Accuracy/Action(Top 1) 0.0388695652173913
Val Accuracy/Verb(Top 1) 0.28365384615384615
Val Accuracy/Noun(Top 1) 0.11403846153846153
Val Accuracy/Action(Top 1) 0.025961538461538463
Epoch : 42
0.1462756841647916
0.07375104492757378
Train Accuracy/Verb(Top 1) 0.382
Train Accuracy/Noun(Top 1) 0.1254782608695652
Train Accuracy/Action(Top 1) 0.03904347826086957
Val Accuracy/Verb(Top 1) 0