In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')
os.listdir("/content/drive/MyDrive")
os.chdir("/content/drive/MyDrive/hmdb51")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Load HMDB51

Before running below code, download HMDB51 first.

In [None]:
! pip install av
! pip install spikingjelly
#! wget https://raw.githubusercontent.com/pytorch/vision/6de158c473b83cf43344a0651d7c01128c7850e6/references/video_classification/transforms.py



In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import random_split, DataLoader
from torch.optim.lr_scheduler import StepLR
import torchvision
from torchvision import get_video_backend
from torchvision.models.video import r3d_18
from torchvision import transforms
import os
from tqdm.auto import tqdm
import numpy as np
import time
import datetime
import random
import transforms as T
import av

from spikingjelly.activation_based import layer, neuron, surrogate, encoding, functional

In [None]:
# Datasets and Dataloaders for model training ..

val_split = 0.05 # 검증 데이터 비율 (5%)
num_frames = 16
clip_steps = 50 # 연속된 프레임 사이의 간격
num_workers = 8
pin_memory = True # GPU 사용 시 데이터 로딩 성능 최적화
train_tfms = torchvision.transforms.Compose([
                                 T.ToFloatTensorInZeroOne(), # 0~1 범위의 float32 tensor
                                 T.Resize((128, 171)), # 128x171
                                 T.RandomHorizontalFlip(),# 랜덤으로 좌우 반전
                                 T.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
                                 T.RandomCrop((112, 112)) # 112x112
                               ])
test_tfms =  torchvision.transforms.Compose([
                                             T.ToFloatTensorInZeroOne(),
                                             T.Resize((128, 171)),
                                             T.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
                                             T.CenterCrop((112, 112))
                                             ])
hmdb51_train = torchvision.datasets.HMDB51('video_data/', 'test_train_splits/', num_frames,
                                                step_between_clips = clip_steps, fold=1, train=True,
                                                transform=train_tfms, num_workers=num_workers)


hmdb51_test = torchvision.datasets.HMDB51('video_data/', 'test_train_splits/', num_frames,
                                                step_between_clips = clip_steps, fold=1, train=False,
                                                transform=test_tfms, num_workers=num_workers)

total_train_samples = len(hmdb51_train)
total_val_samples = round(val_split * total_train_samples)

print(f"number of train samples {total_train_samples}")
print(f"number of validation samples {total_val_samples}")
print(f"number of test samples {len(hmdb51_test)}")

100%|██████████| 417/417 [07:45<00:00,  1.12s/it]
100%|██████████| 417/417 [01:57<00:00,  3.55it/s]


number of train samples 7577
number of validation samples 379
number of test samples 3161


In [None]:
batch_size = 16
num_workers = 0

kwargs = {'num_workers':num_workers, 'pin_memory':True} if torch.cuda.is_available() else {'num_workers':num_workers}
#kwargs = {'num_workers':num_workers}
#kwargs = {}

hmdb51_train_v1, hmdb51_val_v1 = random_split(hmdb51_train, [total_train_samples - total_val_samples,
                                                                       total_val_samples])

#hmdb51_train_v1.video_clips.compute_clips(16, 1, frame_rate=30)
#hmdb51_val_v1.video_clips.compute_clips(16, 1, frame_rate=30)
#hmdb51_test.video_clips.compute_clips(16, 1, frame_rate=30)

#train_sampler = RandomClipSampler(hmdb51_train_v1.video_clips, 5)
#test_sampler = UniformClipSampler(hmdb51_test.video_clips, 5)

train_loader = DataLoader(hmdb51_train_v1, batch_size=batch_size, shuffle=True, **kwargs)
val_loader   = DataLoader(hmdb51_val_v1, batch_size=batch_size, shuffle=True, **kwargs)
test_loader  = DataLoader(hmdb51_test, batch_size=batch_size, shuffle=False, **kwargs)

In [None]:
batch = next(iter(train_loader))
print(f"Batch type: {type(batch)}")
print(f"Batch length: {len(batch)}")

video, audio, label = next(iter(train_loader))
print(video.shape) # (batch size, channels, frames, height, width)
print(audio.shape)
print(label.shape) # (batch size)

Batch type: <class 'list'>
Batch length: 3
torch.Size([16, 3, 16, 112, 112])
torch.Size([16, 1, 0])
torch.Size([16])


In [None]:
import joblib

# 데이터 로더 저장
joblib.dump(train_loader, "hmdb51_train.pkl")
joblib.dump(test_loader, "hmdb51_test.pkl")

['hmdb51_test.pkl']

앞으로는 이것만 하면 됨

In [None]:
import joblib

# 저장된 데이터 로더 불러오기
train_loader = joblib.load("hmdb51_train.pkl")
test_loader = joblib.load("hmdb51_test.pkl")

In [None]:
print(f'Number of batches in train_loader: {len(train_loader)}')
print(f'num_workers: {train_loader.num_workers}')
print(f'Batch size: {train_loader.batch_size}')

video, audio, label = next(iter(train_loader))
print(f"Video shape: {video.shape}")  # (batch_size, frames, height, width)
print(f"Audio shape: {audio.shape}")  # 오디오 데이터 크기
print(f"Label shape: {label.shape}")  # (batch_size,)

Number of batches in train_loader: 450
num_workers: 0
Batch size: 16




Video shape: torch.Size([16, 3, 16, 112, 112])
Audio shape: torch.Size([16, 1, 0])
Label shape: torch.Size([16])


# ANN

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import time
import matplotlib.pyplot as plt

In [None]:
class ANN(nn.Module):
    def __init__(self, input_size, output_size):
        super(ANN, self).__init__()
        self.network = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_size, output_size),
            nn.ReLU() # 출력값은 항상 0 이상, 출력 크기는 (batch_size, output_size)=(16,51)
        )

    def forward(self, x):
        x = self.network(x)
        return x

'''
class SNN(nn.Module):
    def __init__(self, input_size, output_size, tau=10.0, v_th=1.0, dt=1.0):
        super(SNN, self).__init__()
        self.weights = None
        self.lif = neuron.LIFNode(tau=tau, v_threshold=v_th, detach_reset=True)

    def set_weights(self, weights, photo_responsivity):
        self.weights = weights * (max(abs(photo_responsivity)) / max(abs(weights)))

    def forward(self, x):
        if self.weights is None:
          raise valueError("Weights have not been set. Use set_weights() first.")

        membrane_potential = torch.matmul(input_spikes, self.weights)
        return self.lif(membrane_potential)
'''

'\nclass SNN(nn.Module):\n    def __init__(self, input_size, output_size, tau=10.0, v_th=1.0, dt=1.0):\n        super(SNN, self).__init__()\n        self.weights = None\n        self.lif = neuron.LIFNode(tau=tau, v_threshold=v_th, detach_reset=True)\n\n    def set_weights(self, weights, photo_responsivity):\n        self.weights = weights * (max(abs(photo_responsivity)) / max(abs(weights)))\n\n    def forward(self, x):\n        if self.weights is None:\n          raise valueError("Weights have not been set. Use set_weights() first.")\n\n        membrane_potential = torch.matmul(input_spikes, self.weights)\n        return self.lif(membrane_potential)\n'

**Train**

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
input_size = 3 * 16 * 112 * 112
output_size = 51
ann = ANN(input_size, output_size).to(device)

num_epochs = 100
learning_rate = 0.001
criterion = nn.MSELoss()
optimizer = optim.SGD(ann.parameters(), lr=learning_rate)

In [None]:
max_test_acc = -1
train_acc_list = []
test_acc_list = []

for epoch in range(num_epochs):
  start_time = time.time()
  ann.train()
  train_loss = 0
  train_acc = 0
  train_samples = 0
  for frame, _, label in train_loader:
    optimizer.zero_grad()
    frame = frame.to(device).permute(2,0,1,3,4) # [N,T,C,H,W] -> [T,N,C,H,W]
    label = label.to(device)
    label_onehot = F.one_hot(label,  51).float()

    out_fr = ann(frame) # 크기: (batch_size, output_size)=(16,51)
    loss = F.mse_loss(out_fr, label_onehot)
    loss.backward()
    optimizer.step()

    train_samples += label.numel()
    train_loss += loss.item() * label.numel()
    train_acc += (out_fr.argmax(1) == label).float().sum().item()
  train_time = time.time()
  train_speed = train_samples / (train_time - start_time)
  train_loss /= train_samples
  train_acc /= train_samples
  train_acc_list.append(train_acc)
  print(f'epoch {epoch}: train_loss={train_loss: .4f}, train_acc={train_acc: .4f}, train_speed={train_speed: .4f}')

  ann.eval()
  test_loss = 0
  test_acc = 0
  test_samples = 0
  with torch.no_grad():
    for frame, _, label in test_loader:
      frame = frame.to(device)
      frame = frame.permute(2,0,1,3,4)  # [N, T, C, H, W] -> [T, N, C, H, W]
      label = label.to(device)
      label_onehot = F.one_hot(label, 51).float()
      out_fr = ann(frame)
      loss = F.mse_loss(out_fr, label_onehot)
      test_samples += label.numel()
      test_loss += loss.item() * label.numel()
      test_acc += (out_fr.argmax(1) == label).float().sum().item()
  test_time = time.time()
  test_speed = test_samples / (test_time - train_time)
  test_loss /= test_samples
  test_acc /= test_samples
  test_acc_list.append(test_acc)

  if test_acc > max_test_acc:
    max_test_acc = test_acc


# weight 저장
torch.save(ann.state_dict(), "ann_weights.pth")

AttributeError: module 'av' has no attribute 'AVError'

# SNN

In [None]:
class SNN(nn.Module):
  def __init__(self, input_size, output_size, tau=2.0):
    super(SNN, self).__init__()
    self.network = nn.Sequential(
      nn.Flatten(),
      nn.Linear(input_size, output_size),
      neuron.LIFNode(tau=tau)
    )

  def forward(self, x):
    x = self.network(x)
    return x

In [None]:
# weight 불러오기
ann_weights = torch.load("ann_weights.pth", map_location=device)

# snn weight으로 변환
max_weight = max(abs(w.max().item()) for w in ann_weights.value())
snn_weights = {k: v / max_weight for k, v in ann_weights.items()}

# SNN 초기화 및 snn_weights 적용
snn = SNN(input_size, output_size).to(device)
snn.load_state_dict(snn_weights)

In [None]:
# train
snn.eval()
test_loss = 0
test_acc = 0
test_samples = 0

with torch.no_grad():
  for frame, _, label in test_loader:
    frame = frame.to(device)
    frame = frame.permute(2,0,1,3,4)  # [N, T, C, H, W] -> [T, N, C, H, W]
    label = label.to(device)
    label_onehot = F.one_hot(label, 51).float()
    out_fr = snn(frame)
    loss = F.mse_loss(out_fr, label_onehot)
    test_samples += label.numel()
    test_loss += loss.item() * label.numel()
    test_acc += (out_fr.argmax(1) == label).float().sum().item()
test_time = time.time()
test_speed = test_samples / (test_time - train_time)
test_loss /= test_samples
test_acc /= test_samples
test_acc_list.append(test_acc)