<a href="https://colab.research.google.com/github/passerbyWt/videoClassification/blob/main/NNDL_video_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# define constants
DATA_DIR = "content/UCF101"
LABEL_DIR = "content/UCF101_labels"
CACHE_DIR = "content/cache"
FRAMES_PER_CLIP = 10
IMG_SIZE = 224    # video frames would be resized to IMG_SIZE * IMG_SIZE

# Download UCF101 dataset and extract frames
References


*   https://www.kaggle.com/pevogam/starter-ucf101-with-pytorch
*   https://blog.csdn.net/HW140701/article/details/115864277



In [None]:
import os
import sys
import copy
import subprocess
import numpy as np
from prettytable import PrettyTable
from multiprocessing import Pool
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

# !pip install av
# import av
import cv2

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.nn.utils.rnn import pad_sequence
from torch.optim import Adam
from torchvision import transforms
from torchvision import models
from torch.utils.data import Dataset, DataLoader

OS = sys.platform

if not OS == 'win32' and not os.path.exists(CACHE_DIR):
  !mkdir -p $CACHE_DIR
if not OS == 'win32' and not os.path.exists(DATA_DIR):
  !mkdir -p $DATA_DIR
  !wget --no-check-certificate --limit-rate 100M -O UCF101.rar https://www.crcv.ucf.edu/data/UCF101/UCF101.rar
  !unrar x ./UCF101.rar $DATA_DIR > /dev/null
if not OS == 'win32' and not os.path.exists(LABEL_DIR):
  !mkdir $LABEL_DIR
  !wget --no-check-certificate -O UCF101_labels.zip https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip
  !unzip -d $LABEL_DIR ./UCF101_labels.zip > /dev/null

# check if data is ready
if (os.path.exists(DATA_DIR+'/UCF-101/ApplyEyeMakeup') and os.path.exists(LABEL_DIR+'/ucfTrainTestlist')):
  print("ready to go")
else:
  print("Failed to download data\nPlease manually download files from\nhttps://www.crcv.ucf.edu/data/UCF101/UCF101.rar\nand\nhttps://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip")


In [None]:
# enable GPU
if torch.cuda.is_available():
  print("Using GPU!")
  device = torch.device("cuda")
else:
  print("Using CPU... this is going to be slow...")
  device = torch.device("cpu")

In [None]:
# video loader
'''
  params:
    1. the path the video (e.g. "ApplyEyeMakepu/v_ApplyEyeMakeup_g01_c01.avi")
  return:
    sample FRAMES_PER_CLIP frames from the video (evenly distributed along the timeline),
    return a tensor (FRAMES_PER_CLIP x height x width x color_channels) that stores these frames
'''
def video_loader(filename):
  filename = DATA_DIR + '/UCF-101/' + filename
  if not os.path.exists(filename):
    raise Exception("Cannot find file " + filename)
  frames = []
  cap = cv2.VideoCapture(filename)
  while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
      break
    else:
      frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
      frame = cv2.resize(frame, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
      frames.append(frame)
  cap.release()
  if len(frames) < 1:
    raise Exception("Invalid video")
  if len(frames) > FRAMES_PER_CLIP:
    ratio = len(frames) / FRAMES_PER_CLIP
    frames_ = []
    for i in range(FRAMES_PER_CLIP):
      idx = int(i * ratio)
      frames_.append(frames[idx])
    frames = frames_
  frames = np.stack(frames, axis=0)
  return torch.tensor(frames)

In [None]:
v_test = video_loader('ApplyEyeMakeup/v_ApplyEyeMakeup_g01_c01.avi')
v_test.shape

In [None]:
# class mapping
class_to_id = {}
id_to_class = {}
with open(LABEL_DIR+'/ucfTrainTestlist/classInd.txt', 'r') as f:
  for line in f:
    line = line.split()
    line[0] = int(line[0]) - 1
    class_to_id[line[1]] = line[0]
    id_to_class[line[0]] = line[1]

In [None]:
# define the UCF101 dataset class
class UCF101(Dataset):
  def __init__(self, _class_to_id, _subset, _video_loader, _transform=None):
    if _subset == 'train':
      train_data = []
      with open(LABEL_DIR+'/ucfTrainTestlist/trainlist01.txt', 'r') as f1:
        for i, line in enumerate(f1):
          if i % 5 == 0:
            # save that for dev set
            continue
          line = line.split()
          train_data.append((int(line[1])-1, line[0]))   # (caption, video_filename)
      self.data = train_data
      f1.close()
    elif _subset == 'test':
      test_data = []
      with open(LABEL_DIR+'/ucfTrainTestlist/testlist01.txt', 'r') as f1:
        for line in f1:
          line_ = line.split('/')
          test_data.append((_class_to_id[line_[0]], line.strip()))
      self.data = test_data
      f1.close()
    elif _subset == 'dev':
      dev_data = []
      with open(LABEL_DIR+'/ucfTrainTestlist/trainlist01.txt', 'r') as f1:
        for i, line in enumerate(f1):
          if i % 5 != 0:
            # the sample is already in training set
            continue
          line = line.split()
          dev_data.append((int(line[1])-1, line[0]))
      self.data = dev_data
      f1.close()
    else:
      raise Exception("_subset should have value 'train', 'test', or 'dev'")
    self.video_loader = _video_loader
    self.transform = _transform
    
  def __len__(self):
    return len(self.data)
    
  def __getitem__(self, idx):
    res = self.data[idx]
    enc_video = self.video_loader(res[1])
    if self.transform is not None:
      enc_video = self.transform(enc_video)
    return (res[0], enc_video)

In [None]:
def custom_collate(batch):
  captions, frames = [], []
  for caption, frame in batch:
    captions.append(caption)    # label of current video sample
    frames.append(frame)        # sampled sequence of frames from the video
  return (
    torch.tensor(captions),
    pad_sequence(frames, batch_first=True)
  )

In [None]:
def tfs(enc_video):
  enc_video = torch.permute(enc_video, [0, 3, 1, 2]).float() / 255
  transfrom = torch.nn.Sequential(
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
  )
  return enc_video

# Load pre-trained image classifier
Model:
*   CoAtNet (Zihang Dai, et al. 2021)

References:
*   <a href="https://arxiv.org/abs/2106.04803">Research Paper</a>
*   <a href="https://github.com/chinhsuanwu/coatnet-pytorch/blob/master/coatnet.py">Code</a>

In [None]:
# download model
if not OS == 'win32' and not os.path.exists('model'):
  !mkdir -p model/coatnet
  !wget -O model/coatnet/. --no-check-certificate https://github.com/chinhsuanwu/coatnet-pytorch/raw/master/coatnet.py
from model.coatnet.coatnet import CoAtNet

**Define the image classification model**

In [None]:
def set_parameter_requires_grad(model, feature_extracting):
  if feature_extracting:
    for param in model.parameters():
      param.requires_grad = False

def ResNet(num_classes, use_pretrained=False):
  model = models.resnet152(pretrained=use_pretrained)
  set_parameter_requires_grad(model, use_pretrained)
  num_ftrs = model.fc.in_features
  model.fc = torch.nn.Sequential(
      torch.nn.Linear(num_ftrs, num_classes),
      torch.nn.Sigmoid()
  )
  return model.to(device)

In [None]:
# create an instance of the image classification model
# image_classifier = CoAtNet((IMG_SIZE, IMG_SIZE), 3, [2,2,3,5,2], [64,96,192,384,768], num_classes=len(id_to_class.keys())).to(device)
image_classifier = ResNet(len(id_to_class.keys()), use_pretrained=True)

# visualize the model
def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
# count_parameters(image_classifier)

**Train the image classification model with UCF101 dataset**

In [None]:
batch_size = 2
train_loader = DataLoader(
    UCF101(class_to_id, 'train', video_loader, _transform=tfs),
    collate_fn=custom_collate,
    batch_size=batch_size,
    shuffle=True
)
dev_loader = DataLoader(
    UCF101(class_to_id, 'dev', video_loader, _transform=tfs),
    collate_fn=custom_collate,
    batch_size=batch_size,
    shuffle=True
)
optimizer = Adam(image_classifier.parameters(), lr=0.001)
criterion = CrossEntropyLoss()

def handle_batch(model, criterion, batch, _device):
  captions, clips = batch
  # move tensors to device
  captions = captions.to(_device)
  clips = clips.to(_device)
  # extract frames
  captions = captions.repeat_interleave(clips.shape[1])
  clips = clips.view(-1, 3, IMG_SIZE, IMG_SIZE)
  preds = model(clips)
  loss = criterion(preds, captions)
  return preds, loss

# training params
max_epoch = 10
early_stopping = -0.000001
show_train_loss = False   # disable this could speedup training

# try loading trained model
# if os.path.exists(CACHE_DIR+'/image_classifier.pt'):
#   image_classifier.load_state_dict(torch.load(CACHE_DIR+'/image_classifier.pt'))
#   print("model state loaded")

best_loss = None
for epoch in range(max_epoch):
  tot_train_loss = 0.0
  tot_eval_loss = 0.0
  image_classifier.train()
  with tqdm(total=len(train_loader)+len(dev_loader)) as pbar:
    for batch in train_loader:
      optimizer.zero_grad()
      _, loss = handle_batch(image_classifier, criterion, batch, device)
      loss.backward()
      optimizer.step()
      if show_train_loss:
        tot_train_loss += loss.item()
      pbar.update(1)
    image_classifier.eval()
    with torch.no_grad():
      for batch in dev_loader:
        _, loss = handle_batch(image_classifier, criterion, batch, device)
        tot_eval_loss += loss.item()
        pbar.update(1)
  
  if show_train_loss:
    avg_train_loss = tot_train_loss / (len(train_loader) * batch_size * FRAMES_PER_CLIP)
    avg_eval_loss = tot_eval_loss / (len(dev_loader) * batch_size * FRAMES_PER_CLIP)
    tqdm.write("epoch={} avg_train_loss={:.4f} avg_eval_loss={:.4f}".format(epoch, avg_train_loss, avg_eval_loss))
  else:
    avg_eval_loss = tot_eval_loss / (len(dev_loader) * batch_size * FRAMES_PER_CLIP)
    tqdm.write("epoch={} avg_eval_loss={:.4f}".format(epoch, avg_eval_loss))

  # stopping criteria
  if best_loss is None:
    best_loss = tot_eval_loss
  if tot_eval_loss > (best_loss * (1+early_stopping)):
    tqdm.write("Eval loss not improving, stop training.")
    break
  elif tot_eval_loss < best_loss:
    best_loss = tot_eval_loss
    # save weights periodically
    torch.save(image_classifier.state_dict(), CACHE_DIR+'/image_classifier.pt')

**Evaluate the trained image classifier**

In [None]:
image_classifier = CoAtNet((IMG_SIZE, IMG_SIZE), 3, [2,2,3,5,2], [64,96,192,384,768], num_classes=len(id_to_class.keys())).to(device)
# image_classifier = ResNet(len(id_to_class.keys()))

try:
  image_classifier.load_state_dict(torch.load(CACHE_DIR+'/image_classifier.pt'))
except Exception as e:
  print("Exception: " + str(e))
  print("Please make sure the 'image_classifier.pt' file exists")

In [None]:
batch_size = 2
test_loader = DataLoader(
    UCF101(class_to_id, 'test', video_loader, _transform=tfs),
    collate_fn=custom_collate,
    batch_size=batch_size,
    shuffle=False
)
criterion = CrossEntropyLoss()

def handle_batch(model, criterion, batch, _device):
  captions, clips = batch
  # move tensors to device
  captions = captions.to(_device)
  clips = clips.to(_device)
  # extract frames
  captions = captions.repeat_interleave(clips.shape[1])
  clips = clips.view(-1, 3, IMG_SIZE, IMG_SIZE)
  preds = model(clips)
  loss = criterion(preds, captions)
  preds_ = torch.argmax(preds, dim=1)
  acc = (preds_ == captions)
  return acc, loss

image_classifier.eval()
tot_eval_loss = 0.0
with torch.no_grad():
  accs = []
  for batch in tqdm(test_loader):
    acc, loss = handle_batch(image_classifier, criterion, batch, device)
    tot_eval_loss += loss.cpu().item()
    accs.append(acc)
  accs = torch.cat(accs, dim=0)
  num_of_samples = accs.shape[0]
  avg_eval_loss = tot_eval_loss / (len(test_loader) * batch_size * FRAMES_PER_CLIP)
  tqdm.write("accuracy={:.4f}, avg loss={:.4f}".format(accs.sum().item() / num_of_samples, avg_eval_loss))

In [None]:
import matplotlib.pyplot as plt
testset = UCF101(class_to_id, 'test', video_loader, _transform=tfs)

In [None]:
testset = UCF101(class_to_id, 'test', video_loader, _transform=tfs)
sample1 = testset[0]

In [None]:
sample1[1][0,:,:,:]

In [None]:
batch_size = 1
test_loader = DataLoader(
    testset,
    collate_fn=custom_collate,
    batch_size=batch_size,
    shuffle=True
)

label_0, img_0 = next(iter(test_loader))
label_0 = label_0.repeat_interleave(img_0.shape[1])
img_0 = img_0.view(-1, 3, IMG_SIZE, IMG_SIZE)

In [None]:
idx = 5
print(id_to_class[label_0[idx].item()])
plt.imshow(img_0[idx,:,:,:].permute(1,2,0))

In [None]:
idx = 826
label_1, img_1 = testset[idx]
print(id_to_class[label_1])
print(img_1.shape)
plt.imshow(img_1[0,:,:,:].permute(1,2,0))

In [None]:
idx = 7
pred_ = image_classifier(img_0.view(-1,3,224,224).to(device))
prob, pred = torch.max(pred_, dim=1)
#print(pred.shape)
print("Predicted class is '{}' with {:.4f} confidence".format(id_to_class[pred[idx].item()], prob[idx].item()))
plt.imshow(img_0[idx,:,:,:].permute(1,2,0))

# LSTM video classifier

Load model

In [None]:
model_to_use = 'ResNet'
use_pretrained = True

if model_to_use == 'CoAtNet':
  image_classifier = CoAtNet((IMG_SIZE, IMG_SIZE), 3, [2,2,3,5,2], [64,96,192,384,768], num_classes=len(id_to_class.keys())).to(device)
else:
  image_classifier = ResNet(len(id_to_class.keys()))

if use_pretrained:
  try:
    # image_classifier.load_state_dict(torch.load(CACHE_DIR+'/image_classifier.pt'))
    set_parameter_requires_grad(image_classifier, True)
  except Exception as e:
    print("Exception: " + str(e))
    print("Please make sure the 'image_classifier.pt' file exists")

image_classifier.fc = torch.nn.Identity()

In [None]:
class LSTM(torch.nn.Module):
  def __init__(self, num_classes, hidden_size):
    super(LSTM, self).__init__()
    self.lstm = torch.nn.LSTM(hidden_size, hidden_size, 3, batch_first=True)
    self.fc1 = torch.nn.Linear(hidden_size, 128)
    self.fc2 = torch.nn.Linear(128, num_classes)
  
  def forward(self, x):
    N = x.shape[0]  # current batch zise
    x, (_, _) = self.lstm(x)
    x = self.fc1(x[:,-1,:].view(N, -1))
    x = torch.sigmoid(x)
    x = torch.nn.functional.dropout(x, p=0.0)
    x = self.fc2(x)
    x = torch.sigmoid(x)
    return x

Train video classifier

In [None]:
if model_to_use == 'CoAtNet':
  hidden_size = 768
else:
  hidden_size = 2048

video_classifier = LSTM(len(id_to_class.keys()), hidden_size).to(device)
# video_classifier

In [None]:
count_parameters(video_classifier)

In [None]:
# batch_size = 3
# train_loader = DataLoader(
#     UCF101(class_to_id, 'train', video_loader, _transform=tfs),
#     collate_fn=custom_collate,
#     batch_size=batch_size,
#     shuffle=True
# )
# dev_loader = DataLoader(
#     UCF101(class_to_id, 'dev', video_loader, _transform=tfs),
#     collate_fn=custom_collate,
#     batch_size=batch_size,
#     shuffle=True
# )
# optimizer = Adam(video_classifier.parameters(), lr=0.003)
# criterion = CrossEntropyLoss()

# def handle_batch(model, criterion, batch, _device):
#   captions, clips = batch
#   # move tensors to device
#   captions = captions.to(_device)
#   clips = clips.to(_device)
#   # encode with pre-trained image classifier
#   N = clips.shape[0]  # current batch zise
#   L = clips.shape[1]  # current sequence length
#   clips = image_classifier(clips.view(-1, 3, IMG_SIZE, IMG_SIZE))
#   H = clips.shape[-1] # current hidden size
#   preds = model(clips.view(N, L, H))
#   loss = criterion(preds, captions)
#   return preds, loss

# # try loading trained model
# if os.path.exists(CACHE_DIR+'/video_classifier.pt'):
#   video_classifier.load_state_dict(torch.load(CACHE_DIR+'/video_classifier.pt'))
#   print("model state loaded")

# # training params
# max_epoch = 5
# early_stopping = -0.000001
# show_train_loss = True   # disable this could speedup training


# best_loss = None
# for epoch in range(max_epoch):
#   tot_train_loss = 0.0
#   tot_eval_loss = 0.0
#   video_classifier.train()
#   with tqdm(total=len(train_loader)+len(dev_loader)) as pbar:
#     for batch in train_loader:
#       optimizer.zero_grad()
#       _, loss = handle_batch(video_classifier, criterion, batch, device)
#       loss.backward()
#       optimizer.step()
#       if show_train_loss:
#         tot_train_loss += loss.item()
#       pbar.update(1)
#     video_classifier.eval()
#     with torch.no_grad():
#       for batch in dev_loader:
#         _, loss = handle_batch(video_classifier, criterion, batch, device)
#         tot_eval_loss += loss.item()
#         pbar.update(1)
  
#   if show_train_loss:
#     avg_train_loss = tot_train_loss / (len(train_loader) * batch_size)
#     avg_eval_loss = tot_eval_loss / (len(dev_loader) * batch_size)
#     tqdm.write("epoch={} avg_train_loss={:.4f} avg_eval_loss={:.4f}".format(epoch, avg_train_loss, avg_eval_loss))
#   else:
#     avg_eval_loss = tot_eval_loss / (len(dev_loader) * batch_size)
#     tqdm.write("epoch={} avg_eval_loss={:.4f}".format(epoch, avg_eval_loss))

#   # stopping criteria
#   if best_loss is None:
#     best_loss = tot_eval_loss
#   if tot_eval_loss > (best_loss * (1+early_stopping)):
#     tqdm.write("Eval loss not improving, stop training.")
#     break
#   elif tot_eval_loss < best_loss:
#     best_loss = tot_eval_loss
#     # save weights periodically
#     torch.save(video_classifier.state_dict(), CACHE_DIR+'/video_classifier.pt')

In [None]:
# # try loading trained model
# if os.path.exists(CACHE_DIR+'/video_classifier.pt'):
#   video_classifier.load_state_dict(torch.load(CACHE_DIR+'/video_classifier.pt'))
#   print("model state loaded")

# batch_size = 3
# test_loader = DataLoader(
#     UCF101(class_to_id, 'test', video_loader, _transform=tfs),
#     collate_fn=custom_collate,
#     batch_size=batch_size,
#     shuffle=False
# )
# criterion = CrossEntropyLoss()

# def handle_batch(model, criterion, batch, _device):
#   captions, clips = batch
#   # move tensors to device
#   captions = captions.to(_device)
#   clips = clips.to(_device)
#   # encode with pre-trained image classifier
#   N = clips.shape[0]  # current batch zise
#   L = clips.shape[1]  # current sequence length
#   clips = image_classifier(clips.view(-1, 3, IMG_SIZE, IMG_SIZE))
#   H = clips.shape[-1] # current hidden size
#   preds = model(clips.view(N, L, H))
#   loss = criterion(preds, captions)
#   preds_ = torch.argmax(preds, dim=1)
#   acc = (preds_ == captions)
#   return acc, loss

# video_classifier.eval()
# tot_eval_loss = 0.0
# with torch.no_grad():
#   accs = []
#   for batch in tqdm(test_loader):
#     acc, loss = handle_batch(video_classifier, criterion, batch, device)
#     tot_eval_loss += loss.cpu().item()
#     accs.append(acc)
#   accs = torch.cat(accs, dim=0)
#   num_of_samples = accs.shape[0]
#   avg_eval_loss = tot_eval_loss / (len(test_loader) * batch_size * FRAMES_PER_CLIP)
#   tqdm.write("accuracy={:.4f}, avg loss={:.4f}".format(accs.sum().item() / num_of_samples, avg_eval_loss))

# CNN+RNN

In [None]:
class ResCNNEncoder(nn.Module):
    def __init__(self, fc_hidden1=512, fc_hidden2=512, drop_p=0.3, CNN_embed_dim=300):
        """Load the pretrained ResNet-152 and replace top fc layer."""
        super(ResCNNEncoder, self).__init__()

        self.fc_hidden1, self.fc_hidden2 = fc_hidden1, fc_hidden2
        self.drop_p = drop_p

        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-1]      # delete the last fc layer.
        self.resnet = nn.Sequential(*modules)
        self.fc1 = nn.Linear(resnet.fc.in_features, fc_hidden1)
        self.bn1 = nn.BatchNorm1d(fc_hidden1, momentum=0.01)
        self.fc2 = nn.Linear(fc_hidden1, fc_hidden2)
        self.bn2 = nn.BatchNorm1d(fc_hidden2, momentum=0.01)
        self.fc3 = nn.Linear(fc_hidden2, CNN_embed_dim)
        
    def forward(self, x_3d):
        cnn_embed_seq = []
        for t in range(x_3d.size(1)):
            # ResNet CNN
            with torch.no_grad():
                x = self.resnet(x_3d[:, t, :, :, :])  # ResNet
                x = x.view(x.size(0), -1)             # flatten output of conv

            # FC layers
            x = self.bn1(self.fc1(x))
            x = F.relu(x)
            x = self.bn2(self.fc2(x))
            x = F.relu(x)
            x = F.dropout(x, p=self.drop_p, training=self.training)
            x = self.fc3(x)

            cnn_embed_seq.append(x)

        # swap time and sample dim such that (sample dim, time dim, CNN latent dim)
        cnn_embed_seq = torch.stack(cnn_embed_seq, dim=0).transpose_(0, 1)
        # cnn_embed_seq: shape=(batch, time_step, input_size)

        return cnn_embed_seq

In [None]:
class DecoderRNN(nn.Module):
    def __init__(self, CNN_embed_dim=300, h_RNN_layers=3, h_RNN=256, h_FC_dim=128, drop_p=0.3, num_classes=50):
        super(DecoderRNN, self).__init__()

        self.RNN_input_size = CNN_embed_dim
        self.h_RNN_layers = h_RNN_layers   # RNN hidden layers
        self.h_RNN = h_RNN                 # RNN hidden nodes
        self.h_FC_dim = h_FC_dim
        self.drop_p = drop_p
        self.num_classes = num_classes

        self.LSTM = nn.LSTM(
            input_size=self.RNN_input_size,
            hidden_size=self.h_RNN,        
            num_layers=h_RNN_layers,       
            batch_first=True,       # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
        )

        self.fc1 = nn.Linear(self.h_RNN, self.h_FC_dim)
        self.fc2 = nn.Linear(self.h_FC_dim, self.num_classes)

    def forward(self, x_RNN):
        
        self.LSTM.flatten_parameters()
        RNN_out, (h_n, h_c) = self.LSTM(x_RNN, None)  
        """ h_n shape (n_layers, batch, hidden_size), h_c shape (n_layers, batch, hidden_size) """ 
        """ None represents zero initial hidden state. RNN_out has shape=(batch, time_step, output_size) """

        # FC layers
        x = self.fc1(RNN_out[:, -1, :])   # choose RNN_out at the last time step
        x = F.relu(x)
        x = F.dropout(x, p=self.drop_p, training=self.training)
        x = self.fc2(x)

        return x

In [None]:
# EncoderCNN architecture
CNN_fc_hidden1, CNN_fc_hidden2 = 1024, 768
CNN_embed_dim = 512      # latent dim extracted by 2D CNN
img_x, img_y = 256, 342  # resize video 2d frame size
dropout_p = 0.0          # dropout probability

# DecoderRNN architecture
RNN_hidden_layers = 3
RNN_hidden_nodes = 512
RNN_FC_dim = 256

image_classifier = ResCNNEncoder(fc_hidden1=CNN_fc_hidden1, fc_hidden2=CNN_fc_hidden2, drop_p=dropout_p, CNN_embed_dim=300)
video_classifier = DecoderRNN(CNN_embed_dim=300, h_RNN_layers=RNN_hidden_layers, h_RNN=RNN_hidden_nodes, h_FC_dim=RNN_FC_dim, drop_p=dropout_p, num_classes=len(id_to_class.keys()))

In [None]:
def train(encoder_model, decoder_model, optimizer, criterion, train_loader, dev_loader, _device):
  encoder_model.train()
  decoder_model.train()

  losses = []
  scores = []

  with tqdm(total=len(train_loader)+len(dev_loader)) as pbar:
    for batch in train_loader:
      captions, clips = batch
      captions = captions.to(_device)
      clips = clips.to(_device)

      N = clips.shape[0]  # current batch size
      L = clips.shape[1]  # current sequence length

      optimizer.zero_grad()

      # encoded = encoder_model(clips.view(-1, clips.shape[2], clips.shape[3], clips.shape[4]))
      # outputs = decoder_model(encoded.view(N, L, -1))
      encoded = encoder_model(clips)
      outputs = decoder_model(encoded)

      loss = criterion(outputs, captions)
      loss.backward()
      optimizer.step()
      pbar.update(1)
    
    encoder_model.eval()
    decoder_model.eval()
    with torch.no_grad():
      for batch in dev_loader:
        captions, clips = batch
        captions = captions.to(_device)
        clips = clips.to(_device)

        N = clips.shape[0]  # current batch size
        L = clips.shape[1]  # current sequence length

        # encoded = encoder_model(clips.view(-1, clips.shape[2], clips.shape[3], clips.shape[4]))
        # outputs = decoder_model(encoded.view(N, L, -1))
        encoded = encoder_model(clips)
        outputs = decoder_model(encoded)

        loss = criterion(outputs, captions)
        losses.append(loss.item())
        preds = torch.max(outputs, 1)[1]
        score = accuracy_score(captions.cpu().data.squeeze().numpy(), preds.cpu().data.squeeze().numpy())
        scores.append(score)
        pbar.update(1)
  
  return sum(losses) / len(losses), sum(scores) / len(scores)

In [None]:
def eval(encoder_model, decoder_model, criterion, test_loader, _device):
  encoder_model.eval()
  decoder_model.eval()

  losses = []
  scores = []

  with torch.no_grad():
    for batch in tqdm(test_loader):
      captions, clips = batch
      captions = captions.to(_device)
      clips = clips.to(_device)

      N = clips.shape[0]  # current batch size
      L = clips.shape[1]  # current sequence length

      # encoded = encoder_model(clips.view(-1, clips.shape[2], clips.shape[3], clips.shape[4]))
      # outputs = decoder_model(encoded.view(N, L, -1))
      encoded = encoder_model(clips)
      outputs = decoder_model(encoded)

      loss = criterion(outputs, captions)
      losses.append(loss.item())
      preds = torch.max(outputs, 1)[1]
      score = accuracy_score(captions.cpu().data.squeeze().numpy(), preds.cpu().data.squeeze().numpy())
      scores.append(score)
  
  return sum(losses) / len(losses), sum(scores) / len(scores)

In [None]:
batch_size = 2
max_epoch = 10
early_stopping = 0.01
learning_rate = 0.001

train_loader = DataLoader(
    UCF101(class_to_id, 'train', video_loader, _transform=tfs),
    collate_fn=custom_collate,
    batch_size=batch_size,
    shuffle=True
)
dev_loader = DataLoader(
    UCF101(class_to_id, 'dev', video_loader, _transform=tfs),
    collate_fn=custom_collate,
    batch_size=batch_size,
    shuffle=True
)
test_loader = DataLoader(
    UCF101(class_to_id, 'test', video_loader, _transform=tfs),
    collate_fn=custom_collate,
    batch_size=batch_size,
    shuffle=False
)

train_params = list(image_classifier.parameters()) + list(video_classifier.parameters())
optimizer = Adam(train_params, lr=learning_rate)
criterion = CrossEntropyLoss()

best_loss = 99.9
best_acc = 0.0

for epoch in range(max_epoch):
  train_loss, train_acc = train(image_classifier, video_classifier, optimizer, criterion, train_loader, dev_loader, device)
  if train_loss < best_loss:
    torch.save(image_classifier.state_dict(), CACHE_DIR+'/image_classifier.pt')
    torch.save(video_classifier.state_dict(), CACHE_DIR+'/video_classifier.pt')
    best_loss = train_loss
  if train_loss > best_loss * (1 + early_stopping):
    print("Training loss not improving, stop training.")
    break
  print("Epoch {}: train loss = {:.4f}, train acc = {:.4f}".format(epoch, train_loss, train_acc))

test_loss, test_acc = eval(image_classifier, video_classifier, criterion, test_loader, device)