<a href="https://colab.research.google.com/github/rachit2005/UNET-/blob/main/video_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [103]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("trainingdatapro/aggressive-behavior-video-classification")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'aggressive-behavior-video-classification' dataset.
Path to dataset files: /kaggle/input/aggressive-behavior-video-classification


In [104]:
from torchvision.transforms.v2 import RandomHorizontalFlip, RandomVerticalFlip, Compose, Normalize, Resize, ToPILImage
from torch.utils.data import Dataset, ConcatDataset
from torchvision.transforms import ToTensor
from glob import glob
from PIL import Image
import shutil
import torch
import cv2
import os
import gc

In [105]:
from cv2.gapi import video
dataset_url = '/kaggle/input/aggressive-behavior-video-classification/files'

def extract_frames(video_path, output_dir, frame_rate=1):
  cap = cv2.VideoCapture(video_path)
  frame_num = 0
  frame_skip = 4

  while True:
    ret,frame = cap.read()
    if not ret:
      break
    if frame_num % frame_skip == 0:
      frame_path = os.path.join(output_dir , f'frame_{frame_num}.jpg')
      cv2.imwrite(frame_path, frame)
    frame_num += 1

  cap.release() # Moved outside the loop


def get_video(dataset_url):
  video_paths = glob(os.path.join(dataset_url, '*.mp4'))
  return video_paths


In [106]:
class VideoDataset(Dataset):
  def __init__(self,dataset_url , video_url , output_folder , transforms = None):
    self.dataset_url = dataset_url
    self.video_url = video_url
    self.output_folder = output_folder
    self.transforms = transforms

  def __getitem__(self, index):
    label = self.dataset_url.split('/')[-1]
    video_path = self.video_url[index]

    if self.transforms is not None:
      os.makedirs(self.output_folder, exist_ok=True)
      extract_frames(video_path, self.output_folder)

      frames_list = [] # Initialize an empty list to collect frames

      # Sort frame files to maintain temporal order if desired
      frame_files = sorted([f for f in os.listdir(self.output_folder) if f.startswith('frame_') and f.endswith('.jpg')])

      if not frame_files:
          # Handle case where no frames were extracted (e.g., corrupt video or extract_frames failed)
          # You might want to skip this item or return a placeholder/error
          print(f"Warning: No frames found for video: {video_path}")
          # For now, let's return a dummy tensor to allow the pipeline to continue, but a robust solution would handle this gracefully
          dummy_video = torch.zeros(3, 1, 112, 112) # (C, D, H, W)
          dummy_label = torch.tensor(0) # Or -1 for unknown
          shutil.rmtree(self.output_folder) # Clean up
          return dummy_video, dummy_label

      for frame_file in frame_files:
        frame_path = os.path.join(self.output_folder , frame_file)

        image = Image.open(frame_path).convert("RGB") # Ensure image is RGB

        frame_tensor = self.transforms(image) # Apply transforms, should result in (C, H, W)
        frames_list.append(frame_tensor)

      # Stack all frames along a new dimension (depth/time) to get (D, C, H, W)
      video_tensor = torch.stack(frames_list, dim=0)

      # Permute to (C, D, H, W) as expected by Conv3d layers
      video = video_tensor.permute(1, 0, 2, 3)

      shutil.rmtree(self.output_folder) # Clean up temporary frames

      if label == 'aggressive': # Corrected spelling for consistency, original was 'agressive'
        label = torch.tensor(1)
      else:
        label = torch.tensor(0)

      return video , label

  def __len__(self):
    return len(self.video_url)


In [107]:
output_agg = '/kaggle/working/agg_frames'
output_non_agg = '/kaggle/working/non_agg_frames'

effects = Compose([
    ToPILImage(),
    RandomHorizontalFlip(),
    RandomVerticalFlip(),
    Resize(size = (112, 112), antialias = True),
    ToTensor(),
    Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
])

output_agg = '/kaggle/working/agg_frames'
output_non = '/kaggle/working/non_frames'

agg_video_paths = get_video(os.path.join(dataset_url, 'aggressive'))
nonagg_video_paths = get_video(os.path.join(dataset_url, 'non_aggressive'))

agg_dataset = VideoDataset(os.path.join(dataset_url, 'aggressive'), agg_video_paths, output_agg,transforms=effects)
nonagg_dataset = VideoDataset(os.path.join(dataset_url, 'non_aggressive'), nonagg_video_paths, output_non_agg,transforms=effects)

dataset = ConcatDataset([agg_dataset, nonagg_dataset])

print("Total videos in dataset:", len(dataset))
print("Dataset base path:", dataset_url)

Total videos in dataset: 11
Dataset base path: /kaggle/input/aggressive-behavior-video-classification/files


In [108]:
import os

# Walk through the directory and print all files found
for root, dirs, files in os.walk(path):
    print(f"Directory: {root}")
    for file in files:
        print(f"  - {file}")

Directory: /kaggle/input/aggressive-behavior-video-classification
  - aggressive_behavior.csv
Directory: /kaggle/input/aggressive-behavior-video-classification/files
Directory: /kaggle/input/aggressive-behavior-video-classification/files/aggressive
  - 3.mp4
  - 1.mp4
  - 4.mp4
  - 0.mp4
  - 2.mp4
Directory: /kaggle/input/aggressive-behavior-video-classification/files/non_aggressive
  - 5.mp4
  - 3.mp4
  - 1.mp4
  - 4.mp4
  - 0.mp4
  - 2.mp4


In [109]:
from torch.utils.data import DataLoader, random_split
print(len(dataset))

train_size = 8
eval_size = 3

train_dataset , test_dataset = random_split(dataset , [train_size, eval_size])

batch_size = 1

train_dataloader = DataLoader(train_dataset , batch_size = batch_size , shuffle = True , num_workers=0)
test_dataloder = DataLoader(test_dataset , batch_size , shuffle = True)

# print(next(iter(train_dataloader)))

print(len(train_dataloader.dataset))

11
8


In [110]:
! pip install snntorch



In [111]:
import torch
import torch.nn as nn
import torch.ao.quantization as quantization
import snntorch as snn
from snntorch import surrogate


class Conv3dBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.quant = quantization.QuantStub()
        self.dequant = quantization.DeQuantStub()

        self.conv = nn.Conv3d(in_channels, out_channels, 3, padding=1, bias=False)
        self.bn = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.lif = snn.Leaky(
            beta=0.95,
            spike_grad=surrogate.fast_sigmoid(),
            init_hidden=True
        )

    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)

        spk = self.lif(x)
        return spk



class UNET_VIDEO_CLASSIFICATION(nn.Module):
    def __init__(self, num_classes=10, in_channels=3):
        super(UNET_VIDEO_CLASSIFICATION, self).__init__()

        # --- Quantization Stubs ---
        # QuantStub converts floating point tensors to quantized tensors
        self.quant = quantization.QuantStub()
        # DeQuantStub converts quantized tensors back to floating point
        self.dequant = quantization.DeQuantStub()

        # --- Contracting Path (Encoder) ---
        # Level 1
        self.enc1 = Conv3dBlock(in_channels, 32)
        self.pool1 = nn.MaxPool3d(kernel_size=(1,2,2), stride=(1,2,2))


        # Level 2
        self.enc2 = Conv3dBlock(32, 64)
        self.pool2 = nn.MaxPool3d(kernel_size=(1,2,2), stride=(1,2,2))


        # Level 3
        self.enc3 = Conv3dBlock(64, 128)
        self.pool3 = nn.MaxPool3d(kernel_size=(1,2,2), stride=(1,2,2))


        # Level 4 (Bottleneck)
        self.bottleneck = Conv3dBlock(128, 256)

        # --- Classification Head ---
        # Replaces the U-Net Decoder for classification tasks
        self.global_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
      # x shape: [B, C, T, H, W]

      self.reset_snn()

      time_steps = x.shape[2]
      spike_sum = torch.zeros(x.shape[0], self.fc.out_features, device=x.device)

      for t in range(time_steps):
          xt = x[:, :, t:t+1, :, :]

          # Quantize
          xt = self.quant(xt)

          # Encoder
          xt = self.pool1(self.enc1(xt))
          xt = self.pool2(self.enc2(xt))
          xt = self.pool3(self.enc3(xt))

          # Bottleneck
          xt = self.bottleneck(xt)

          # Global Pool
          xt = self.global_pool(xt)
          xt = torch.flatten(xt, 1)
          xt = self.dropout(xt)
          xt = self.fc(xt)

          # Accumulate spikes
          spike_sum += xt

      # Average over time
      out = spike_sum / time_steps

      # Dequantize
      out = self.dequant(out)
      return out

    def reset_snn(self):
      for m in self.modules():
          if isinstance(m, snn.Leaky):
              m.reset_hidden()

    def fuse_model(self):
        """
        Fuses Conv+BN+ReLU layers to save memory and improve speed.
        Required for effective Quantization Aware Training.
        """
        # Fuse the sub-blocks (Conv + BN + ReLU) by specifying full paths from the model
        torch.ao.quantization.fuse_modules(
            self, ['enc1.conv', 'enc1.bn', 'enc1.relu'], inplace=True
        )
        torch.ao.quantization.fuse_modules(
            self, ['enc2.conv', 'enc2.bn', 'enc2.relu'], inplace=True
        )
        torch.ao.quantization.fuse_modules(
            self, ['enc3.conv', 'enc3.bn', 'enc3.relu'], inplace=True
        )
        torch.ao.quantization.fuse_modules(
            self, ['bottleneck.conv', 'bottleneck.bn', 'bottleneck.relu'], inplace=True
        )

In [112]:
model = UNET_VIDEO_CLASSIFICATION(num_classes=2)

print(model)
print(len(dataset))

UNET_VIDEO_CLASSIFICATION(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (enc1): Conv3dBlock(
    (quant): QuantStub()
    (dequant): DeQuantStub()
    (conv): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (bn): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (lif): Leaky()
  )
  (pool1): MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0, dilation=1, ceil_mode=False)
  (enc2): Conv3dBlock(
    (quant): QuantStub()
    (dequant): DeQuantStub()
    (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (bn): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (lif): Leaky()
  )
  (pool2): MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0, dilation=1, ceil_mode=False)
  (enc3): Conv3dBlock(
    (quant): QuantStub()
    (dequant): DeQuantStub()
   

## Quantization aware training

In [113]:
import os
import copy

BATCH_SIZE = 8
EPOCHS = 5
LR = 3e-4

optimizer = torch.optim.Adam(model.parameters() , lr = LR)
loss_fn = nn.CrossEntropyLoss()

def main():

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model.eval() # Put model in eval mode for stable BatchNorm during fusion
  model.fuse_model()
  model.train() # Put model back in train mode for QAT preparation

  # Move model to device BEFORE preparing for QAT
  model.to(device)

  # prepare for QAT
  backend = "fbgemm"
  model.qconfig = quantization.get_default_qat_qconfig(backend)
  quantization.prepare_qat(model , inplace=True)


  # training loop
  print("starting the quantization aware training")

  for epoch in range(EPOCHS): # Corrected: Iterate over range(EPOCHS)
    model.train()
    total_loss = 0
    correct = 0
    total = 0 # Initialize total for accuracy calculation

    for videos, labels in train_dataloader:
      videos = videos.to(device)
      labels = labels.to(device)

      optimizer.zero_grad()
      output = model(videos)
      loss = loss_fn(output , labels)
      loss.backward()
      optimizer.step()

      total_loss += loss.item()
      # Corrected: predicted should be the index of the max value
      _, predicted = torch.max(output.data, 1)

      total += labels.size(0)

      correct += (predicted == labels).sum().item()

    train_loss = total_loss / len(train_dataloader)
    train_accuracy = 100 * correct / total
    print(f"epochs: {epoch+1}/{EPOCHS} | loss: {train_loss:.4f} | accuracy: {train_accuracy:.2f}%")

In [114]:
main()

starting the quantization aware training


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  quantization.prepare_qat(model , inplace=True)


epochs: 1/5 | loss: 0.6939 | accuracy: 50.00%
epochs: 2/5 | loss: 0.6936 | accuracy: 50.00%
epochs: 3/5 | loss: 0.6930 | accuracy: 50.00%
epochs: 4/5 | loss: 0.6926 | accuracy: 50.00%
epochs: 5/5 | loss: 0.6921 | accuracy: 50.00%


In [115]:
model.eval()
model.to('cpu')

# Perform conversion in-place to avoid deepcopy issues
quantization.convert(model , inplace=True)
quantized_model = model # Assign the in-place converted model

print("conversion succesfull")

torch.save(quantized_model.state_dict(), 'quantized_model.pth')
torch.save(model.state_dict() , "model.pth")

conversion succesfull


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  quantization.convert(model , inplace=True)
