In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as pretrained
from torchvision import transforms
from tqdm import tqdm

import requests
from PIL import Image
import pandas as pd
import numpy as np
import imageio
import cv2
import os
from torch.utils.data import Dataset, DataLoader


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


In [3]:
IMG_SIZE = 64
BATCH_SIZE = 8
EPOCHS = 10

MAX_SEQ_LENGTH = 20
NUM_FEATURES = (IMG_SIZE**2) * 3
LEARNING_RATE = 0.001


In [4]:
labels = []
data = []

for label in os.listdir('./JHMDB_video/ReCompress_Videos/'):
    labels.append(label)
    for video in os.listdir(f'./JHMDB_video/ReCompress_Videos/{label}'):
        if video.endswith('.avi'):
            data.append((label, f'./JHMDB_video/ReCompress_Videos/{label}/{video}'))


In [5]:
label_map = {label:torch.tensor(idx) for idx, (label) in enumerate(labels)}

In [6]:
class VideoDataset(Dataset):

    def __init__(self, video_dir_path):

        self.resnet = pretrained.resnet50(pretrained=True).to(device)
        self.resnet.eval()
        self.layer = self.resnet.avgpool
        self.video_embeddings = []
        self.labels = []

        self.transforms = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                0.229, 0.224, 0.255]),
        ])

        self.label_map = {}

        def hook(module, inputs, outputs):
            self.video_embeddings[-1].append(outputs.detach().cpu().squeeze())

        self.handle = self.layer.register_forward_hook(hook)

        for label in os.listdir(video_dir_path):
            # print(label)
            for video in tqdm(os.listdir(os.path.join(video_dir_path, label))):
                if not video.endswith('.avi'):
                    continue
                frames = self.get_frames(os.path.join(video_dir_path, f'{label}/{video}'))
                self.video_embeddings.append([])
                self.labels.append(label)

                for frame in frames:
                    inp = self.transforms(Image.fromarray(
                        frame)).to(device).unsqueeze(0)
                    self.resnet(inp)

        self.handle.remove()
        del self.handle
        del self.resnet
        del self.layer

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.vstack(self.video_embeddings[idx]), self.label_map[self.labels[idx]]

    def get_frames(self, path, max_frames=20):
        vidObj = cv2.VideoCapture(path)
        success = 1
        frames = []
        count = 0
        while success:
            success, image = vidObj.read()
            count += 1
            if not success or count > max_frames:
                break
            frames.append(image)

        for i in range(len(frames), max_frames + 1):
            frames.append(np.zeros_like(frames[0]))
        return frames


In [8]:
import pickle

with open('train_dataset_jhmdb.pkl','rb') as f:
    train_dataset = pickle.load(f)

with open('test_dataset_jhmdb.pkl', 'rb') as f:
    test_dataset = pickle.load(f)

In [10]:
train_dataset.label_map = label_map
test_dataset.label_map = label_map

In [11]:
class Attention(nn.Module):
  def __init__(self, embedding_dim, n_hidden):
    super().__init__()

    self.embedding_dim = embedding_dim
    self.n_hidden = n_hidden

    self.wx = nn.Linear(self.embedding_dim, self.embedding_dim)
    self.wh = nn.Linear(self.n_hidden, self.embedding_dim)
    self.wc = nn.Linear(self.n_hidden, self.embedding_dim)
    self.sigmoid = nn.Sigmoid()

  def forward(self, X, h, c):
    out1 = self.wx(X)
    out2 = self.wh(h)
    out3 = self.wc(c)
    a = self.sigmoid(out1+out2+out3)
    
    return torch.mul(a,X)

In [23]:

class EleAttG_LSTM(nn.Module):
  def __init__(self, embedding_dim, n_hidden=128, n_classes=None):
    super().__init__()

    assert n_classes is not None

    self.embedding_dim = embedding_dim
    self.n_hidden = n_hidden
    self.n_classes = n_classes

    self.attention = Attention(self.embedding_dim, self.n_hidden)
    self.lstmcell = nn.LSTMCell(self.embedding_dim, self.n_hidden)
    self.fc = nn.Sequential(
        nn.Linear(self.n_hidden, self.n_hidden),
        nn.ReLU(),
        nn.Linear(self.n_hidden, self.n_classes),
        nn.Softmax(dim=1)
    )

  def forward(self, X):
    '''
      x = batch_size * frames * embedding_dim

    '''
    h = torch.zeros(X.shape[0], self.n_hidden).to(device)
    c = torch.zeros(X.shape[0], self.n_hidden).to(device)
    for i in range(X.shape[1]):
      X[:, i, :] = self.attention(X[:, i, :].clone(), h, c)
      h,c  = self.lstmcell(X[:, i, :].clone(), (h, c))

    return self.fc(h)


In [24]:
class Vanilla_LSTM(nn.Module):
    def __init__(self, embedding_dim, n_hidden, n_classes=None):
        super().__init__()

        assert n_classes is not None
        self.embedding_dim = embedding_dim
        self.n_hidden = n_hidden
        self.n_classes = n_classes

        self.lstmcell = nn.LSTMCell(self.embedding_dim, self.n_hidden)
        self.fc = nn.Sequential(
            nn.Linear(self.n_hidden, self.n_hidden),
            nn.ReLU(),
            nn.Linear(self.n_hidden, self.n_classes),
            nn.Softmax(dim=1)
        )

    def forward(self, X):
        h = torch.zeros(X.shape[0], self.n_hidden).to(device)
        c = torch.zeros(X.shape[0], self.n_hidden).to(device)
        for i in range(X.shape[1]):
            h, c = self.lstmcell(X[:, i, :].clone(), (h, c))

        return self.fc(h)


In [25]:
model = EleAttG_LSTM(2048, 256, 21).to(device)


In [16]:
training_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True)


In [26]:
def train(model, dataloader, device, n_epochs=10):
    optimizer = torch.optim.AdamW(model.parameters())
    criterion = nn.CrossEntropyLoss()

    for epoch in range(n_epochs):
        loss_val = 0
        for i, (X, y) in enumerate(tqdm(dataloader)):
            optimizer.zero_grad()
            X = X.to(device)
            y = y.to(device)
            y_pred = model(X)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            loss_val += loss.item()

        print(f'Epoch:{epoch} Loss:{loss_val}')


In [27]:
train(model, training_loader, device, 10)

100%|██████████| 93/93 [00:03<00:00, 30.86it/s]


Epoch:0 Loss:281.832720041275


100%|██████████| 93/93 [00:02<00:00, 33.85it/s]


Epoch:1 Loss:276.42549419403076


100%|██████████| 93/93 [00:02<00:00, 32.89it/s]


Epoch:2 Loss:271.6280689239502


100%|██████████| 93/93 [00:02<00:00, 31.68it/s]


Epoch:3 Loss:266.30261421203613


100%|██████████| 93/93 [00:02<00:00, 31.10it/s]


Epoch:4 Loss:261.7419385910034


100%|██████████| 93/93 [00:02<00:00, 32.79it/s]


Epoch:5 Loss:259.3659362792969


100%|██████████| 93/93 [00:02<00:00, 33.28it/s]


Epoch:6 Loss:254.7185959815979


100%|██████████| 93/93 [00:02<00:00, 32.46it/s]


Epoch:7 Loss:252.5578374862671


100%|██████████| 93/93 [00:02<00:00, 33.05it/s]


Epoch:8 Loss:250.9123728275299


100%|██████████| 93/93 [00:02<00:00, 32.46it/s]

Epoch:9 Loss:250.9273064136505





In [29]:
vanilla_model = Vanilla_LSTM(2048, 256, 21)
vanilla_model.to(device)

Vanilla_LSTM(
  (lstmcell): LSTMCell(2048, 256)
  (fc): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=21, bias=True)
    (3): Softmax(dim=1)
  )
)

In [30]:
train(vanilla_model, training_loader, device, 10)

100%|██████████| 93/93 [00:01<00:00, 68.66it/s]


Epoch:0 Loss:282.89400124549866


100%|██████████| 93/93 [00:01<00:00, 92.81it/s] 


Epoch:1 Loss:280.8013508319855


100%|██████████| 93/93 [00:00<00:00, 105.98it/s]


Epoch:2 Loss:277.8125493526459


100%|██████████| 93/93 [00:01<00:00, 87.39it/s] 


Epoch:3 Loss:274.35194396972656


100%|██████████| 93/93 [00:01<00:00, 87.40it/s] 


Epoch:4 Loss:270.8425028324127


100%|██████████| 93/93 [00:01<00:00, 82.92it/s]


Epoch:5 Loss:266.69428181648254


100%|██████████| 93/93 [00:01<00:00, 90.86it/s] 


Epoch:6 Loss:263.75190353393555


100%|██████████| 93/93 [00:01<00:00, 82.36it/s] 


Epoch:7 Loss:262.0747709274292


100%|██████████| 93/93 [00:00<00:00, 100.88it/s]


Epoch:8 Loss:259.2995517253876


100%|██████████| 93/93 [00:00<00:00, 97.64it/s] 

Epoch:9 Loss:259.65176224708557





In [31]:
preds = list()
y_true = list()

for (X, y) in tqdm(test_dataset):
    X = X.unsqueeze(0).to(device)
    out = model(X).squeeze(0)
    y_pred = out.argmax()
    preds.append(y_pred.item())
    y_true.append(y.item())


100%|██████████| 186/186 [00:01<00:00, 110.90it/s]


In [32]:
from sklearn.metrics import accuracy_score, f1_score, classification_report, precision_score


In [33]:
att_gru_accu = accuracy_score(y_true, preds)
print(classification_report(y_true, preds))


              precision    recall  f1-score   support

           0       0.00      0.00      0.00        10
           1       0.46      1.00      0.63        11
           2       0.15      0.78      0.25         9
           3       0.00      0.00      0.00         5
           4       0.56      0.56      0.56         9
           5       0.93      0.93      0.93        14
           6       0.45      0.71      0.56         7
           7       0.79      0.73      0.76        15
           8       0.00      0.00      0.00         9
           9       0.00      0.00      0.00         9
          10       0.00      0.00      0.00         7
          11       0.00      0.00      0.00         8
          12       0.92      0.80      0.86        15
          13       0.00      0.00      0.00         6
          14       0.00      0.00      0.00         8
          15       0.00      0.00      0.00         8
          16       0.36      1.00      0.53         4
          17       0.40    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [34]:
preds = list()
y_true = list()

for (X, y) in tqdm(test_dataset):
    X = X.unsqueeze(0).to(device)
    out = vanilla_model(X).squeeze(0)
    y_pred = out.argmax()
    preds.append(y_pred.item())
    y_true.append(y.item())


100%|██████████| 186/186 [00:00<00:00, 288.59it/s]


In [35]:
vanilla_accu = accuracy_score(y_true, preds)
print(classification_report(y_true, preds))


              precision    recall  f1-score   support

           0       0.00      0.00      0.00        10
           1       0.64      0.82      0.72        11
           2       0.00      0.00      0.00         9
           3       0.00      0.00      0.00         5
           4       0.20      0.89      0.32         9
           5       0.92      0.79      0.85        14
           6       0.00      0.00      0.00         7
           7       0.38      0.60      0.46        15
           8       0.00      0.00      0.00         9
           9       0.00      0.00      0.00         9
          10       0.12      0.71      0.21         7
          11       0.00      0.00      0.00         8
          12       0.00      0.00      0.00        15
          13       0.00      0.00      0.00         6
          14       0.00      0.00      0.00         8
          15       0.00      0.00      0.00         8
          16       0.22      1.00      0.36         4
          17       0.36    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [36]:
print(vanilla_accu)
print(att_gru_accu)


0.3333333333333333
0.44623655913978494
