In [1]:
import time
from IPython.display import clear_output

def prevent_timeout(minutes=55):
    """
    Prevents notebook timeout by printing a message every specified number of minutes.
    Most notebook environments have a 60-90 minute timeout, so the default is set to 55 minutes.
    
    Args:
        minutes (int): Number of minutes between each activity signal
    """
    seconds = minutes * 60
    counter = 1
    
    print(f"Timeout prevention started. Will refresh every {minutes} minutes.")
    
    try:
        while True:
            time.sleep(seconds)
            clear_output(wait=True)
            current_time = time.strftime("%Y-%m-%d %H:%M:%S")
            print(f"Keeping session alive... Ping #{counter} at {current_time}")
            print(f"Timeout prevention active. Will refresh every {minutes} minutes.")
            counter += 1
    except KeyboardInterrupt:
        print("Timeout prevention stopped.")

# Run this in a Jupyter notebook cell to start the timeout prevention
# You can stop it by interrupting the kernel (Kernel > Interrupt) or by pressing Ctrl+C

In [2]:
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True  # Allow loading of truncated images
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import torchvision.transforms as transforms

# Dataset Class
class VideoAnomalyDataset(Dataset):
    def __init__(self, base_folder, transform=None, use_optical_flow=False, 
                 combine_modalities=False, label_file=None):
        """
        Dataset for loading video frames for anomaly detection.
        
        Args:
            base_folder (str): Path to dataset containing video folders
            transform (callable, optional): Transform to apply to frames
            use_optical_flow (bool): Whether to use optical flow instead of RGB frames
            combine_modalities (bool): Whether to combine RGB and optical flow (returns both)
            label_file (str, optional): Path to annotation file with frame-level labels
        """
        self.base_folder = base_folder
        self.transform = transform
        self.use_optical_flow = use_optical_flow
        self.combine_modalities = combine_modalities
        
        # Find all video folders
        self.video_folders = [f for f in os.listdir(base_folder) 
                             if os.path.isdir(os.path.join(base_folder, f))]
        self.video_folders.sort()
        
        # Build frame paths and video indices for all videos
        self.frame_paths = []
        self.flow_paths = []
        self.video_indices = []
        
        for video_idx, video_folder in enumerate(self.video_folders):
            video_path = os.path.join(base_folder, video_folder)
            
            # Get frames for this video
            frames_folder = os.path.join(video_path, 'frames')
            if os.path.exists(frames_folder):
                frames = [f for f in os.listdir(frames_folder) if f.endswith('.jpg')]
                frames.sort()
                
                for frame in frames:
                    self.frame_paths.append(os.path.join(frames_folder, frame))
                    self.video_indices.append(video_idx)
                    
                    # Also get the corresponding optical flow if it exists
                    flow_folder = os.path.join(video_path, 'optical_flow')
                    flow_name = frame.replace('frame_', 'flow_')
                    flow_path = os.path.join(flow_folder, flow_name)
                    
                    if os.path.exists(flow_path):
                        self.flow_paths.append(flow_path)
                    else:
                        # If flow doesn't exist, use None as placeholder
                        self.flow_paths.append(None)
        
        # Load labels if available
        self.labels = self._load_labels(label_file)
        
        # If no labels provided, create dummy labels for demonstration
        if self.labels is None:
            np.random.seed(42)  # For reproducibility
            self.labels = np.random.choice([0, 1], size=len(self.frame_paths), 
                                           p=[0.8, 0.2]).tolist()
    
    def _load_labels(self, label_file):
        """Load labels from file if available"""
        if label_file is None or not os.path.exists(label_file):
            return None
            
        # Implement label loading logic here based on your annotation format
        # Example: CSV with frame_path,label format
        labels = {}
        with open(label_file, 'r') as f:
            for line in f:
                parts = line.strip().split(',')
                if len(parts) == 2:
                    frame_path, label = parts
                    labels[frame_path] = int(label)
                    
        # Convert to list matching frame_paths order
        return [labels.get(path, 0) for path in self.frame_paths]
    
    def __len__(self):
        return len(self.frame_paths)
    
    def __getitem__(self, idx):
        frame_path = self.frame_paths[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        video_idx = self.video_indices[idx]
        
        try:
            # Handle different modality options
            if self.use_optical_flow and self.flow_paths[idx]:
                # Use optical flow only
                image = Image.open(self.flow_paths[idx]).convert("RGB")
                if self.transform:
                    image = self.transform(image)
                return image, label
                
            elif self.combine_modalities and self.flow_paths[idx]:
                # Combine RGB frame and optical flow
                frame = Image.open(frame_path).convert("RGB")
                flow = Image.open(self.flow_paths[idx]).convert("RGB")
                
                if self.transform:
                    frame = self.transform(frame)
                    flow = self.transform(flow)
                
                # Stack channels (can be modified based on how you want to combine)
                combined = torch.cat((frame, flow), dim=0)  # Channel-wise concatenation
                return combined, label
                
            else:
                # Use RGB frame only (default)
                frame = Image.open(frame_path).convert("RGB")
                if self.transform:
                    frame = self.transform(frame)
                return frame, label
                
        except Exception as e:
            print(f"Error loading image {frame_path}: {e}")
            # Return a placeholder image and label
            placeholder = torch.zeros(3, 112, 112) if self.transform is None else self.transform(
                Image.new("RGB", (112, 112), color=(0, 0, 0))
            )
            return placeholder, label
    
    def get_video_indices(self):
        """Return list of video indices for each frame"""
        return self.video_indices
    
    def get_video_names(self):
        """Return list of video folder names"""
        return self.video_folders

# Modular Vision Transformer Implementation
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=112, patch_size=16, in_channels=3, embed_dim=384):
        super(PatchEmbedding, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, embed_dim, n_patches ** 0.5, n_patches ** 0.5)
        x = x.flatten(2)  # (B, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (B, n_patches, embed_dim)
        return x

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim=384, num_heads=4, dropout=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class MLP(nn.Module):
    def __init__(self, embed_dim=384, hidden_dim=3072, dropout=0.1):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, embed_dim=384, num_heads=4, mlp_ratio=4, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, embed_dim * mlp_ratio, dropout)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=112, patch_size=16, in_channels=3, embed_dim=384, depth=4, num_heads=4, mlp_ratio=4, dropout=0.1):
        super(VisionTransformer, self).__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.n_patches + 1, embed_dim))
        self.dropout = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)  # (B, n_patches, embed_dim)

        cls_token = self.cls_token.expand(B, -1, -1)  # (B, 1, embed_dim)
        x = torch.cat((cls_token, x), dim=1)  # (B, n_patches + 1, embed_dim)
        x = x + self.pos_embed
        x = self.dropout(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)
        return x[:, 0]  # Return only the class token

# Define Vision Transformer Model
class ViTForAnomalyDetection(nn.Module):
    def __init__(self):
        super(ViTForAnomalyDetection, self).__init__()
        self.vit = VisionTransformer(img_size=112, patch_size=16, embed_dim=384, depth=4, num_heads=4)
       
        # Fine-tuned fully connected layers
        self.fc = nn.Sequential(
            nn.Linear(384, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)  # Output a single score for MIL
        )
       
    def forward(self, x):
        x = self.vit(x)
        return self.fc(x)

# Initialize model with GPU support
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViTForAnomalyDetection().to(device)
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPU...")
    model = torch.nn.DataParallel(model)

model.to(device)

# Initialize model weights properly for better training
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)
    elif isinstance(m, nn.LayerNorm):
        torch.nn.init.ones_(m.weight)
        torch.nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

model.apply(init_weights)

# MIL Loss Implementation
class MILoss(nn.Module):
    def __init__(self, lambda_reg=0.01):
        super(MILoss, self).__init__()
        self.lambda_reg = lambda_reg  # Regularization strength

    def forward(self, outputs, labels):
        """
        outputs: Model predictions (batch_size, 1)
        labels: Ground truth labels (batch_size,), where 1 = anomaly, 0 = normal
        """
        # Convert labels to -1 (normal) and 1 (anomaly)
        labels = 2 * labels.float() - 1  # 0 -> -1, 1 -> 1

        # Hinge loss term
        hinge_loss = torch.mean(torch.clamp(1 - labels * outputs, min=0))

        # L2 regularization term (using model parameters)
        l2_reg = 0.0
        for param in model.parameters():
            l2_reg += torch.norm(param, p=2)

        # Total MIL loss
        mil_loss = hinge_loss + self.lambda_reg * l2_reg
        return mil_loss

criterion = MILoss(lambda_reg=0.01)

# Optimizer & Learning Rate Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Define functions to calculate accuracy and other metrics
def calculate_accuracy(predictions, targets):
    """
    Calculate accuracy from binary predictions and targets
    """
    # Handle case when predictions is not a 1D tensor
    if predictions.dim() > 1 and predictions.size(1) == 1:
        predictions = predictions.squeeze(1)
   
    correct = (predictions == targets).float().sum().item()
    total = targets.size(0)
    return correct / total if total > 0 else 0

def get_binary_predictions(outputs):
    """
    Convert model outputs (scores) to binary predictions
    """
    return (outputs > 0).float()

# Define the dataset and dataloader
base_folder = "/kaggle/input/shanghaitech-anomaly-detection/dataset/mp"  # Replace with the path to your dataset
transform = transforms.Compose([
    transforms.Resize((112, 112)),  # Resize frames to 112x112
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize([0.5] * 3, [0.5] * 3),  # Normalize to mean 0.5, std 0.5
])

# Create dataset
dataset = VideoAnomalyDataset(base_folder, transform=transform)

# Create dataloader
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

# Training Loop with Accuracy Calculation
num_epochs = 25  # Increased epochs for better training
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
   
    for inputs, batch_labels in progress_bar:
        inputs, batch_labels = inputs.to(device), batch_labels.to(device)
       
        # Forward pass
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, batch_labels)
       
        # Backward pass
        loss.backward()
        optimizer.step()
       
        # Get binary predictions from scores
        binary_preds = get_binary_predictions(outputs)
       
        # Handle dimensionality of predictions
        if binary_preds.dim() > 1 and binary_preds.size(1) == 1:
            binary_preds = binary_preds.squeeze(1)
           
        # Calculate batch accuracy
        batch_accuracy = calculate_accuracy(binary_preds, batch_labels)
       
        # Store predictions and labels for epoch-level metrics
        all_preds.extend(binary_preds.cpu().detach().numpy())
        all_labels.extend(batch_labels.cpu().detach().numpy())
       
        # Update running loss
        running_loss += loss.item()
        progress_bar.set_postfix({"Loss": f"{loss.item():.4f}", "Acc": f"{batch_accuracy:.4f}"})
   
    # Calculate epoch metrics
    epoch_accuracy = accuracy_score(all_labels, all_preds)
    epoch_loss = running_loss / len(dataloader)  # Average loss for the epoch
   
    # Update learning rate
    lr_scheduler.step()
   
    # Print epoch summary with explicit MIL loss
    print(f"Epoch [{epoch+1}/{num_epochs}] - MIL Loss: {epoch_loss:.4f} - Accuracy: {epoch_accuracy:.4f}")

    # Optional: Add validation here for more comprehensive evaluation
   
# Save the trained model
torch.save(model.state_dict(), "vit_anomaly_detection_model.pth")
print("Model saved successfully!")

Using 2 GPU...


Epoch 1/25: 100%|██████████| 17137/17137 [21:51<00:00, 13.07it/s, Loss=0.0465, Acc=1.0000]


Epoch [1/25] - MIL Loss: 1.6959 - Accuracy: 0.7994


Epoch 2/25: 100%|██████████| 17137/17137 [17:54<00:00, 15.95it/s, Loss=1.1461, Acc=0.4444]


Epoch [2/25] - MIL Loss: 0.4441 - Accuracy: 0.7994


Epoch 3/25: 100%|██████████| 17137/17137 [17:51<00:00, 16.00it/s, Loss=0.2520, Acc=0.8889]


Epoch [3/25] - MIL Loss: 0.4320 - Accuracy: 0.7994


Epoch 4/25: 100%|██████████| 17137/17137 [17:33<00:00, 16.27it/s, Loss=0.0329, Acc=1.0000]


Epoch [4/25] - MIL Loss: 0.4269 - Accuracy: 0.7994


Epoch 5/25: 100%|██████████| 17137/17137 [17:43<00:00, 16.12it/s, Loss=0.4708, Acc=0.7778]


Epoch [5/25] - MIL Loss: 0.4258 - Accuracy: 0.7994


Epoch 6/25: 100%|██████████| 17137/17137 [17:31<00:00, 16.30it/s, Loss=0.0210, Acc=1.0000]


Epoch [6/25] - MIL Loss: 0.4249 - Accuracy: 0.7994


Epoch 7/25: 100%|██████████| 17137/17137 [17:25<00:00, 16.40it/s, Loss=0.0164, Acc=1.0000]


Epoch [7/25] - MIL Loss: 0.4244 - Accuracy: 0.7994


Epoch 8/25: 100%|██████████| 17137/17137 [17:12<00:00, 16.60it/s, Loss=0.6895, Acc=0.6667]


Epoch [8/25] - MIL Loss: 0.4240 - Accuracy: 0.7994


Epoch 9/25: 100%|██████████| 17137/17137 [17:02<00:00, 16.75it/s, Loss=0.0251, Acc=1.0000]


Epoch [9/25] - MIL Loss: 0.4236 - Accuracy: 0.7994


Epoch 10/25: 100%|██████████| 17137/17137 [17:07<00:00, 16.68it/s, Loss=0.2417, Acc=0.8889]


Epoch [10/25] - MIL Loss: 0.4234 - Accuracy: 0.7994


Epoch 11/25: 100%|██████████| 17137/17137 [17:28<00:00, 16.34it/s, Loss=0.2427, Acc=0.8889]


Epoch [11/25] - MIL Loss: 0.4234 - Accuracy: 0.7994


Epoch 12/25: 100%|██████████| 17137/17137 [17:44<00:00, 16.09it/s, Loss=0.6902, Acc=0.6667]


Epoch [12/25] - MIL Loss: 0.4234 - Accuracy: 0.7994


Epoch 13/25: 100%|██████████| 17137/17137 [17:39<00:00, 16.17it/s, Loss=0.0170, Acc=1.0000]


Epoch [13/25] - MIL Loss: 0.4233 - Accuracy: 0.7994


Epoch 14/25: 100%|██████████| 17137/17137 [17:49<00:00, 16.02it/s, Loss=0.2425, Acc=0.8889]


Epoch [14/25] - MIL Loss: 0.4233 - Accuracy: 0.7994


Epoch 15/25: 100%|██████████| 17137/17137 [17:57<00:00, 15.90it/s, Loss=0.2462, Acc=0.8889]


Epoch [15/25] - MIL Loss: 0.4233 - Accuracy: 0.7994


Epoch 16/25: 100%|██████████| 17137/17137 [17:49<00:00, 16.03it/s, Loss=0.6895, Acc=0.6667]


Epoch [16/25] - MIL Loss: 0.4232 - Accuracy: 0.7994


Epoch 17/25: 100%|██████████| 17137/17137 [17:59<00:00, 15.88it/s, Loss=0.4665, Acc=0.7778]


Epoch [17/25] - MIL Loss: 0.4231 - Accuracy: 0.7994


Epoch 18/25: 100%|██████████| 17137/17137 [17:40<00:00, 16.15it/s, Loss=1.1312, Acc=0.4444]


Epoch [18/25] - MIL Loss: 0.4227 - Accuracy: 0.7994


Epoch 19/25: 100%|██████████| 17137/17137 [17:44<00:00, 16.09it/s, Loss=0.4656, Acc=0.7778]


Epoch [19/25] - MIL Loss: 0.4223 - Accuracy: 0.7994


Epoch 20/25: 100%|██████████| 17137/17137 [17:30<00:00, 16.32it/s, Loss=0.4643, Acc=0.7778]


Epoch [20/25] - MIL Loss: 0.4216 - Accuracy: 0.7994


Epoch 21/25: 100%|██████████| 17137/17137 [17:32<00:00, 16.28it/s, Loss=0.6879, Acc=0.6667]


Epoch [21/25] - MIL Loss: 0.4207 - Accuracy: 0.7994


Epoch 22/25: 100%|██████████| 17137/17137 [17:23<00:00, 16.42it/s, Loss=0.4610, Acc=0.7778]


Epoch [22/25] - MIL Loss: 0.4191 - Accuracy: 0.7994


Epoch 23/25: 100%|██████████| 17137/17137 [16:30<00:00, 17.30it/s, Loss=0.6771, Acc=0.6667]


Epoch [23/25] - MIL Loss: 0.4133 - Accuracy: 0.7994


Epoch 24/25: 100%|██████████| 17137/17137 [16:34<00:00, 17.23it/s, Loss=0.4548, Acc=0.7778]


Epoch [24/25] - MIL Loss: 0.4115 - Accuracy: 0.7994


Epoch 25/25: 100%|██████████| 17137/17137 [16:34<00:00, 17.24it/s, Loss=0.4548, Acc=0.7778]


Epoch [25/25] - MIL Loss: 0.4114 - Accuracy: 0.7994
Model saved successfully!
