# 1. Setup and Imports

In [None]:
# Install necessary libraries (uncomment if needed)
# !pip install torch torchvision torchaudio
# !pip install wandb
# !pip install ipywidgets
# !pip install opencv-python
# !pip install scikit-learn
# !pip install matplotlib seaborn


In [None]:
import os
import logging
from typing import List, Tuple, Dict, Optional, Union

import glob
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import shufflenet_v2_x1_0
from torchvision.models.video import r3d_18, R3D_18_Weights

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, precision_recall_fscore_support

import cv2
from tqdm import tqdm
import gc

# Import Weights & Biases
import wandb

# Import ipywidgets for UI
import ipywidgets as widgets
from IPython.display import display, clear_output

# Clear cache and collect garbage
gc.collect()
torch.cuda.empty_cache()

# 2. Configuration

In [None]:
dataset_root = "/kaggle/input/ucf101"

class VideoClassificationConfig:
    """Enhanced configuration for video classification with I3D and ShuffleNet."""
    def __init__(
        self,
        epochs: int = 50,
        batch_size: int = 16,
        learning_rate: float = 0.0001,
        num_workers: int = 4,
        videos_per_class: int = 10,
        model_type: str = 'i3dshufflenet',  # 'i3d', 'shufflenet', or 'i3dshufflenet'
        pretrained: bool = True,
        accumulation_steps: int = 2,  # For gradient accumulation
        use_amp: bool = True,  # Use Automatic Mixed Precision
        max_batch_size: int = 32,  # Maximum batch size to attempt
        wandb_project: str = 'har-i3dshufflenet-ucf101',
        checkpoint_path: str = 'best_model_i3dshufflenet.pth',
        resume: bool = True,  # Flag to resume training from checkpoint
        scheduler_mode: str = 'plateau',  # 'plateau' or 'cosine'
        scheduler_factor: float = 0.1,  # Factor for ReduceLROnPlateau
        scheduler_patience: int = 5,  # Patience for ReduceLROnPlateau
        early_stop_patience: int = 10,  # Patience for Early Stopping
        checkpoint_interval: int = 10,  # Save checkpoint every N epochs
        temporal_module: str = 'transformer',  # 'i3d', 'shuffle', or 'transformer'
        use_attention: bool = True,
        aux_loss: bool = True,
        transformer_layers: int = 2,
        transformer_heads: int = 4
    ):
        self.epochs = epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.num_workers = num_workers
        self.videos_per_class = videos_per_class
        self.model_type = model_type
        self.pretrained = pretrained
        self.accumulation_steps = accumulation_steps
        self.use_amp = use_amp
        self.max_batch_size = max_batch_size
        self.wandb_project = wandb_project
        self.checkpoint_path = checkpoint_path
        self.resume = resume

        # Scheduler parameters
        self.scheduler_mode = scheduler_mode
        self.scheduler_factor = scheduler_factor
        self.scheduler_patience = scheduler_patience
        
        self.early_stop_patience = early_stop_patience
        self.checkpoint_interval = checkpoint_interval

        # Temporal module
        self.temporal_module = temporal_module
        self.transformer_layers = transformer_layers
        self.transformer_heads = transformer_heads

        # Attention mechanism
        self.use_attention = use_attention

        # Auxiliary loss
        self.aux_loss = aux_loss

        # Device configuration
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Predefined action classes
        self.classes = [
            "ApplyEyeMakeup", "ApplyLipstick", "Archery", "BabyCrawling", "BalanceBeam",
            "BandMarching", "BaseballPitch", "Basketball", "BasketballDunk", "BenchPress",
            "Biking", "Billiards", "BlowDryHair", "BlowingCandles", "BodyWeightSquats",
            "Bowling", "BoxingPunchingBag", "BoxingSpeedBag", "BreastStroke", "BrushingTeeth",
            "CleanAndJerk", "CliffDiving", "CricketBowling", "CricketShot", "CuttingInKitchen",
            "Diving", "Drumming", "Fencing", "FieldHockeyPenalty", "FloorGymnastics",
            "FrisbeeCatch", "FrontCrawl", "GolfSwing", "Haircut", "HammerThrow",
            "Hammering", "HandstandPushups", "HandstandWalking", "HeadMassage", "HighJump",
            "HorseRace", "HorseRiding", "HulaHoop", "IceDancing", "JavelinThrow",
            "JugglingBalls", "JumpingJack", "JumpRope", "Kayaking", "Knitting",
            "LongJump", "Lunges", "MilitaryParade", "Mixing", "MoppingFloor",
            "Nunchucks", "ParallelBars", "PizzaTossing", "PlayingCello", "PlayingDaf",
            "PlayingDhol", "PlayingFlute", "PlayingGuitar", "PlayingPiano", "PlayingSitar",
            "PlayingTabla", "PlayingViolin", "PoleVault", "PommelHorse", "PullUps",
            "Punch", "PushUps", "Rafting", "RockClimbingIndoor", "RopeClimbing",
            "Rowing", "SalsaSpin", "ShavingBeard", "Shotput", "SkateBoarding",
            "Skiing", "Skijet", "SkyDiving", "SoccerJuggling", "SoccerPenalty",
            "StillRings", "SumoWrestling", "Surfing", "Swing", "TableTennisShot",
            "TaiChi", "TennisSwing", "ThrowDiscus", "TrampolineJumping", "Typing",
            "UnevenBars", "VolleyballSpiking", "WalkingWithDog", "WallPushups", "WritingOnBoard",
            "YoYo"
        ]

# 3. Utility Functions

In [None]:
def format_frames(frame, output_size):
    """Format frames to tensor with specified size."""
    frame = cv2.resize(frame, output_size)
    frame = frame / 255.0  # Normalize to [0,1]
    return frame

def frames_from_video_file(video_path, n_frames=32, output_size=(224, 224), frame_step=15):
    """Extract frames from video file."""
    result = []
    src = cv2.VideoCapture(str(video_path))

    video_length = int(src.get(cv2.CAP_PROP_FRAME_COUNT))
    need_length = 1 + (n_frames - 1) * frame_step

    if need_length > video_length:
        start = 0
    else:
        max_start = video_length - need_length
        start = random.randint(0, max_start + 1)

    src.set(cv2.CAP_PROP_POS_FRAMES, start)
    ret, frame = src.read()

    if not ret:
        return np.zeros((n_frames, output_size[1], output_size[0], 3))

    result.append(format_frames(frame, output_size))

    for _ in range(n_frames - 1):
        for _ in range(frame_step):
            ret, frame = src.read()
        if ret:
            frame = format_frames(frame, output_size)
            result.append(frame)
        else:
            # Pad with zeros if no more frames
            result.append(np.zeros_like(result[0]))

    src.release()

    # Ensure exactly n_frames are returned
    result = result[:n_frames]
    while len(result) < n_frames:
        result.append(np.zeros_like(result[0]))

    result = np.array(result)
    return result

def adjust_batch_size(config, exception):
    """Adjust batch size dynamically in case of OOM."""
    if isinstance(exception, RuntimeError) and 'out of memory' in str(exception).lower():
        if config.batch_size > 1:
            config.batch_size = max(1, config.batch_size // 2)
            logging.warning(f"OOM detected. Reducing batch size to {config.batch_size}.")
            return True
    return False

# 4. Dataset Preparation

In [None]:
class VideoDataset(Dataset):
    """Enhanced video dataset with improved frame sampling and preprocessing."""
    def __init__(
        self,
        file_paths: List[str],
        targets: List[int],
        config: VideoClassificationConfig,
        n_frames: int = 32,
        input_size: Tuple[int, int] = (224, 224)
    ):
        self.file_paths = file_paths
        self.targets = targets
        self.n_frames = n_frames
        self.input_size = input_size
        self.config = config

        # Data augmentation transforms
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(
                brightness=0.2,
                contrast=0.2,
                saturation=0.2
            ),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

    def __len__(self) -> int:
        return len(self.file_paths)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        try:
            # Extract and process frames
            video_frames = self._extract_frames(self.file_paths[idx])

            # Apply transforms frame by frame to reduce memory usage
            transformed_frames = np.array([self.transform(frame) for frame in video_frames])

            # Convert to tensor
            video_tensor = torch.from_numpy(transformed_frames).float()

            # **Correct Permutation**: [n_frames, C, H, W] -> [C, T, H, W]
            video_tensor = video_tensor.permute(1, 0, 2, 3)

            label = self.targets[idx]

            return video_tensor, torch.tensor(label, dtype=torch.long)

        except Exception as e:
            logging.error(f"Error processing video {self.file_paths[idx]}: {e}")
            # Return dummy data to prevent breaking the DataLoader
            dummy_frames = torch.zeros(3, self.n_frames, *self.input_size)
            return dummy_frames, torch.tensor(0, dtype=torch.long)

    def _extract_frames(self, video_path: str) -> np.ndarray:
        return frames_from_video_file(
            video_path,
            n_frames=self.n_frames,
            output_size=self.input_size
        )

def prepare_dataset(config: VideoClassificationConfig):
    """Prepare dataset by collecting video file paths and labels."""
    file_paths = []
    targets = []

    # **Update the dataset path accordingly**
    dataset_root = dataset_root

    for i, cls in enumerate(config.classes):
        # Corrected glob pattern with recursive search
        search_pattern = os.path.join(dataset_root, "UCF101", "UCF-101", cls, "**", "*.avi")
        sub_file_paths = glob.glob(search_pattern, recursive=True)[:config.videos_per_class]
        
        if not sub_file_paths:
            logging.warning(f"No .avi files found for class '{cls}' in '{search_pattern}'.")

        file_paths += sub_file_paths
        targets += [i] * len(sub_file_paths)

    # Check if any video files were found
    if not file_paths:
        raise ValueError("No video files found. Please check the dataset path and file extensions.")

    # Shuffle the dataset
    combined = list(zip(file_paths, targets))
    random.shuffle(combined)

    if combined:
        file_paths, targets = zip(*combined)
        file_paths = list(file_paths)
        targets = list(targets)
    else:
        raise ValueError("No data found after shuffling.")

    # Split dataset into training and validation sets
    train_paths, val_paths, train_targets, val_targets = train_test_split(
        file_paths, targets, test_size=0.2, random_state=42, stratify=targets
    )

    return train_paths, val_paths, train_targets, val_targets

# 5. Model Definitions

In [None]:
import math

class SEBlock(nn.Module):
    """Squeeze-and-Excitation Block."""
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.fc1 = nn.Linear(channel, channel // reduction, bias=True)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(channel // reduction, channel, bias=True)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        b, c, t, h, w = x.size()
        # Squeeze: Global Average Pooling
        y = F.adaptive_avg_pool3d(x, 1).view(b, c)
        # Excitation: FC -> ReLU -> FC -> Sigmoid
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(b, c, 1, 1, 1)
        # Scale: Channel-wise multiplication
        return x * y.expand_as(x)

class TemporalAttention(nn.Module):
    """Temporal Attention Module."""
    def __init__(self, channels: int, reduction: int = 16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1, 1)
        return x * y.expand_as(x)

class SpatialAttention(nn.Module):
    """Spatial Attention Module."""
    def __init__(self, channels: int):
        super().__init__()
        self.conv = nn.Conv3d(channels, 1, kernel_size=7, padding=3)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        attn = self.sigmoid(self.conv(x))  # [B, 1, T, H, W]
        return x * attn

class HybridAttention(nn.Module):
    """Hybrid Attention combining Temporal and Spatial Attention."""
    def __init__(self, channels: int, reduction: int = 16):
        super().__init__()
        self.temporal_attn = TemporalAttention(channels, reduction)
        self.spatial_attn = SpatialAttention(channels)
    
    def forward(self, x):
        x = self.temporal_attn(x)
        x = self.spatial_attn(x)
        return x

class TransformerTemporalEncoder(nn.Module):
    """Transformer-based Temporal Encoder."""
    def __init__(self, embed_dim: int, num_heads: int, num_layers: int, dropout: float = 0.1):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=embed_dim * 4,
            dropout=dropout,
            activation='relu',
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(embed_dim)
    
    def forward(self, x):
        # x: [B, T, C]
        x = self.transformer_encoder(x)  # [B, T, C]
        x = self.norm(x)
        return x

class ChannelShuffle(nn.Module):
    def __init__(self, groups: int):
        super().__init__()
        self.groups = groups

    def forward(self, x):
        batch, channels, time, height, width = x.size()
        channels_per_group = channels // self.groups
        x = x.view(batch, self.groups, channels_per_group, time, height, width)
        x = torch.transpose(x, 1, 2).contiguous()
        x = x.view(batch, channels, time, height, width)
        return x

class TemporalShuffleBlock(nn.Module):
    def __init__(self, channels: int, temporal_stride: int = 1):
        super().__init__()
        self.temporal_conv = nn.Conv3d(
            channels, channels, 
            kernel_size=(3, 1, 1),
            stride=(temporal_stride, 1, 1),
            padding=(1, 0, 0),
            groups=channels
        )
        self.bn = nn.GroupNorm(num_groups=32, num_channels=channels)
        self.shuffle = ChannelShuffle(groups=4)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.relu(self.bn(self.temporal_conv(x)))
        x = self.shuffle(x)
        return x

class AttentionBasedFusion(nn.Module):
    """Attention-Based Feature Fusion Module."""
    def __init__(self, i3d_channels: int, shufflenet_channels: int):
        super().__init__()
        self.query = nn.Linear(i3d_channels, i3d_channels)
        self.key = nn.Linear(shufflenet_channels, shufflenet_channels)
        self.value = nn.Linear(shufflenet_channels, i3d_channels)
        self.softmax = nn.Softmax(dim=-1)
        self.fc = nn.Linear(i3d_channels + shufflenet_channels, i3d_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, i3d_feat, shufflenet_feat):
        # i3d_feat: [B, C_i3d, T, H, W]
        # shufflenet_feat: [B, C_shuff, T, H, W]
        
        # Global Average Pooling
        i3d_pooled = F.adaptive_avg_pool3d(i3d_feat, 1).view(i3d_feat.size(0), -1)  # [B, C_i3d]
        shuff_pooled = F.adaptive_avg_pool3d(shufflenet_feat, 1).view(shufflenet_feat.size(0), -1)  # [B, C_shuff]
        
        # Compute attention scores
        Q = self.query(i3d_pooled)  # [B, C_i3d]
        K = self.key(shuff_pooled)   # [B, C_shuff]
        scores = torch.matmul(Q, K.transpose(0, 1)) / math.sqrt(Q.size(-1))  # [B, B]
        attn_weights = self.softmax(scores)  # [B, B]
        
        # Apply attention
        V = self.value(shuff_pooled)  # [B, C_i3d]
        attn_output = torch.matmul(attn_weights, V)  # [B, C_i3d]
        
        # Fuse features
        fused = torch.cat((i3d_pooled, attn_output), dim=1)  # [B, C_i3d + C_i3d]
        fused = self.relu(self.fc(fused))  # [B, C_i3d]
        
        # Reshape to [B, C_i3d, 1, 1, 1] and expand
        fused = fused.view(fused.size(0), fused.size(1), 1, 1, 1)
        fused = fused.expand_as(i3d_feat)  # [B, C_i3d, T, H, W]
        
        # Add residual connection
        fused = fused + i3d_feat  # [B, C_i3d, T, H, W]
        
        return fused

class EnhancedI3DShuffleNet(nn.Module):
    """
    Enhanced hybrid architecture combining I3D and ShuffleNet concepts
    with advanced features for video classification.
    """
    def __init__(
        self,
        num_classes: int,
        pretrained: bool = True,
        dropout_prob: float = 0.5,
        temporal_module: str = 'transformer',  # 'i3d', 'shuffle', or 'transformer'
        use_attention: bool = True,
        aux_loss: bool = False,
        transformer_layers: int = 2,
        transformer_heads: int = 4
    ):
        super().__init__()
        
        # I3D backbone initialization
        if pretrained:
            weights = R3D_18_Weights.DEFAULT
            self.i3d_backbone = r3d_18(weights=weights)
        else:
            self.i3d_backbone = r3d_18(weights=None)
            
        # ShuffleNet backbone initialization
        self.shuffle_backbone = shufflenet_v2_x1_0(pretrained=pretrained)
        
        # Get feature dimensions
        self.i3d_features = self.i3d_backbone.fc.in_features  # Typically 512 for r3d_18
        self.shuffle_features = self.shuffle_backbone.fc.in_features  # Typically 1024 for shufflenet_v2_x1_0
        
        # Remove original fully connected layers
        self.i3d_backbone.fc = nn.Identity()
        self.shuffle_backbone.fc = nn.Identity()
        
        # Temporal modeling
        self.temporal_module = temporal_module
        if temporal_module == 'shuffle':
            self.temporal_blocks = nn.ModuleList([
                TemporalShuffleBlock(self.i3d_features // 2),
                TemporalShuffleBlock(self.i3d_features // 2)
            ])
        elif temporal_module == 'transformer':
            self.temporal_encoder = TransformerTemporalEncoder(
                embed_dim=self.i3d_features,
                num_heads=transformer_heads,
                num_layers=transformer_layers,
                dropout=0.1
            )
        else:
            raise ValueError(f"Unsupported temporal_module: {temporal_module}")
        
        # Attention mechanism
        self.use_attention = use_attention
        if use_attention:
            self.attention = HybridAttention(self.i3d_features)
        
        # Feature fusion
        self.fusion = AttentionBasedFusion(self.i3d_features, self.shuffle_features)
        
        # Enhanced classifier head with residual connections and GroupNorm
        self.classifier = nn.Sequential(
            nn.Dropout3d(dropout_prob),
            nn.Conv3d(self.i3d_features, self.i3d_features // 2, kernel_size=1),
            nn.GroupNorm(num_groups=16, num_channels=self.i3d_features // 2),
            nn.ReLU(inplace=True),
            nn.Dropout3d(dropout_prob),
            nn.Conv3d(self.i3d_features // 2, num_classes, kernel_size=1)
        )
        
        # Auxiliary classifier for deep supervision
        self.aux_loss = aux_loss
        if aux_loss:
            self.aux_classifier = nn.Sequential(
                nn.AdaptiveAvgPool3d(1),
                nn.Flatten(),
                nn.Linear(self.i3d_features, self.i3d_features // 2),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout_prob),
                nn.Linear(self.i3d_features // 2, num_classes)
            )
        
    def _process_shuffle_features(self, x: torch.Tensor) -> torch.Tensor:
        b, c, t, h, w = x.shape
        x = x.transpose(1, 2).contiguous()  # [B, T, C, H, W]
        x = x.view(-1, c, h, w)  # [B*T, C, H, W]
        x = self.shuffle_backbone(x)  # [B*T, C_shuff]
        x = x.view(b, t, -1, 1, 1)  # [B, T, C_shuff, 1, 1]
        x = x.transpose(1, 2).contiguous()  # [B, C_shuff, T, 1, 1]
        return x.expand(-1, -1, t, h, w)  # [B, C_shuff, T, H, W]
    
    def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        # I3D feature extraction
        i3d_features = self.i3d_backbone(x)  # [B, C_i3d, T, H, W]
        
        # ShuffleNet feature extraction
        shuffle_features = self._process_shuffle_features(x)  # [B, C_shuff, T, H, W]
        
        # Temporal modeling
        if self.temporal_module == 'shuffle':
            split_size = i3d_features.size(1) // 2
            x1, x2 = torch.split(i3d_features, split_size, dim=1)
            x1 = self.temporal_blocks[0](x1)
            x2 = self.temporal_blocks[1](x2)
            i3d_features = torch.cat([x1, x2], dim=1)
        elif self.temporal_module == 'transformer':
            b, c, t, h, w = i3d_features.size()
            i3d_pooled = F.adaptive_avg_pool3d(i3d_features, (1, h, w)).view(b, c, t)  # [B, C, T]
            i3d_pooled = i3d_pooled.permute(0, 2, 1)  # [B, T, C]
            i3d_pooled = self.temporal_encoder(i3d_pooled)  # [B, T, C]
            i3d_pooled = i3d_pooled.permute(0, 2, 1).unsqueeze(-1).unsqueeze(-1)  # [B, C, T, 1, 1]
            i3d_features = i3d_features + i3d_pooled  # Residual connection
        
        # Apply attention if enabled
        if self.use_attention:
            i3d_features = self.attention(i3d_features)
        
        # Feature fusion using Attention-Based Fusion
        fused_features = self.fusion(i3d_features, shuffle_features)  # [B, C_i3d, T, H, W]
        
        # Classification
        output = self.classifier(fused_features)  # [B, num_classes, T, H, W]
        output = F.adaptive_avg_pool3d(output, 1).view(output.size(0), -1)  # [B, num_classes]
        
        if self.aux_loss and self.training:
            aux_output = self.aux_classifier(i3d_features)  # [B, num_classes]
            return output, aux_output
                
        return output

def create_model(config: VideoClassificationConfig) -> nn.Module:
    """Factory method to create appropriate video classification model."""
    num_classes = len(config.classes)

    if config.model_type == 'i3dshufflenet':
        model = EnhancedI3DShuffleNet(
            num_classes=num_classes,
            pretrained=config.pretrained,
            dropout_prob=0.5,
            temporal_module=config.temporal_module,  # 'i3d', 'shuffle', or 'transformer'
            use_attention=config.use_attention,
            aux_loss=config.aux_loss,
            transformer_layers=config.transformer_layers,
            transformer_heads=config.transformer_heads
        )
    else:
        raise ValueError(f"Unsupported model type: {config.model_type}")

    model = model.to(config.device)
    return model

# 6. Training and Validation Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device, scaler, config):
    """Train model for one epoch with enhanced logging and progress tracking."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    batches = len(dataloader)

    # Progress bar for the entire epoch
    progress_bar = tqdm(enumerate(dataloader), total=batches, desc="Training")

    optimizer.zero_grad()

    for batch_idx, (videos, labels) in progress_bar:
        videos, labels = videos.to(device, non_blocking=True), labels.to(device, non_blocking=True)

        with torch.amp.autocast(device_type='cuda', enabled=config.use_amp):
            outputs = model(videos)
            if config.aux_loss and isinstance(outputs, tuple):
                main_output, aux_output = outputs
                loss1 = criterion(main_output, labels)
                loss2 = criterion(aux_output, labels)
                loss = loss1 + 0.4 * loss2 
            else:
                loss = criterion(outputs, labels)
            loss = loss / config.accumulation_steps

        # Gradient accumulation
        scaler.scale(loss).backward()

        if (batch_idx + 1) % config.accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            torch.cuda.empty_cache()

        # Compute batch statistics
        running_loss += loss.item() * config.accumulation_steps
        if config.aux_loss and isinstance(outputs, tuple):
            _, predicted = main_output.max(1)
        else:
            _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # Compute and update progress bar
        current_loss = running_loss / (batch_idx + 1)
        current_acc = 100. * correct / total

        progress_bar.set_postfix({
            'Loss': f'{current_loss:.4f}',
            'Accuracy': f'{current_acc:.2f}%'
        })

        # Log to Weights & Biases every batch
        wandb.log({
            'Train/Loss': current_loss,
            'Train/Accuracy': current_acc,
            'Train/Batch': batch_idx + 1
        })

    # Compute epoch-level metrics
    epoch_loss = running_loss / batches
    epoch_acc = 100. * correct / total

    logging.info(f"\nTraining Epoch Summary:")
    logging.info(f"Epoch Loss: {epoch_loss:.4f}")
    logging.info(f"Epoch Accuracy: {epoch_acc:.2f}%")

    # Log epoch metrics to Weights & Biases
    wandb.log({
        'Train/Epoch Loss': epoch_loss,
        'Train/Epoch Accuracy': epoch_acc
    })

    return epoch_loss, epoch_acc

def validate(model, dataloader, criterion, device, config):
    """Validate model performance with enhanced logging and tracking."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    batches = len(dataloader)
    all_preds = []
    all_labels = []

    # Progress bar for validation
    progress_bar = tqdm(dataloader, total=batches, desc="Validation")

    with torch.no_grad():
        for batch_idx, (videos, labels) in enumerate(progress_bar):
            videos, labels = videos.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            with torch.amp.autocast(device_type='cuda', enabled=config.use_amp):
                outputs = model(videos)
                loss = criterion(outputs, labels)

            # Compute batch statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Collect predictions for final analysis
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            # Compute and update progress bar
            current_loss = running_loss / (batch_idx + 1)
            current_acc = 100. * correct / total

            progress_bar.set_postfix({
                'Loss': f'{current_loss:.4f}',
                'Accuracy': f'{current_acc:.2f}%'
            })

            # Log to Weights & Biases every batch
            wandb.log({
                'Validation/Loss': current_loss,
                'Validation/Accuracy': current_acc,
                'Validation/Batch': batch_idx + 1
            })

    # Compute validation-level metrics
    val_loss = running_loss / batches
    val_acc = 100. * correct / total

    # Calculate additional metrics with zero_division=0
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted', zero_division=0)

    # Generate classification report
    class_report = classification_report(all_labels, all_preds, target_names=config.classes, zero_division=0)
    logging.info(f"\nValidation Summary:")
    logging.info(f"Validation Loss: {val_loss:.4f}")
    logging.info(f"Validation Accuracy: {val_acc:.2f}%")
    logging.info(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}")

    # Log metrics to Weights & Biases
    wandb.log({
        'Validation/Epoch Loss': val_loss,
        'Validation/Epoch Accuracy': val_acc,
        'Validation/Precision': precision,
        'Validation/Recall': recall,
        'Validation/F1-Score': f1
    })

    return val_loss, val_acc, all_preds, all_labels, class_report

# 7. Training Loop

In [None]:
def load_checkpoint(config, model, optimizer, scheduler, scaler):
    """Load checkpoint if available."""
    if os.path.isfile(config.checkpoint_path):
        logging.info(f"Loading checkpoint from '{config.checkpoint_path}'")
        checkpoint = torch.load(config.checkpoint_path, map_location=config.device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        scaler.load_state_dict(checkpoint['scaler_state_dict'])
        start_epoch = checkpoint['epoch']
        best_val_acc = checkpoint.get('best_val_acc', 0)
        logging.info(f"Loaded checkpoint '{config.checkpoint_path}' (Epoch {start_epoch})")
        return start_epoch, best_val_acc
    else:
        logging.info(f"No checkpoint found at '{config.checkpoint_path}'. Starting fresh.")
        return 0, 0

def get_current_lr(optimizer):
    """Retrieve the current learning rate from the optimizer."""
    for param_group in optimizer.param_groups:
        return param_group['lr']

def plot_training_history(history):
    """Plot training and validation metrics."""
    plt.figure(figsize=(12, 4))

    # Accuracy subplot
    plt.subplot(1, 2, 1)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    # Loss subplot
    plt.subplot(1, 2, 2)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.tight_layout()
    plt.show()

def plot_confusion_matrix_custom(y_true, y_pred, class_names):
    """Generate and plot confusion matrix."""
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(20, 16))
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=class_names,
        yticklabels=class_names
    )
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.show()

# 9. Execution

In [None]:
# def main():

config = VideoClassificationConfig(
    model_type='i3dshufflenet',
    epochs=50,
    batch_size=16,
    learning_rate=0.0001,
    accumulation_steps=2,
    use_amp=True,
    wandb_project='thanhnx',
    checkpoint_path='best_model_i3dshufflenet.pth',
    resume=True,
    scheduler_mode='plateau',
    scheduler_factor=0.1,
    scheduler_patience=5,
    early_stop_patience=10,
    checkpoint_interval=10,
    temporal_module='transformer',  # 'i3d', 'shuffle', or 'transformer'
    use_attention=True,
    aux_loss=True,
    transformer_layers=2,
    transformer_heads=4
)

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("training_i3dshufflenet.log"),
        logging.StreamHandler()
    ]
)

wandb.init(
    project=config.wandb_project,
    config=vars(config),
    resume='allow' if config.resume else False,
    job_type='training',
    settings=wandb.Settings(init_timeout=120)
)

# Prepare dataset
train_paths, val_paths, train_targets, val_targets = prepare_dataset(config)
logging.info(f"Training samples: {len(train_paths)}, Validation samples: {len(val_paths)}")
wandb.config.update({
    'train_samples': len(train_paths),
    'val_samples': len(val_paths)
})

train_dataset = VideoDataset(train_paths, train_targets, config)
val_dataset = VideoDataset(val_paths, val_targets, config)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
    pin_memory=True,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
    pin_memory=True,
    persistent_workers=True
)

# Create model
model = create_model(config)
logging.info(f"Using model type: {config.model_type}")

# Loss and optimizer
from sklearn.utils.class_weight import compute_class_weight

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_targets),
    y=train_targets
)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(config.device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

# Initialize ReduceLROnPlateau scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',  # Since we are monitoring accuracy
    factor=config.scheduler_factor,
    patience=config.scheduler_patience,
    verbose=True,
    min_lr=1e-6
)

# Initialize mixed precision scaler
scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp)

start_epoch = 0
best_val_acc = 0
if config.resume:
    start_epoch, best_val_acc = load_checkpoint(config, model, optimizer, scheduler, scaler)

# Training loop
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
epochs_no_improve = 0
early_stop = False

for epoch in range(start_epoch, config.epochs):
    if early_stop:
        logging.info("Early stopping triggered.")
        break

    print(f'Epoch {epoch+1}/{config.epochs}')
    logging.info(f'Epoch {epoch+1}/{config.epochs}')
    wandb.log({'epoch': epoch + 1})

    try:
        # Training phase
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, config.device, scaler, config
        )

        # Validation phase
        val_loss, val_acc, all_preds, all_labels, class_report = validate(
            model, val_loader, criterion, config.device, config
        )

        # Step the scheduler with validation accuracy
        scheduler.step(val_acc)

        # Retrieve and log the current learning rate
        current_lr = get_current_lr(optimizer)
        wandb.log({'Learning Rate': current_lr})
        logging.info(f"Current Learning Rate: {current_lr}")

        # Check for improvement
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            epochs_no_improve = 0

            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_acc': best_val_acc,
                'scheduler_state_dict': scheduler.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
            }, config.checkpoint_path)
            logging.info(f"New best model saved with validation accuracy: {best_val_acc:.2f}%")

            # Create a W&B Artifact
            artifact = wandb.Artifact('best_model_i3dshufflenet', type='model')
            artifact.add_file(config.checkpoint_path)
            wandb.log_artifact(artifact)
            logging.info(f"Checkpoint {config.checkpoint_path} logged to W&B Artifacts.")
            print(f"Checkpoint {config.checkpoint_path} logged to W&B Artifacts.")

        else:
            epochs_no_improve += 1
            logging.info(f"No improvement in validation accuracy for {epochs_no_improve} epoch(s).")

        # Save checkpoint every N epochs
        if (epoch + 1) % config.checkpoint_interval == 0:
            checkpoint_name = f'checkpoint_epoch_{epoch+1}.pth'
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_acc': best_val_acc,
                'scheduler_state_dict': scheduler.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
            }, checkpoint_name)
            logging.info(f"Checkpoint saved at epoch {epoch+1} as '{checkpoint_name}'.")

        # Update history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        logging.info(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
        logging.info(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')

        # Check early stopping condition
        if epochs_no_improve >= config.early_stop_patience:
            logging.info(f"No improvement for {config.early_stop_patience} consecutive epochs. Stopping training.")
            early_stop = True

    except RuntimeError as e:
        if adjust_batch_size(config, e):
            # Reinitialize data loaders with new batch size
            train_loader = DataLoader(
                train_dataset,
                batch_size=config.batch_size,
                shuffle=True,
                num_workers=config.num_workers,
                pin_memory=True,
                persistent_workers=True
            )

            val_loader = DataLoader(
                val_dataset,
                batch_size=config.batch_size,
                shuffle=False,
                num_workers=config.num_workers,
                pin_memory=True,
                persistent_workers=True
            )

            # Reinitialize optimizer and scheduler
            optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='max',
                factor=config.scheduler_factor,
                patience=config.scheduler_patience,
                verbose=True,
                min_lr=1e-6
            )
            scaler = torch.cuda.amp.GradScaler(enabled=config.use_amp)

            # Retry the current epoch
            epoch -= 1
            torch.cuda.empty_cache()
            gc.collect()
            continue
        else:
            logging.error("Unrecoverable CUDA OOM error.")
            raise e

    torch.cuda.empty_cache()
    gc.collect()

wandb.finish()
plot_training_history(history)
plot_confusion_matrix_custom(all_labels, all_preds, config.classes)

print("\nClassification Report:")
print(class_report)

final_model_path = 'final_model_i3dshufflenet.pth'
torch.save(model.state_dict(), final_model_path)
logging.info(f"Final model saved to '{final_model_path}'.")