In [1]:
import time
from IPython.display import clear_output

def prevent_timeout(minutes=55):
    """
    Prevents notebook timeout by printing a message every specified number of minutes.
    Most notebook environments have a 60-90 minute timeout, so the default is set to 55 minutes.
    
    Args:
        minutes (int): Number of minutes between each activity signal
    """
    seconds = minutes * 60
    counter = 1
    
    print(f"Timeout prevention started. Will refresh every {minutes} minutes.")
    
    try:
        while True:
            time.sleep(seconds)
            clear_output(wait=True)
            current_time = time.strftime("%Y-%m-%d %H:%M:%S")
            print(f"Keeping session alive... Ping #{counter} at {current_time}")
            print(f"Timeout prevention active. Will refresh every {minutes} minutes.")
            counter += 1
    except KeyboardInterrupt:
        print("Timeout prevention stopped.")

# Run this in a Jupyter notebook cell to start the timeout prevention
# You can stop it by interrupting the kernel (Kernel > Interrupt) or by pressing Ctrl+C

In [2]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from PIL import Image
import os
from tqdm import tqdm
from torchvision.models import vit_b_16, ViT_B_16_Weights
import numpy as np

# Data Augmentation for Frames
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * 3, [0.5] * 3),  # Normalize to mean 0.5, std 0.5
])

class FrameDataset(Dataset):
    def __init__(self, base_folder, transform=None):
        self.base_folder = base_folder
        self.transform = transform
        self.frame_paths = []
        self.labels = []

        video_folders = sorted([f for f in os.listdir(base_folder) if os.path.isdir(os.path.join(base_folder, f))])
        for video_folder in video_folders:
            frame_folder = os.path.join(base_folder, video_folder, 'frames')
           
            if not os.path.exists(frame_folder):
                continue
           
            frames = sorted([f for f in os.listdir(frame_folder) if f.endswith(".jpg")])
            self.frame_paths.extend([os.path.join(frame_folder, f) for f in frames])
            self.labels.extend([1 if 'anomaly' in video_folder.lower() else 0] * len(frames))  # Assign labels

    def __len__(self):
        return len(self.frame_paths)

    def __getitem__(self, idx):
        frame = Image.open(self.frame_paths[idx]).convert("RGB")
        label = torch.tensor(self.labels[idx], dtype=torch.long)

        if self.transform:
            frame = self.transform(frame)

        return frame, label

# Load dataset
base_folder = "/kaggle/input/shanghaitech-anomaly-detection/dataset/mp"
dataset = FrameDataset(base_folder, transform)

# Weighted Sampling to Handle Class Imbalance
class_counts = np.bincount(dataset.labels)
weights = 1.0 / class_counts[dataset.labels]
sampler = WeightedRandomSampler(weights, len(weights))
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

# Define Vision Transformer Model
class ViTForAnomalyDetection(nn.Module):
    def __init__(self):
        super(ViTForAnomalyDetection, self).__init__()
        self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
       
        # Fine-tuned fully connected layers
        self.fc = nn.Sequential(
            nn.Linear(1000, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 1)  # Output a single score for MIL
        )
       
    def forward(self, x):
        x = self.vit(x)
        return self.fc(x)

# Initialize model with GPU support
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViTForAnomalyDetection().to(device)
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)

# MIL Loss Implementation
class MILoss(nn.Module):
    def __init__(self, lambda_reg=0.01):
        super(MILoss, self).__init__()
        self.lambda_reg = lambda_reg  # Regularization strength

    def forward(self, outputs, labels):
        """
        outputs: Model predictions (batch_size, 1)
        labels: Ground truth labels (batch_size,), where 1 = anomaly, 0 = normal
        """
        # Convert labels to -1 (normal) and 1 (anomaly)
        labels = 2 * labels.float() - 1  # 0 -> -1, 1 -> 1

        # Hinge loss term
        hinge_loss = torch.mean(torch.clamp(1 - labels * outputs, min=0))

        # L2 regularization term (using model parameters)
        l2_reg = 0.0
        for param in model.parameters():
            l2_reg += torch.norm(param, p=2)

        # Total MIL loss
        mil_loss = hinge_loss + self.lambda_reg * l2_reg
        return mil_loss

criterion = MILoss(lambda_reg=0.01)

# Optimizer & Learning Rate Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Training Loop
num_epochs = 2  # Increase epochs
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
   
    for inputs, batch_labels in progress_bar:
        inputs, batch_labels = inputs.to(device), batch_labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()

        # Update running loss
        running_loss += loss.item()
        progress_bar.set_postfix({"MIL Loss": loss.item()})
   
    lr_scheduler.step()
    print(f"Epoch [{epoch+1}/{num_epochs}] - MIL Loss: {running_loss / len(dataloader):.4f}")

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:01<00:00, 193MB/s]  
Epoch 1/2: 100%|██████████| 8569/8569 [2:37:42<00:00,  1.10s/it, MIL Loss=0.0337] 


Epoch [1/2] - MIL Loss: 2.0551


Epoch 2/2: 100%|██████████| 8569/8569 [2:23:51<00:00,  1.01s/it, MIL Loss=0.0226]  

Epoch [2/2] - MIL Loss: 0.0272



