In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# Adaptive Gamma Correction


In [None]:
def soft_argmax(x, beta=100.0):
    """
    Differentiable approximation of argmax using softmax with temperature scaling.

    Args:
        x: Input tensor of shape [batch_size, num_experts]
        beta: Temperature parameter (higher values make it closer to argmax)

    Returns:
        Soft argmax values of shape [batch_size]
    """
    # Apply temperature scaling for sharper distribution
    scaled_x = beta * x
    # Compute softmax probabilities
    softmax_probs = F.softmax(scaled_x, dim=-1)
    # Create indices tensor
    indices = torch.arange(x.size(-1), dtype=x.dtype, device=x.device)
    # Compute weighted sum of indices
    soft_indices = torch.sum(indices * softmax_probs, dim=-1)

    return soft_indices, softmax_probs


In [None]:
class FeatureExtractor(nn.Module):
    """Extract statistical features from video clips for AGC"""

    def __init__(self):
        super(FeatureExtractor, self).__init__()

    def forward(self, video_clip):
        """
        Extract mean, std, and Shannon entropy for each frame

        Args:
            video_clip: Tensor of shape [batch_size, T, H, W, C]

        Returns:
            features: Tensor of shape [batch_size, T*3]
        """
        batch_size, T, H, W, C = video_clip.shape
        features = []

        for t in range(T):
            frame = video_clip[:, t]  # [batch_size, H, W, C]

            # Convert to grayscale if needed
            if C == 3:
                gray_frame = torch.mean(frame, dim=-1)  # [batch_size, H, W]
            else:
                gray_frame = frame.squeeze(-1)

            # Extract statistical features
            mean_val = torch.mean(gray_frame, dim=(1, 2))  # [batch_size]
            std_val = torch.std(gray_frame, dim=(1, 2))    # [batch_size]

            # Calculate Shannon entropy
            entropy_val = self._calculate_entropy(gray_frame)  # [batch_size]

            # Stack features for this frame
            frame_features = torch.stack([mean_val, std_val, entropy_val], dim=1)  # [batch_size, 3]
            features.append(frame_features)

        # Concatenate all frame features
        features = torch.cat(features, dim=1)  # [batch_size, T*3]
        return features

    def _calculate_entropy(self, gray_frame):
        """Calculate Shannon entropy for grayscale frames"""
        batch_size = gray_frame.shape[0]
        entropies = []

        for i in range(batch_size):
            frame = gray_frame[i].flatten()
            # Create histogram
            hist = torch.histc(frame, bins=256, min=0.0, max=1.0)
            # Normalize to get probabilities
            hist = hist / hist.sum()
            # Remove zeros to avoid log(0)
            hist = hist[hist > 0]
            # Calculate entropy
            entropy = -torch.sum(hist * torch.log2(hist))
            entropies.append(entropy)

        return torch.stack(entropies)


In [None]:
class GatingNetwork(nn.Module):
    """Gating network that uses SoftArgmax for differentiable expert selection"""

    def __init__(self, input_size, num_experts=5, hidden_dim=64, beta=100.0):
        super(GatingNetwork, self).__init__()
        self.num_experts = num_experts
        self.beta = beta

        # Three-layer fully connected network with batch normalization
        self.fc1 = nn.Linear(input_size, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.bn2 = nn.BatchNorm1d(hidden_dim // 2)
        self.fc3 = nn.Linear(hidden_dim // 2, num_experts)

        self.relu = nn.ReLU()

    def forward(self, x):
        """
        Forward pass through gating network

        Args:
            x: Feature tensor of shape [batch_size, input_size]

        Returns:
            lambda_vector: Expert weights [batch_size, num_experts]
            gate_indices: Soft argmax indices [batch_size]
        """
        # First layer: FC1 -> BatchNorm1 -> ReLU
        x = self.fc1(x)
        x = self.bn1(x)
        x = self.relu(x)

        # Second layer: FC2 -> BatchNorm2 -> ReLU
        x = self.fc2(x)
        x = self.bn2(x)
        x = self.relu(x)

        # Output layer: FC3
        logits = self.fc3(x)  # [batch_size, num_experts]

        # Apply SoftArgmax to get differentiable indices
        gate_indices, lambda_vector = soft_argmax(logits, self.beta)

        return lambda_vector, gate_indices


In [None]:
class GammaIntensityCorrection(nn.Module):
    """Gamma Intensity Correction module with multiple gamma values"""

    def __init__(self, gamma_values=[1.0, 1.5, 2.0, 2.5, 3.0]):
        super(GammaIntensityCorrection, self).__init__()
        self.gamma_values = gamma_values
        self.num_experts = len(gamma_values)

    def forward(self, video_clip, expert_weights):
        """
        Apply weighted gamma correction based on expert selection

        Args:
            video_clip: Input video [batch_size, T, H, W, C]
            expert_weights: Gating weights [batch_size, num_experts]

        Returns:
            corrected_clip: Gamma corrected video [batch_size, T, H, W, C]
        """
        batch_size = video_clip.shape[0]
        corrected_clips = []

        # Apply each gamma correction
        for i, gamma in enumerate(self.gamma_values):
            gamma_corrected = self._apply_gamma_correction(video_clip, gamma)
            corrected_clips.append(gamma_corrected)

        # Stack all corrected clips
        corrected_clips = torch.stack(corrected_clips, dim=1)  # [batch_size, num_experts, T, H, W, C]

        # Weighted combination using expert weights
        expert_weights = expert_weights.view(batch_size, self.num_experts, 1, 1, 1, 1)
        final_clip = torch.sum(corrected_clips * expert_weights, dim=1)  # [batch_size, T, H, W, C]

        return final_clip

    def _apply_gamma_correction(self, image, gamma):
        """
        Apply gamma intensity correction formula:
        GIC_gamma(I) = [(max(I) - min(I)) * ((I - min(I))/(max(I) - min(I)))^(1/gamma)] + min(I)
        """
        # Ensure image is in float format
        if image.dtype != torch.float32:
            image = image.float()

        # Normalize to [0,1] if needed
        if torch.max(image) > 1.0:
            image = image / 255.0

        # Apply GIC formula per batch
        batch_size = image.shape[0]
        corrected_images = []

        for i in range(batch_size):
            img = image[i]
            min_val = torch.min(img)
            max_val = torch.max(img)

            if max_val > min_val:  # Avoid division by zero
                normalized = (img - min_val) / (max_val - min_val)
                corrected = torch.pow(normalized, 1.0/gamma)
                result = (max_val - min_val) * corrected + min_val
            else:
                result = img

            corrected_images.append(result)

        return torch.stack(corrected_images, dim=0)


In [None]:
class AdaptiveGammaCorrection(nn.Module):
    """Complete Adaptive Gamma Correction module using SoftArgmax"""

    def __init__(self, gamma_values=[1.0, 1.5, 2.0, 2.5, 3.0], hidden_dim=64, beta=100.0):
        super(AdaptiveGammaCorrection, self).__init__()
        self.gamma_values = gamma_values
        self.num_experts = len(gamma_values)

        # Feature extractor
        self.feature_extractor = FeatureExtractor()

        # Gating network (will be initialized after seeing first input)
        self.gating_network = None
        self.hidden_dim = hidden_dim
        self.beta = beta

        # Gamma correction module
        self.gic = GammaIntensityCorrection(gamma_values)

    def forward(self, video_clip):
        """
        Process video clip with adaptive gamma correction

        Args:
            video_clip: Input video tensor [batch_size, T, H, W, C]

        Returns:
            corrected_clip: Gamma corrected video [batch_size, T, H, W, C]
            lambda_vector: Expert weights [batch_size, num_experts]
            gate_indices: Selected expert indices [batch_size]
            selected_gammas: Selected gamma values [batch_size]
        """
        # Extract features
        features = self.feature_extractor(video_clip)  # [batch_size, T*3]

        # Initialize gating network if needed
        if self.gating_network is None:
            input_size = features.shape[1]
            self.gating_network = GatingNetwork(
                input_size, self.num_experts, self.hidden_dim, self.beta
            ).to(video_clip.device)

        # Get expert weights and indices using SoftArgmax
        lambda_vector, gate_indices = self.gating_network(features)

        # Apply gamma correction
        corrected_clip = self.gic(video_clip, lambda_vector)

        # Get selected gamma values for interpretation
        selected_gammas = self._get_selected_gammas(gate_indices)

        return corrected_clip, lambda_vector, gate_indices, selected_gammas

    def _get_selected_gammas(self, gate_indices):
        """Convert soft indices to gamma values for interpretation"""
        gamma_tensor = torch.tensor(self.gamma_values, device=gate_indices.device)
        # Use the soft indices to interpolate between gamma values
        indices_clamped = torch.clamp(gate_indices, 0, len(self.gamma_values) - 1)

        # Linear interpolation between adjacent gamma values
        lower_idx = torch.floor(indices_clamped).long()
        upper_idx = torch.clamp(lower_idx + 1, 0, len(self.gamma_values) - 1)

        alpha = indices_clamped - lower_idx.float()
        selected_gammas = (1 - alpha) * gamma_tensor[lower_idx] + alpha * gamma_tensor[upper_idx]

        return selected_gammas


In [None]:
agc = AdaptiveGammaCorrection(
    gamma_values=[1.0, 1.5, 2.0, 2.5, 3.0],
    hidden_dim=64,
    beta=100.0
)

In [None]:
video_clip = torch.randn(2, 32, 224, 224, 3)

In [None]:
corrected_clip, lambda_vector, gate_indices, selected_gammas = agc(video_clip)

Original shape: torch.Size([2, 32, 224, 224, 3])
Corrected shape: torch.Size([2, 32, 224, 224, 3])
Expert weights: tensor([[1.0000e+00, 2.6171e-30, 1.5817e-09, 3.6904e-30, 0.0000e+00],
        [8.2547e-09, 9.2468e-14, 1.0717e-05, 6.2197e-01, 3.7802e-01]],
       grad_fn=<SoftmaxBackward0>)
Gate indices: tensor([3.1633e-09, 3.3780e+00], grad_fn=<SumBackward1>)
Selected gammas: tensor([1.0000, 2.6890], grad_fn=<AddBackward0>)


In [None]:
corrected_clip.shape, lambda_vector.shape, gate_indices.shape, selected_gammas.shape

(torch.Size([2, 32, 224, 224, 3]),
 torch.Size([2, 5]),
 torch.Size([2]),
 torch.Size([2]))

# Video Swin Backbone

In [None]:
from torchvision.models.video import swin3d_s,Swin3D_S_Weights

In [None]:
class VideoSwinBackbone(nn.Module):
  def __init__(self,pretrained=True):
    super().__init__()
    weights = Swin3D_S_Weights.KINETICS400_V1 if pretrained else None
    self.model = swin3d_s(weights=weights)

  def forward(self,x): # X -> [B,T,H,W,C]
    x = x.permute(0,4,1,2,3) # [B,C,T,H,W]
    x = self.model.patch_embed(x) # B _T _H _W C
    x = self.model.pos_drop(x)
    x = self.model.features(x)  # B _T _H _W C
    x = self.model.norm(x)
    x = x.permute(0, 4, 1, 2, 3)  # B, C, _T, _H, _W
    return x

In [None]:
output = VideoSwinBackbone()(corrected_clip)

Downloading: "https://download.pytorch.org/models/swin3d_s-da41c237.pth" to /root/.cache/torch/hub/checkpoints/swin3d_s-da41c237.pth
100%|██████████| 218M/218M [00:02<00:00, 85.6MB/s]


In [None]:
output.shape

torch.Size([2, 768, 16, 7, 7])

# Adaptive Head  Selection

In [None]:
class I3DHead(nn.Module):
    """
    I3D‐style classification head:
      - AdaptiveAvgPool3d → Flatten → Dropout → Linear
    """
    def __init__(self, in_channels, num_classes, dropout_rate=0.5):
        super(I3DHead, self).__init__()
        self.pool = nn.AdaptiveAvgPool3d(1)            # [B, C, 1,1,1]
        self.dropout = nn.Dropout(dropout_rate)        # regularization
        self.fc = nn.Linear(in_channels, num_classes)  # final classifier

    def forward(self, x):
        """
        x: [B, C, T', H', W']
        returns: [B, num_classes]
        """
        x = self.pool(x)        # [B, C, 1, 1, 1]
        x = torch.flatten(x, 1) # [B, C]
        x = self.dropout(x)     # [B, C]
        return self.fc(x)       # [B, num_classes]


In [None]:
class AdaptiveHeadSelection(nn.Module):
    """
    Adaptive Head Selection module:
      - Holds one I3DHead per expert
      - Aggregates head outputs using gating weights
    """
    def __init__(self, in_channels, num_classes, num_experts, dropout_rate=0.5):
        super(AdaptiveHeadSelection, self).__init__()
        self.num_experts = num_experts
        self.heads = nn.ModuleList([
            I3DHead(in_channels, num_classes, dropout_rate)
            for _ in range(num_experts)
        ])

    def forward(self, features, expert_weights):
        """
        features: [B, C, T', H', W']
        expert_weights: [B, E] from AGC gating (softmax probabilities)
        returns:
          - logits: [B, num_classes]
          - per_head_logits: [B, E, num_classes]
        """
        # Compute logits for each expert head
        head_logits = []
        for head in self.heads:
            head_logits.append(head(features).unsqueeze(1))  # [B,1,num_classes]
        head_logits = torch.cat(head_logits, dim=1)         # [B,E,num_classes]

        # Weight and sum logits over experts
        weights = expert_weights.unsqueeze(-1)              # [B,E,1]
        fused_logits = torch.sum(weights * head_logits, dim=1)  # [B,num_classes]
        return fused_logits, head_logits

# Combined

In [None]:
class DGAMModel(nn.Module):
    def __init__(self, agc, backbone, num_classes):
        super(DGAMModel, self).__init__()
        self.agc = agc                           # AdaptiveGammaCorrection
        self.backbone = backbone                 # VideoSwinBackbone
        self.ahs = AdaptiveHeadSelection(
            in_channels=backbone.model.num_features,
            num_classes=num_classes,
            num_experts=agc.num_experts
        )

    def forward(self, video_clip):
        """
        Returns final class logits and auxiliary per‐expert logits.
        """
        # 1. Adaptive gamma correction
        corrected_clip, expert_weights, _, _ = self.agc(video_clip)
        # 2. Feature extraction
        features = self.backbone(corrected_clip)  # [B,C,T',H',W']
        # 3. Adaptive head selection
        fused_logits, per_head_logits = self.ahs(features, expert_weights)
        return fused_logits, per_head_logits

In [None]:
backbone = VideoSwinBackbone(pretrained=True)  # or True if pretrained weights are needed
model = DGAMModel(agc=agc, backbone=backbone, num_classes=2)

# Run forward pass
output = model(torch.randn(2, 32, 224, 224, 3))

In [None]:
output

(tensor([[-0.0940, -0.0656],
         [-0.0791, -0.1593]], grad_fn=<SumBackward1>),
 tensor([[[-0.0940, -0.0656],
          [ 0.0873, -0.1315],
          [ 0.0351,  0.1086],
          [-0.0880, -0.0831],
          [ 0.0607,  0.0510]],
 
         [[-0.1210, -0.0157],
          [ 0.0921, -0.4061],
          [ 0.0368,  0.1223],
          [-0.0791, -0.1593],
          [-0.0153, -0.1523]]], grad_fn=<CatBackward0>))

In [None]:
output[0].shape, output[1].shape

(torch.Size([2, 2]), torch.Size([2, 5, 2]))

In [None]:
# import torch
# import torch.nn.functional as F
# from torch.optim import AdamW
# from torch.optim.lr_scheduler import CosineAnnealingLR
# from torch.cuda.amp import autocast, GradScaler

# def get_optimizer(model):
#     # Custom LR for gating
#     params = [
#         {"params": [], "lr": 3e-4},  # default group
#         {"params": [], "lr": 3e-4 * 1000}  # gating group
#     ]

#     for name, param in model.named_parameters():
#         if not param.requires_grad:
#             continue
#         if "gating_network" in name:
#             params[1]["params"].append(param)
#         else:
#             params[0]["params"].append(param)

#     return AdamW(params, betas=(0.9, 0.999), weight_decay=0.05)

# def accuracy(output, target, topk=(1, 5)):
#     """Computes the top-k accuracy"""
#     maxk = max(topk)
#     _, pred = output.topk(maxk, dim=1, largest=True, sorted=True)
#     correct = pred.eq(target.view(-1, 1).expand_as(pred))

#     return [
#         correct[:, :k].float().sum().item() / target.size(0)
#         for k in topk
#     ]

# @torch.no_grad()
# def evaluate(model, dataloader, device):
#     model.eval()
#     loss_total, correct1, correct5, total = 0.0, 0, 0, 0

#     for video_clip, labels in dataloader:
#         video_clip, labels = video_clip.to(device), labels.to(device)

#         with autocast():
#             fused_logits, _ = model(video_clip)
#             loss = F.cross_entropy(fused_logits, labels)

#         loss_total += loss.item() * labels.size(0)
#         acc1, acc5 = accuracy(fused_logits, labels)
#         correct1 += acc1 * labels.size(0)
#         correct5 += acc5 * labels.size(0)
#         total += labels.size(0)

#     return loss_total / total, correct1 / total, correct5 / total

# def train_model(model, train_loader, val_loader, device):
#     total_epochs = 100
#     warmup_epochs = 2.5
#     warmup_steps = int(warmup_epochs * len(train_loader))
#     global_step = 0

#     model.to(device)
#     optimizer = get_optimizer(model)
#     scaler = GradScaler()
#     scheduler = CosineAnnealingLR(optimizer, T_max=total_epochs - int(warmup_epochs))

#     for epoch in range(1, total_epochs + 1):
#         model.train()
#         train_loss, correct1, correct5, total = 0.0, 0, 0, 0

#         for video_clip, labels in train_loader:
#             video_clip, labels = video_clip.to(device), labels.to(device)
#             global_step += 1

#             if epoch <= warmup_epochs:
#                 lr_scale = global_step / warmup_steps
#                 for pg in optimizer.param_groups:
#                     pg["lr"] = 3e-4 * (1000 if pg["lr"] > 3e-4 else 1) * lr_scale

#             optimizer.zero_grad()
#             with autocast():
#                 fused_logits, _ = model(video_clip)
#                 loss = F.cross_entropy(fused_logits, labels)

#             scaler.scale(loss).backward()
#             scaler.step(optimizer)
#             scaler.update()

#             train_loss += loss.item() * labels.size(0)
#             acc1, acc5 = accuracy(fused_logits, labels)
#             correct1 += acc1 * labels.size(0)
#             correct5 += acc5 * labels.size(0)
#             total += labels.size(0)

#         if epoch > warmup_epochs:
#             scheduler.step()

#         print(f"[Epoch {epoch:03d}] Train Loss: {train_loss/total:.4f}, "
#               f"Top-1: {correct1/total:.4f}, Top-5: {correct5/total:.4f}")

#         # Evaluate every 2 epochs
#         if epoch % 2 == 0:
#             val_loss, val_top1, val_top5 = evaluate(model, val_loader, device)
#             print(f"[Epoch {epoch:03d}] Val Loss: {val_loss:.4f}, "
#                   f"Top-1: {val_top1:.4f}, Top-5: {val_top5:.4f}")


In [None]:
import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.cuda.amp import autocast, GradScaler

# ------------------------------------------
# 1. Freeze utilities
# ------------------------------------------
def freeze_module(module):
    for param in module.parameters():
        param.requires_grad = False

def unfreeze_module(module):
    for param in module.parameters():
        param.requires_grad = True

def freeze_backbone_and_heads(model):
    freeze_module(model.backbone)
    for head in model.ahs.heads:
        freeze_module(head)

def freeze_gating(model):
    freeze_module(model.agc.gating_network)

def unfreeze_backbone_and_heads(model):
    unfreeze_module(model.backbone)
    for head in model.ahs.heads:
        unfreeze_module(head)

def unfreeze_gating(model):
    unfreeze_module(model.agc.gating_network)

# ------------------------------------------
# 2. Optimizer builder
# ------------------------------------------
def get_optimizer(model, gating_lr_mult=1.0):
    base_params, gating_params = [], []

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if "gating_network" in name:
            gating_params.append(param)
        else:
            base_params.append(param)

    return AdamW([
        {"params": base_params, "lr": 3e-4},
        {"params": gating_params, "lr": 3e-4 * gating_lr_mult}
    ], betas=(0.9, 0.999), weight_decay=0.05)

# ------------------------------------------
# 3. Accuracy metrics
# ------------------------------------------
def accuracy(output, target, topk=(1, 5)):
    maxk = max(topk)
    _, pred = output.topk(maxk, 1, True, True)
    correct = pred.eq(target.view(-1, 1).expand_as(pred))
    return [correct[:, :k].float().sum().item() / target.size(0) for k in topk]

# ------------------------------------------
# 4. Evaluation
# ------------------------------------------
@torch.no_grad()
def evaluate(model, dataloader, device):
    model.eval()
    loss_sum, correct1, correct5, total = 0.0, 0, 0, 0

    for video_clip, labels in dataloader:
        video_clip, labels = video_clip.to(device), labels.to(device)
        with autocast():
            fused_logits, _ = model(video_clip)
            loss = F.cross_entropy(fused_logits, labels)

        loss_sum += loss.item() * labels.size(0)
        acc1, acc5 = accuracy(fused_logits, labels)
        correct1 += acc1 * labels.size(0)
        correct5 += acc5 * labels.size(0)
        total += labels.size(0)

    return loss_sum / total, correct1 / total, correct5 / total

# ------------------------------------------
# 5. Full training loop with switching logic
# ------------------------------------------
def train_dgam_switching(model, train_loader, val_loader, device,
                         stage1_epochs=30, stage2_epochs=30):

    scaler = GradScaler()

    # -------------------------------
    # Stage 1: Train backbone + heads
    # -------------------------------
    print("🧠 Stage 1: Training backbone and heads (gating frozen)")
    freeze_gating(model)
    unfreeze_backbone_and_heads(model)

    optimizer = get_optimizer(model, gating_lr_mult=1.0)  # gating ignored
    scheduler = CosineAnnealingLR(optimizer, T_max=stage1_epochs)

    for epoch in range(1, stage1_epochs + 1):
        model.train()
        total_loss, correct1, correct5, total = 0.0, 0, 0, 0

        for video_clip, labels in train_loader:
            video_clip, labels = video_clip.to(device), labels.to(device)

            optimizer.zero_grad()
            with autocast():
                fused_logits, _ = model(video_clip)
                loss = F.cross_entropy(fused_logits, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item() * labels.size(0)
            acc1, acc5 = accuracy(fused_logits, labels)
            correct1 += acc1 * labels.size(0)
            correct5 += acc5 * labels.size(0)
            total += labels.size(0)

        scheduler.step()
        print(f"[Stage1-Epoch {epoch:02d}] Loss: {total_loss/total:.4f}, "
              f"Top-1: {correct1/total:.4f}, Top-5: {correct5/total:.4f}")

        if epoch % 2 == 0:
            val_loss, val_top1, val_top5 = evaluate(model, val_loader, device)
            print(f"[Stage1-Epoch {epoch:02d}] Val Loss: {val_loss:.4f}, "
                  f"Top-1: {val_top1:.4f}, Top-5: {val_top5:.4f}")

    # -------------------------------
    # Stage 2: Train gating only
    # -------------------------------
    print("🔀 Stage 2: Training gating network (backbone & heads frozen)")
    freeze_backbone_and_heads(model)
    unfreeze_gating(model)

    optimizer = get_optimizer(model, gating_lr_mult=1000.0)
    scheduler = CosineAnnealingLR(optimizer, T_max=stage2_epochs)

    for epoch in range(1, stage2_epochs + 1):
        model.train()
        total_loss, correct1, correct5, total = 0.0, 0, 0, 0

        for video_clip, labels in train_loader:
            video_clip, labels = video_clip.to(device), labels.to(device)

            optimizer.zero_grad()
            with autocast():
                fused_logits, _ = model(video_clip)
                loss = F.cross_entropy(fused_logits, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item() * labels.size(0)
            acc1, acc5 = accuracy(fused_logits, labels)
            correct1 += acc1 * labels.size(0)
            correct5 += acc5 * labels.size(0)
            total += labels.size(0)

        scheduler.step()
        print(f"[Stage2-Epoch {epoch:02d}] Loss: {total_loss/total:.4f}, "
              f"Top-1: {correct1/total:.4f}, Top-5: {correct5/total:.4f}")

        if epoch % 2 == 0:
            val_loss, val_top1, val_top5 = evaluate(model, val_loader, device)
            print(f"[Stage2-Epoch {epoch:02d}] Val Loss: {val_loss:.4f}, "
                  f"Top-1: {val_top1:.4f}, Top-5: {val_top5:.4f}")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DGAMModel(agc, backbone, num_classes=NUM_CLASSES)
train_dgam_switching(model, train_loader, val_loader, device)
