In [3]:
from google.colab import drive
import os

drive.mount('/content/drive')
os.listdir("/content/drive/MyDrive")
os.chdir("/content/drive/MyDrive/hmdb51")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
! pip install av
! pip install spikingjelly
#! wget https://raw.githubusercontent.com/pytorch/vision/6de158c473b83cf43344a0651d7c01128c7850e6/references/video_classification/transforms.py



In [5]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader
from torch.optim.lr_scheduler import StepLR
import torchvision
from torchvision import get_video_backend
from torchvision.models.video import r3d_18
from torchvision import transforms
import os
from tqdm.auto import tqdm
import numpy as np
import time
import datetime
import random
import transforms as T
import av

from spikingjelly.activation_based import layer, neuron, surrogate, encoding, functional

In [75]:
class HMDB51CSNN(nn.Module):
  def __init__(self, channels=128):
    super().__init__()

    conv = []
    for i in range (5):
      if conv.__len__() == 0:
        in_channels = 3
      else:
        in_channels = channels

      conv.append(layer.Conv2d(in_channels, channels, kernel_size=3, padding=1, bias=False))
      conv.append(layer.BatchNorm2d(channels))
      conv.append(neuron.IFNode(surrogate_function=surrogate.ATan()))
      conv.append(layer.MaxPool2d(2,2)) # 112->56->28->14->7->3

    self.conv_fc = nn.Sequential(
        *conv,

        layer.Flatten(),
        layer.Dropout(0.5),
        layer.Linear(channels * 3 * 3, 512),

        layer.Dropout(0.5),
        layer.Linear(512, 510),
        neuron.IFNode(surrogate_function=surrogate.ATan()),

        layer.VotingLayer(10)
    )

    functional.set_step_mode(self, step_mode='m')

  def forward(self, x: torch.Tensor):
    return self.conv_fc(x)


In [93]:
# Datasets and Dataloaders for model training ..

val_split = 0.05
num_frames = 16 # 16
clip_steps = 50
num_workers = 0
pin_memory = True
train_tfms = torchvision.transforms.Compose([
                                 T.ToFloatTensorInZeroOne(),
                                 T.Resize((128, 171)),
                                 T.RandomHorizontalFlip(),
                                 T.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
                                 T.RandomCrop((112, 112))
                               ])
test_tfms =  torchvision.transforms.Compose([
                                             T.ToFloatTensorInZeroOne(),
                                             T.Resize((128, 171)),
                                             T.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
                                             T.CenterCrop((112, 112))
                                             ])
hmdb51_train = torchvision.datasets.HMDB51('video_data/', 'test_train_splits/', num_frames,
                                                step_between_clips = clip_steps, fold=1, train=True,
                                                transform=train_tfms, num_workers=num_workers)


hmdb51_test = torchvision.datasets.HMDB51('video_data/', 'test_train_splits/', num_frames,
                                                step_between_clips = clip_steps, fold=1, train=False,
                                                transform=test_tfms, num_workers=num_workers)

total_train_samples = len(hmdb51_train)
total_val_samples = round(val_split * total_train_samples)

print(f"number of train samples {total_train_samples}")
print(f"number of validation samples {total_val_samples}")
print(f"number of test samples {len(hmdb51_test)}")

100%|██████████| 417/417 [02:14<00:00,  3.11it/s]
100%|██████████| 417/417 [02:13<00:00,  3.11it/s]


number of train samples 7577
number of validation samples 379
number of test samples 3161


In [94]:
batch_size = 16
num_workers = 0

kwargs = {'num_workers':num_workers, 'pin_memory':True} if torch.cuda.is_available() else {'num_workers':num_workers}
#kwargs = {'num_workers':num_workers}
#kwargs = {}

hmdb51_train_v1, hmdb51_val_v1 = random_split(hmdb51_train, [total_train_samples - total_val_samples,
                                                                       total_val_samples])

#hmdb51_train_v1.video_clips.compute_clips(16, 1, frame_rate=30)
#hmdb51_val_v1.video_clips.compute_clips(16, 1, frame_rate=30)
#hmdb51_test.video_clips.compute_clips(16, 1, frame_rate=30)

#train_sampler = RandomClipSampler(hmdb51_train_v1.video_clips, 5)
#test_sampler = UniformClipSampler(hmdb51_test.video_clips, 5)

train_loader = DataLoader(hmdb51_train_v1, batch_size=batch_size, shuffle=True, **kwargs)
val_loader   = DataLoader(hmdb51_val_v1, batch_size=batch_size, shuffle=True, **kwargs)
test_loader  = DataLoader(hmdb51_test, batch_size=batch_size, shuffle=False, **kwargs)

In [95]:
batch = next(iter(train_loader))
print(f"Batch type: {type(batch)}")
print(f"Batch length: {len(batch)}")

video, audio, label = next(iter(train_loader))
print(video.shape) # (batch size, channels, frames, height, width)
print(audio.shape)
print(label.shape) # (batch size)

Batch type: <class 'list'>
Batch length: 3
torch.Size([16, 3, 16, 112, 112])
torch.Size([16, 1, 0])
torch.Size([16])


In [96]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr = 1e-3
gamma = 0.7
epochs = 10
config = {}
net = HMDB51CSNN().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
encoder = encoding.PoissonEncoder()

In [None]:
max_test_acc = -1

for epoch in range(epochs):
  start_time = time.time()
  net.train()
  train_loss = 0
  train_acc = 0
  train_samples = 0
  for frame, _, label in train_loader:
    optimizer.zero_grad()
    frame = frame.to(device)
    frame = frame.permute(2,0,1,3,4) # [N,T,C,H,W] -> [T,N,C,H,W]
    label = label.to(device)
    label_onehot = F.one_hot(label,  51).float()

    out_fr = net(frame).mean(0)
    loss = F.mse_loss(out_fr, label_onehot)
    loss.backward()
    optimizer.step()

    train_samples += label.numel()
    train_loss += loss.item() * label.numel()
    train_acc += (out_fr.argmax(1) == label).float().sum().item()

    functional.reset_net(net)

  train_time = time.time()
  train_speed = train_samples / (train_time - start_time)
  train_loss /= train_samples
  train_acc /= train_samples
  print(f'==========epoch={epoch}=============')
  print(f'train_loss={train_loss: .4f}, train_acc={train_acc: .4f}')
  print(f'train_speed={train_speed: .4f}')

  net.eval()
  test_loss = 0
  test_acc = 0
  test_samples = 0
  with torch.no_grad():
    for frame, _, label in test_loader:
      frame = frame.to(device)
      frame = frame.permute(2,0,1,3,4)  # [N, T, C, H, W] -> [T, N, C, H, W]
      label = label.to(device)
      label_onehot = F.one_hot(label, 51).float()
      out_fr = net(frame).mean(0)
      loss = F.mse_loss(out_fr, label_onehot)
      test_samples += label.numel()
      test_loss += loss.item() * label.numel()
      test_acc += (out_fr.argmax(1) == label).float().sum().item()
      functional.reset_net(net)
  test_time = time.time()
  test_speed = test_samples / (test_time - train_time)
  test_loss /= test_samples
  test_acc /= test_samples

  if test_acc > max_test_acc:
    max_test_acc = test_acc

  print(f'epoch = {epoch}, train_loss ={train_loss: .4f}, train_acc ={train_acc: .4f}, test_loss ={test_loss: .4f}, test_acc ={test_acc: .4f}, max_test_acc ={max_test_acc: .4f}')
  print(f'train speed ={train_speed: .4f} images/s, test speed ={test_speed: .4f} images/s')
  print(f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start_time) * (epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n')


train_loss= 0.0194, train_acc= 0.0556
train_speed= 7.0397
