In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as pretrained
from torchvision import transforms
from tqdm import tqdm

import requests
from PIL import Image
import pandas as pd
import numpy as np
import imageio
import cv2
import os
from torch.utils.data import Dataset, DataLoader

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


In [3]:
IMG_SIZE = 64
BATCH_SIZE = 8
EPOCHS = 10

MAX_SEQ_LENGTH = 20
NUM_FEATURES = (IMG_SIZE**2)*3
LEARNING_RATE = 0.001

In [4]:
labels = []
data = []

for label in os.listdir('./JHMDB_video/ReCompress_Videos/'):
    labels.append(label)
    for video in os.listdir(f'./JHMDB_video/ReCompress_Videos/{label}'):
        if video.endswith('.avi'):
            data.append((label, f'./JHMDB_video/ReCompress_Videos/{label}/{video}'))

In [5]:
def crop_center_square(frame):
    y, x = frame.shape[0:2]
    min_dim = min(y, x)
    start_x = (x // 2) - (min_dim // 2)
    start_y = (y // 2) - (min_dim // 2)
    return frame[start_y : start_y + min_dim, start_x : start_x + min_dim]


def load_video(path, max_frames=0, resize=(IMG_SIZE, IMG_SIZE)):
    cap = cv2.VideoCapture(path)
    frames = []
    try:
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frame = crop_center_square(frame)
            frame = cv2.resize(frame, resize)
            frame = frame[:, :, [2, 1, 0]]
            frames.append(frame)

            if len(frames) == max_frames:
                break
    finally:
        cap.release()
    return np.array(frames)

In [6]:
buffer = []

In [7]:
resnet50 = pretrained.resnet50(pretrained=True)

def get_embedding(module, input, output):
  buffer.append(output)

resnet50.avgpool.register_forward_hook(get_embedding)

t = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                0.229, 0.224, 0.255]),
        ])

In [8]:
def prepare_video():
    data_list = []
    for i, (label, path) in enumerate(tqdm(data)):
        frames = load_video(path, max_frames=MAX_SEQ_LENGTH)
        frames = frames.reshape(frames.shape[0], -1)
        frames = frames/255
        data_list.append((label, frames))

    return data_list

In [9]:
data_list = prepare_video()

100%|██████████| 928/928 [00:07<00:00, 116.31it/s]


In [10]:
labels_dict = {label: i for i, label in enumerate(labels)}

In [11]:
class VideoDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
        
    def __len__(self):
        return len(self.data[0])
    
    def __getitem__(self, idx):
        return (
            torch.tensor(self.data[idx][1], dtype=torch.float32),
            torch.tensor([1 if self.data[idx][0] == key else 0 for key in labels_dict.keys()], dtype=torch.float32)
        )

In [12]:
data = VideoDataset(data_list)

In [17]:
class Attention(nn.Module):
  def __init__(self, embedding_dim, n_hidden):
    super().__init__()

    self.embedding_dim = embedding_dim
    self.n_hidden = n_hidden

    self.wx = nn.Linear(self.embedding_dim, self.embedding_dim)
    self.wh = nn.Linear(self.n_hidden, self.embedding_dim)
    self.sigmoid = nn.Sigmoid()

  def forward(self, X, h):
    out1 = self.wx(X)
    out2 = self.wh(h)
    a = self.sigmoid(out1+out2)
    
    return torch.mul(a,X)

In [18]:

class EleAttG_GRU(nn.Module):
  def __init__(self, embedding_dim, n_hidden=128, n_classes=None):
    super().__init__()
  
    assert n_classes is not None
    
    self.embedding_dim = embedding_dim
    self.n_hidden = n_hidden
    self.n_classes = n_classes

    self.attention = Attention(self.embedding_dim, self.n_hidden)
    self.grucell = nn.GRUCell(self.embedding_dim, self.n_hidden)
    self.fc = nn.Sequential(
        nn.Linear(self.n_hidden, self.n_hidden),
        nn.ReLU(),
        nn.Linear(self.n_hidden, self.n_classes),
        nn.Softmax(dim=1)
    )
  
  def forward(self, X):
    '''
      x = batch_size * frames * embedding_dim

    '''
    h = torch.zeros(X.shape[0], self.n_hidden).to(device)
    for i in range(X.shape[1]):
      X[:,i,:] = self.attention(X[:,i,:].clone(),h)
      h = self.grucell(X[:,i,:].clone(), h)

    return self.fc(h)

In [19]:
model = EleAttG_GRU(12288, 256, len(labels)).to(device)

In [20]:
training_loader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()

In [22]:
for epoch in range(EPOCHS):
    for i, (X,y) in enumerate(tqdm(training_loader)):
        optimizer.zero_grad()

        y_pred = model(X)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}")