In [1]:
import cv2
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as pretrained

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
from tqdm import tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using {device}')

Using cuda


In [3]:
class VideoDataset(Dataset):

    def __init__(self, video_dir_path):

        self.resnet = pretrained.resnet50(pretrained=True).to(device)
        self.resnet.eval()
        self.layer = self.resnet.avgpool
        self.video_embeddings = []
        self.labels = []

        self.transforms = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                0.229, 0.224, 0.255]),
        ])

        self.label_map = {
            'CricketShot': torch.tensor(0),
            'PlayingCello': torch.tensor(1),
            'Punch': torch.tensor(2),
            'ShavingBeard': torch.tensor(3),
            'TennisSwing': torch.tensor(4)
        }

        def hook(module, inputs, outputs):
            self.video_embeddings[-1].append(outputs.detach().cpu().squeeze())

        self.handle = self.layer.register_forward_hook(hook)

        for video in tqdm(os.listdir(video_dir_path)):
            frames = self.get_frames(os.path.join(video_dir_path, video))
            self.video_embeddings.append([])

            for frame in frames:
                inp = self.transforms(Image.fromarray(
                    frame)).to(device).unsqueeze(0)
                self.resnet(inp)

            action = video.split('_')[1]
            self.labels.append(action)

        self.handle.remove()
        del self.handle
        del self.resnet
        del self.layer

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.vstack(self.video_embeddings[idx]), self.label_map[self.labels[idx]]

    def get_frames(self, path, max_frames=20):
        vidObj = cv2.VideoCapture(path)
        success = 1
        frames = []
        count = 0
        while success:
            success, image = vidObj.read()
            count += 1
            if not success or count > max_frames:
                break
            frames.append(image)
        return frames

    def unregister_hook(self):
        self.handle.remove()
        del self.handle


In [4]:
import pickle

with open('train_dataset_hcf.pkl', 'rb') as f:
    train_dataset = pickle.load(f)

with open('test_dataset_hcf.pkl', 'rb') as f:
    test_dataset = pickle.load(f)


In [5]:
class Attention(nn.Module):
  def __init__(self, embedding_dim, n_hidden):
    super().__init__()

    self.embedding_dim = embedding_dim
    self.n_hidden = n_hidden

    self.wx = nn.Linear(self.embedding_dim, self.embedding_dim)
    self.wh = nn.Linear(self.n_hidden, self.embedding_dim)
    self.wc = nn.Linear(self.n_hidden, self.embedding_dim)
    self.sigmoid = nn.Sigmoid()

  def forward(self, X, h, c):
    out1 = self.wx(X)
    out2 = self.wh(h)
    out3 = self.wc(c)
    a = self.sigmoid(out1+out2+out3)
    
    return torch.mul(a,X)

In [7]:

class EleAttG_LSTM(nn.Module):
  def __init__(self, embedding_dim, n_hidden=128, n_classes=None):
    super().__init__()

    assert n_classes is not None

    self.embedding_dim = embedding_dim
    self.n_hidden = n_hidden
    self.n_classes = n_classes

    self.attention = Attention(self.embedding_dim, self.n_hidden)
    self.lstmcell = nn.LSTMCell(self.embedding_dim, self.n_hidden)
    self.fc = nn.Sequential(
        nn.Linear(self.n_hidden, self.n_hidden),
        nn.ReLU(),
        nn.Linear(self.n_hidden, self.n_classes),
        nn.Softmax(dim=1)
    )

  def forward(self, X):
    '''
      x = batch_size * frames * embedding_dim

    '''
    h = torch.zeros(X.shape[0], self.n_hidden).to(device)
    c = torch.zeros(X.shape[0], self.n_hidden).to(device)
    for i in range(X.shape[1]):
      X[:, i, :] = self.attention(X[:, i, :].clone(), h, c)
      h, c = self.lstmcell(X[:, i, :].clone(), (h, c))

    return self.fc(h)


In [8]:
class Vanilla_LSTM(nn.Module):
    def __init__(self, embedding_dim, n_hidden, n_classes=None):
        super().__init__()

        assert n_classes is not None
        self.embedding_dim = embedding_dim
        self.n_hidden = n_hidden
        self.n_classes = n_classes

        self.lstmcell = nn.LSTMCell(self.embedding_dim, self.n_hidden)
        self.fc = nn.Sequential(
            nn.Linear(self.n_hidden, self.n_hidden),
            nn.ReLU(),
            nn.Linear(self.n_hidden, self.n_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, X):
        h = torch.zeros(X.shape[0], self.n_hidden).to(device)
        c = torch.zeros(X.shape[0], self.n_hidden).to(device)
        for i in range(X.shape[1]):
            h, c = self.lstmcell(X[:, i, :].clone(), (h, c))

        return self.fc(h)


In [9]:
model = EleAttG_LSTM(2048, 256, 5).to(device)
vanilla_model = Vanilla_LSTM(2048, 256, 5).to(device)

In [10]:
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

In [11]:
def train(model, dataloader, device, n_epochs=10):
    optimizer = torch.optim.AdamW(model.parameters())
    criterion = nn.CrossEntropyLoss()

    for epoch in range(n_epochs):
        loss_val = 0
        for i, (X, y) in enumerate(tqdm(dataloader)):
            optimizer.zero_grad()
            X = X.to(device)
            y = y.to(device)
            y_pred = model(X)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            loss_val += loss.item()

        print(f'Epoch:{epoch} Loss:{loss_val}')


In [12]:
train(model, train_dataloader, device, 10)

100%|██████████| 75/75 [00:06<00:00, 12.26it/s]


Epoch:0 Loss:87.23476886749268


100%|██████████| 75/75 [00:03<00:00, 22.14it/s]


Epoch:1 Loss:70.5823045372963


100%|██████████| 75/75 [00:03<00:00, 22.02it/s]


Epoch:2 Loss:69.50165128707886


100%|██████████| 75/75 [00:03<00:00, 22.80it/s]


Epoch:3 Loss:68.21627241373062


100%|██████████| 75/75 [00:03<00:00, 21.99it/s]


Epoch:4 Loss:68.02652633190155


100%|██████████| 75/75 [00:03<00:00, 22.22it/s]


Epoch:5 Loss:67.99511551856995


100%|██████████| 75/75 [00:03<00:00, 23.02it/s]


Epoch:6 Loss:67.99599099159241


100%|██████████| 75/75 [00:03<00:00, 21.40it/s]


Epoch:7 Loss:70.80347400903702


100%|██████████| 75/75 [00:03<00:00, 23.54it/s]


Epoch:8 Loss:69.29672348499298


100%|██████████| 75/75 [00:03<00:00, 21.98it/s]

Epoch:9 Loss:67.9708240032196





In [13]:
train(vanilla_model, train_dataloader, device, 10)

100%|██████████| 75/75 [00:01<00:00, 67.20it/s]


Epoch:0 Loss:86.68452328443527


100%|██████████| 75/75 [00:01<00:00, 65.05it/s]


Epoch:1 Loss:71.34319418668747


100%|██████████| 75/75 [00:01<00:00, 65.16it/s]


Epoch:2 Loss:68.70965242385864


100%|██████████| 75/75 [00:01<00:00, 69.34it/s]


Epoch:3 Loss:68.0726643204689


100%|██████████| 75/75 [00:01<00:00, 72.04it/s]


Epoch:4 Loss:67.88857561349869


100%|██████████| 75/75 [00:00<00:00, 82.86it/s]


Epoch:5 Loss:67.87942868471146


100%|██████████| 75/75 [00:01<00:00, 68.85it/s]


Epoch:6 Loss:67.87403935194016


100%|██████████| 75/75 [00:01<00:00, 74.24it/s]


Epoch:7 Loss:67.87047755718231


100%|██████████| 75/75 [00:00<00:00, 78.06it/s]


Epoch:8 Loss:67.86845952272415


100%|██████████| 75/75 [00:00<00:00, 77.21it/s]

Epoch:9 Loss:67.86715215444565





In [14]:
preds = list()
y_true = list()

for (X, y) in tqdm(test_dataset):
    X = X.unsqueeze(0).to(device)
    out = model(X).squeeze(0)
    y_pred = out.argmax()
    preds.append(y_pred.item())
    y_true.append(y.item())

100%|██████████| 224/224 [00:01<00:00, 118.59it/s]


In [15]:
from sklearn.metrics import accuracy_score, f1_score, classification_report, precision_score


In [17]:
att_lstm_accu = accuracy_score(y_true, preds)
print(classification_report(y_true, preds))

              precision    recall  f1-score   support

           0       1.00      0.90      0.95        49
           1       0.98      1.00      0.99        44
           2       1.00      0.95      0.97        39
           3       0.98      1.00      0.99        43
           4       0.91      1.00      0.95        49

    accuracy                           0.97       224
   macro avg       0.97      0.97      0.97       224
weighted avg       0.97      0.97      0.97       224



In [18]:
preds = list()
y_true = list()

for (X, y) in tqdm(test_dataset):
    X = X.unsqueeze(0).to(device)
    out = vanilla_model(X).squeeze(0)
    y_pred = out.argmax()
    preds.append(y_pred.item())
    y_true.append(y.item())


100%|██████████| 224/224 [00:00<00:00, 292.92it/s]


In [19]:
vanilla_accu = accuracy_score(y_true, preds)
print(classification_report(y_true, preds))

              precision    recall  f1-score   support

           0       0.98      1.00      0.99        49
           1       0.98      0.95      0.97        44
           2       0.95      0.97      0.96        39
           3       1.00      1.00      1.00        43
           4       1.00      0.98      0.99        49

    accuracy                           0.98       224
   macro avg       0.98      0.98      0.98       224
weighted avg       0.98      0.98      0.98       224



In [20]:
print(vanilla_accu)
print(att_lstm_accu)

0.9821428571428571
0.96875
