<a href="https://colab.research.google.com/github/s275090/MLDL-First-Person-Action-Recognition/blob/main/main_run_self_supervision.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Install requirements**

In [None]:
!pip3 install 'tensorboardX' 

**Import Google Drive**

In [None]:
# Load the Drive helper and mount
# To download the repository https://drive.google.com/drive/folders/1_NAcoR0UGH1eLsiWMOx_Py8yeAocknA2?usp=sharing
from google.colab import drive
import os
drive.mount('/content/drive')

path = 'drive/My Drive/ego-rnn/'
os.chdir(path)
cwd = os.getcwd()
print("Current dir: "+cwd)

**Import libraries**

In [None]:
from __future__ import print_function, division
from spatial_transforms import (Compose, ToTensor, CenterCrop, Scale, Normalize, MultiScaleCornerCrop,
                                RandomHorizontalFlip)
from tensorboardX import SummaryWriter
from makeDatasetRGB import *
from makeDatasetMmaps import *
from MyConvLSTMCell import *

import argparse
import sys
import matplotlib.pyplot as plt

import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import glob
import random

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo

In [None]:
class MyMotionSegCell(nn.Module):

    def __init__(self, kernel_size=1, stride=1, padding=0):
        super(MyMotionSegCell, self).__init__()

        self.relu = nn.ReLU()
        self.ms_conv = nn.Conv2d(512, 100, kernel_size=1, stride=1, padding=0, bias=False)
        self.ms_fc = nn.Linear(100 * 7 * 7, 2 * 7 * 7)

    def forward(self, x):
        x = self.relu(x)
        x = self.ms_conv(x)
        x = x.view(x.size(0),100*7*7)
        x = self.ms_fc(x)
        x = x.view(x.size(0),2,7,7)

        return x


In [None]:
from torch.autograd import Variable
from torch.nn import functional as F
from resnetMod import *

class convLSTMModel(nn.Module):
    def __init__(self, num_classes=61, mem_size=512):
        super(convLSTMModel, self).__init__()
        self.num_classes = num_classes
        self.resNet = resnet34(False, True)
        self.mem_size = mem_size
        self.weight_softmax = self.resNet.fc.weight
        self.lstm_cell = MyConvLSTMCell(512, mem_size)
        self.ms_cell = MyMotionSegCell()
        self.avgpool = nn.AvgPool2d(7)
        self.dropout = nn.Dropout(0.7)
        self.fc = nn.Linear(mem_size, self.num_classes)
        self.classifier = nn.Sequential(self.dropout, self.fc)

    def forward(self, inputVariable, CAM = False, MS = False):
        state = (Variable(torch.zeros((inputVariable.size(1), self.mem_size, 7, 7)).cuda()),
                 Variable(torch.zeros((inputVariable.size(1), self.mem_size, 7, 7)).cuda()))
        feats_ms = []

        for t in range(inputVariable.size(0)):
            logit, feature_conv, feature_convNBN = self.resNet(inputVariable[t])

            if MS: 
              feats_ms.append(self.ms_cell(feature_conv))

            if CAM:
              bz, nc, h, w = feature_conv.size()
              feature_conv1 = feature_conv.view(bz, nc, h*w)
              probs, idxs = logit.sort(1, True)
              class_idx = idxs[:, 0]
              cam = torch.bmm(self.weight_softmax[class_idx].unsqueeze(1), feature_conv1)
              attentionMAP = F.softmax(cam.squeeze(1), dim=1)
              attentionMAP = attentionMAP.view(attentionMAP.size(0), 1, 7, 7)
              attentionFeat = feature_convNBN * attentionMAP.expand_as(feature_conv)
              state = self.lstm_cell(attentionFeat, state)
            else:
              state = self.lstm_cell(feature_conv, state)
        
        if MS:
          feats_ms = torch.stack(feats_ms, 0)

        feats1 = self.avgpool(state[1]).view(state[1].size(0), -1)
        feats = self.classifier(feats1)
        return feats, feats_ms, feats1

**Set Arguments**

In [None]:
data_dir = "GTEA61/processed_frames2"
out_dir = 'experiments'
model_folder = os.path.join('./', out_dir, 'self-supervised', 'ConvLSMT-Attention','16frm')  # Dir for saving models and log files

user_train = ['S1','S3','S4']
user_val = ['S2']
trainBatchSize = 64
valBatchSize = 64
memSize = 512
num_classes = 61

frame = 16
seqLen = frame

MS = True
CAM = True

**Prepare Dataset and Dataloader**

In [None]:
# Data loader
normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
spatial_transform = Compose([Scale(256), RandomHorizontalFlip(), MultiScaleCornerCrop([1, 0.875, 0.75, 0.65625], 224)])

vid_seq_train = makeDatasetMmaps(data_dir, user_train, frame,
                            spatial_transform=spatial_transform, normalize=normalize, seqLen=seqLen, fmt='.png')

train_loader = torch.utils.data.DataLoader(vid_seq_train, batch_size=trainBatchSize,
                        shuffle=True, num_workers=4, pin_memory=True)


vid_seq_val = makeDataset(data_dir, user_val, frame,
                            spatial_transform=Compose([Scale(256), CenterCrop(224), ToTensor(), normalize]),
                            seqLen=seqLen, fmt='.png')

val_loader = torch.utils.data.DataLoader(vid_seq_val, batch_size=valBatchSize,
                        shuffle=False, num_workers=2, pin_memory=True)

valInstances = vid_seq_val.__len__()
trainInstances = vid_seq_train.__len__()

print('Number of samples in the dataset: training = {} | validation = {}'.format(trainInstances, valInstances))

**Stage 1**

In [None]:
stage1_dict = (out_dir + '/rgb/ConvLSMT-Attention/16frame/stage1/model_rgb_state_dict.pth')

**Stage 2**

**Set Parameters**

In [None]:
numEpochs = 150
lr1 =1e-5 #1e-4
decay_step = [25, 75]
decay_factor = 0.1

In [None]:
# Create the dir
if os.path.exists(model_folder):
    print('Directory {} exists!'.format(model_folder))
    #sys.exit()
#os.makedirs(model_folder)

# Log files
writer = SummaryWriter(model_folder)
train_log_loss = open((model_folder + '/train_log_loss.txt'), 'w')
train_log_acc = open((model_folder + '/train_log_acc.txt'), 'w')
val_log_loss = open((model_folder + '/val_log_loss.txt'), 'w')
val_log_acc = open((model_folder + '/val_log_acc.txt'), 'w')

**Prepare Network and Train**

In [None]:
train_params = []

model = convLSTMModel(num_classes=num_classes, mem_size=memSize)

model.load_state_dict(torch.load(stage1_dict),strict=False)
model.train(False)
for params in model.parameters():
    params.requires_grad = False
#
for params in model.resNet.layer4[0].conv1.parameters():
    params.requires_grad = True
    train_params += [params]

for params in model.resNet.layer4[0].conv2.parameters():
    params.requires_grad = True
    train_params += [params]

for params in model.resNet.layer4[1].conv1.parameters():
    params.requires_grad = True
    train_params += [params]

for params in model.resNet.layer4[1].conv2.parameters():
    params.requires_grad = True
    train_params += [params]

for params in model.resNet.layer4[2].conv1.parameters():
    params.requires_grad = True
    train_params += [params]
#
for params in model.resNet.layer4[2].conv2.parameters():
    params.requires_grad = True
    train_params += [params]
#
for params in model.resNet.fc.parameters():
    params.requires_grad = True
    train_params += [params]

model.resNet.layer4[0].conv1.train(True)
model.resNet.layer4[0].conv2.train(True)
model.resNet.layer4[1].conv1.train(True)
model.resNet.layer4[1].conv2.train(True)
model.resNet.layer4[2].conv1.train(True)
model.resNet.layer4[2].conv2.train(True)
model.resNet.fc.train(True)

for params in model.lstm_cell.parameters():
    params.requires_grad = True
    train_params += [params]

for params in model.classifier.parameters():
    params.requires_grad = True
    train_params += [params]

for params in model.ms_cell.parameters():
    params.requires_grad = True
    train_params += [params]

model.lstm_cell.train(True)
model.ms_cell.train(True)

model.classifier.train(True)
model.cuda()

**Define Data Preprocessing**

In [None]:
loss_fn = nn.CrossEntropyLoss()

optimizer_fn = torch.optim.Adam(train_params, lr=lr1, weight_decay=4e-5, eps=1e-4)

optim_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer_fn, milestones=decay_step,
                                                        gamma=decay_factor)

**Train**

In [None]:
train_iter = 0
min_accuracy = 0
loss_mmaps = []


for epoch in range(numEpochs):
    epoch_loss = 0
    epoch_loss_mmap = 0
    numCorrTrain = 0
    trainSamples = 0
    iterPerEpoch = 0
    model.lstm_cell.train(True)
    model.ms_cell.train(True)
    model.classifier.train(True)
    writer.add_scalar('lr', optimizer_fn.param_groups[0]['lr'], epoch+1)

    model.resNet.layer4[0].conv1.train(True)
    model.resNet.layer4[0].conv2.train(True)
    model.resNet.layer4[1].conv1.train(True)
    model.resNet.layer4[1].conv2.train(True)
    model.resNet.layer4[2].conv1.train(True)
    model.resNet.layer4[2].conv2.train(True)
    model.resNet.fc.train(True)

    for i, (inputs, mmaps, targets) in enumerate(train_loader):
        train_iter += 1
        iterPerEpoch += 1
        optimizer_fn.zero_grad()
        inputVariable = Variable(inputs.permute(1, 0, 2, 3, 4).cuda())
        labelVariable = Variable(targets.cuda())
        trainSamples += inputs.size(0)

        output_label, output_mmaps, _ = model(inputVariable,CAM, MS)
        
        loss = loss_fn(output_label, labelVariable)
        
        if MS:
          mmapsVariable = Variable(mmaps.cuda())
          mmapsVariable = torch.squeeze(mmapsVariable)
          output_mmaps = output_mmaps.permute(1, 2, 0, 3, 4)
          loss+=loss_fn(output_mmaps, mmapsVariable.long())

        loss.backward()
        optimizer_fn.step()
        _, predicted = torch.max(output_label.data, 1)
        numCorrTrain += (predicted == targets.cuda()).sum()

        epoch_loss += loss.item()
        epoch_loss_mmap += loss_fn(output_mmaps, mmapsVariable.long()).item()
    avg_loss = epoch_loss/iterPerEpoch
    avg_loss_mmap = epoch_loss_mmap/iterPerEpoch
    loss_mmaps.append(avg_loss_mmap)
    trainAccuracy = (numCorrTrain.item() / trainSamples) * 100

    print('Train: Epoch = {} | Loss = {} | Accuracy = {}'.format(epoch+1, avg_loss, trainAccuracy))
    
    train_log_loss.write('Train Loss after {} epochs = {}\n'.format(epoch + 1, avg_loss))
    train_log_acc.write('Train Accuracy after {} epochs = {}%\n'.format(epoch + 1, trainAccuracy))
    writer.add_scalar('train/epoch_loss', avg_loss, epoch+1)
    writer.add_scalar('train/accuracy', trainAccuracy, epoch+1)
    
    if (epoch+1) % 1 == 0:
        model.train(False)
        model.ms_cell.train(False)
        val_loss_epoch = 0
        val_iter = 0
        val_samples = 0
        numCorr = 0
        for j, (inputs, targets) in enumerate(val_loader):
            val_iter += 1
            val_samples += inputs.size(0)
            inputVariable = Variable(inputs.permute(1, 0, 2, 3, 4).cuda())
            labelVariable = Variable(targets.cuda(non_blocking=True))
            output_label, _, _ = model(inputVariable, CAM, False)
            val_loss = loss_fn(output_label, labelVariable)
            val_loss_epoch += val_loss.item()
            _, predicted = torch.max(output_label.data, 1)
            numCorr += (predicted == targets.cuda()).sum()
        val_accuracy = (numCorr.item() / val_samples) * 100
        avg_val_loss = val_loss_epoch / val_iter
        print('Val: Epoch = {} | Loss {} | Accuracy = {}'.format(epoch + 1, avg_val_loss, val_accuracy))
        writer.add_scalar('val/epoch_loss', avg_val_loss, epoch + 1)
        writer.add_scalar('val/accuracy', val_accuracy, epoch + 1)
        val_log_loss.write('Val Loss after {} epochs = {}\n'.format(epoch + 1, avg_val_loss))
        val_log_acc.write('Val Accuracy after {} epochs = {}%\n'.format(epoch + 1, val_accuracy))
        
        if val_accuracy > min_accuracy:
            save_path_model = (model_folder + '/model_rgb_state_dict.pth')
            torch.save(model.state_dict(), save_path_model)
            min_accuracy = val_accuracy
    
    # Step the scheduler
    optim_scheduler.step()
    

train_log_loss.close()
train_log_acc.close()
val_log_acc.close()
val_log_loss.close()
writer.export_scalars_to_json(model_folder + "/all_scalars.json")
writer.close()

print('Best accuracy after {} epochs = {}'.format(epoch, min_accuracy))