In [None]:
import os
import json
import random
import numpy as np
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.transforms as transforms

import decord
from einops import rearrange
import timm

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report

# For accessing Google Drive
from google.colab import drive

# Constants
NUM_CLASSES = 14  # Number of crime categories
class_names = {
    1: "Abuse",
    2: "Arrest",
    3: "Arson",
    4: "Assault",
    5: "Burglary",
    6: "Explosion",
    7: "Fighting",
    8: "Normal",
    9: "RoadAccidents",
    10: "Robbery",
    11: "Shooting",
    12: "Shoplifting",
    13: "Stealing",
    14: "Vandalism"
}

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define LoRA layers
class LoRALinear(nn.Module):
    def __init__(self, in_features, out_features, r=8, lora_alpha=16):
        super().__init__()
        self.original = nn.Linear(in_features, out_features)
        self.lora_down = nn.Linear(in_features, r, bias=False)
        self.lora_up = nn.Linear(r, out_features, bias=False)
        self.scale = lora_alpha / r

        # Initialize LoRA weights
        nn.init.normal_(self.lora_down.weight, std=1/r)
        nn.init.zeros_(self.lora_up.weight)

        # Freeze original weights
        for param in self.original.parameters():
            param.requires_grad = False

    def forward(self, x):
        # Original path + LoRA path
        return self.original(x) + self.lora_up(self.lora_down(x)) * self.scale

# Define TimeSformer model with LoRA
class TimeSformerForVideoClassification(nn.Module):
    def __init__(self, num_classes, embed_dim=768, depth=12, num_heads=12, num_frames=8,
                 lora_rank=16, pretrained=True, pretrained_model_path=None):
        super().__init__()
        self.num_frames = num_frames
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.depth = depth

        # Load pretrained ViT as base model
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=False)

        # If pretrained model path is provided, load weights
        if pretrained and pretrained_model_path:
            print(f"Loading pretrained TimeSformer weights from {pretrained_model_path}")
            self._load_pretrained_weights(pretrained_model_path)

        # Replace attention blocks with LoRA attention blocks
        self._replace_attention_with_lora(lora_rank)

        # Replace classifier head
        self.head = nn.Linear(embed_dim, num_classes)

        # Count parameters
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%})")

    def _load_pretrained_weights(self, pretrained_model_path):
        try:
            # Load pretrained weights
            pretrained_dict = torch.load(pretrained_model_path, map_location='cpu')

            # Handle different model formats
            if 'model_state' in pretrained_dict:
                # Original TimeSformer format
                pretrained_dict = pretrained_dict['model_state']
            elif 'state_dict' in pretrained_dict:
                # PyTorch checkpoint format
                pretrained_dict = pretrained_dict['state_dict']

            # Filter out irrelevant keys and adapt key names
            model_dict = self.vit.state_dict()
            pretrained_dict = {k.replace('backbone.', ''): v for k, v in pretrained_dict.items()
                              if 'backbone.' in k and k.replace('backbone.', '') in model_dict}

            # Load weights
            model_dict.update(pretrained_dict)
            self.vit.load_state_dict(model_dict)
            print(f"Successfully loaded pretrained TimeSformer weights")
        except Exception as e:
            print(f"Error loading pretrained weights: {e}")
            print("Initializing with random weights")

    def _replace_attention_with_lora(self, lora_rank):
        """
        Replace self-attention query, key, value projections with LoRA versions
        """
        for i, block in enumerate(self.vit.blocks):
            # Replace q, k, v projections in self-attention
            qkv_weight_shape = block.attn.qkv.weight.shape
            qkv_in_features = qkv_weight_shape[1]
            qkv_out_features = qkv_weight_shape[0]

            # Store original weights
            original_qkv = block.attn.qkv

            # Create LoRA version of qkv projection
            lora_qkv = LoRALinear(qkv_in_features, qkv_out_features, r=lora_rank)

            # Initialize with original weights
            lora_qkv.original.weight.data = original_qkv.weight.data.clone()
            if original_qkv.bias is not None:
                lora_qkv.original.bias.data = original_qkv.bias.data.clone()

            # Replace original with LoRA version
            block.attn.qkv = lora_qkv

    def forward(self, x):
        # x: [batch_size, num_frames, channels, height, width]
        batch_size, num_frames, channels, height, width = x.shape

        # Reshape to [batch_size * num_frames, channels, height, width]
        x = x.view(-1, channels, height, width)

        # Forward through ViT
        x = self.vit.patch_embed(x)
        cls_token = self.vit.cls_token.expand(batch_size * num_frames, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        x = self.vit.pos_drop(x + self.vit.pos_embed)

        # Forward through transformer blocks
        for block in self.vit.blocks:
            x = block(x)

        x = self.vit.norm(x)

        # Take cls token output
        x = x[:, 0]

        # Reshape back to [batch_size, num_frames, embed_dim]
        x = x.view(batch_size, num_frames, -1)

        # Average pooling over frames dimension
        x = torch.mean(x, dim=1)

        # Classification head
        x = self.head(x)

        return x

# Define dataset and data loading functions
def load_data_from_directory(data_dir):
    """
    Load video paths and labels from directory structure
    Each class should be in a separate folder
    """
    video_paths = []
    labels = []

    for class_id, class_name in class_names.items():
        class_dir = os.path.join(data_dir, class_name)
        if not os.path.exists(class_dir):
            print(f"Warning: Class directory {class_dir} not found")
            continue

        for video_file in os.listdir(class_dir):
            if video_file.endswith(('.mp4', '.avi', '.mov')):
                video_path = os.path.join(class_dir, video_file)
                video_paths.append(video_path)
                labels.append(class_id - 1)  # Convert to 0-indexed

    return video_paths, labels

def split_data(video_paths, labels, train_ratio=0.7, val_ratio=0.15):
    """
    Split data into train, validation, and test sets
    """
    # First split into train and temp (val + test)
    train_paths, temp_paths, train_labels, temp_labels = train_test_split(
        video_paths, labels, test_size=(1-train_ratio), stratify=labels, random_state=42
    )

    # Then split temp into val and test
    val_size = val_ratio / (1 - train_ratio)
    val_paths, test_paths, val_labels, test_labels = train_test_split(
        temp_paths, temp_labels, test_size=(1-val_size), stratify=temp_labels, random_state=42
    )

    return (train_paths, train_labels), (val_paths, val_labels), (test_paths, test_labels)

class VideoDataset(Dataset):
    """
    Dataset for video classification
    """
    def __init__(self, video_paths, labels, transform=None, num_frames=8, target_size=(224, 224)):
        self.video_paths = video_paths
        self.labels = labels
        self.transform = transform
        self.num_frames = num_frames
        self.target_size = target_size

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.labels[idx]

        try:
            # Load video
            vr = decord.VideoReader(video_path)
            total_frames = len(vr)

            # Sample frames uniformly
            frame_indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)

            # Get frames
            frames = vr.get_batch(frame_indices).asnumpy()

            # Process frames
            processed_frames = []
            for frame in frames:
                # Convert to RGB
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = cv2.resize(frame, self.target_size)
                if self.transform:
                    frame = self.transform(frame)
                processed_frames.append(frame)

            # Stack frames
            frames_tensor = torch.stack(processed_frames)

            # Rearrange to [num_frames, channels, height, width]
            #frames_tensor = frames_tensor.permute(0, 3, 1, 2)

            return frames_tensor, torch.tensor(label, dtype=torch.long)

        except Exception as e:
            print(f"Error loading video {video_path}: {e}")
            # Return a dummy tensor in case of error
            dummy_tensor = torch.zeros((self.num_frames, 3, self.target_size[0], self.target_size[1]))
            return dummy_tensor, torch.tensor(0, dtype=torch.long)

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10,
               scheduler=None, checkpoint_path="checkpoints", best_model_path="checkpoints/best_model.pth",
               start_epoch=0):
    """
    Train the model
    """
    # Make sure checkpoint directory exists in Drive
    drive_checkpoint_path = "/content/drive/MyDrive/crime_classification_checkpoints"
    os.makedirs(drive_checkpoint_path, exist_ok=True)

    # Initialize tracking variables
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    best_val_acc = 0.0

    # Create results directory
    os.makedirs("results", exist_ok=True)

    # Load previous metrics if resuming training
    if start_epoch > 0:
        try:
            with open(os.path.join(drive_checkpoint_path, "metrics.json"), "r") as f:
                metrics = json.load(f)
                train_losses = metrics["train_losses"]
                train_accuracies = metrics["train_accuracies"]
                val_losses = metrics["val_losses"]
                val_accuracies = metrics["val_accuracies"]

                if val_accuracies:
                    best_val_acc = max(val_accuracies)

            print(f"Loaded metrics from previous training session")
        except Exception as e:
            print(f"Could not load previous metrics: {e}")

    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")

        # Training phase
        model.train()
        running_loss = 0.0
        train_predictions = []
        train_targets = []

        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            train_predictions.extend(preds.cpu().numpy())
            train_targets.extend(labels.cpu().numpy())

        # Calculate epoch statistics
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = accuracy_score(train_targets, train_predictions)

        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)

        print(f"Training Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

        # Validation phase
        model.eval()
        val_running_loss = 0.0
        val_predictions = []
        val_targets = []

        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validation"):
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_running_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                val_predictions.extend(preds.cpu().numpy())
                val_targets.extend(labels.cpu().numpy())

        # Calculate validation statistics
        val_epoch_loss = val_running_loss / len(val_loader.dataset)
        val_epoch_acc = accuracy_score(val_targets, val_predictions)

        val_losses.append(val_epoch_loss)
        val_accuracies.append(val_epoch_acc)

        print(f"Validation Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f}")

        # Update learning rate scheduler if provided
        if scheduler:
            scheduler.step()

        # Save checkpoint for every epoch
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': epoch_loss,
            'val_loss': val_epoch_loss,
            'train_acc': epoch_acc,
            'val_acc': val_epoch_acc,
        }

        if scheduler:
            checkpoint['scheduler_state_dict'] = scheduler.state_dict()

        # Save to local checkpoint path
        os.makedirs(checkpoint_path, exist_ok=True)
        torch.save(checkpoint, os.path.join(checkpoint_path, f"checkpoint_epoch_{epoch+1}.pth"))

        # Also save to Google Drive
        torch.save(checkpoint, os.path.join(drive_checkpoint_path, f"checkpoint_epoch_{epoch+1}.pth"))

        # Save best model based on validation accuracy
        if val_epoch_acc > best_val_acc:
            best_val_acc = val_epoch_acc
            torch.save(model.state_dict(), best_model_path)
            # Also save to Drive
            torch.save(model.state_dict(), os.path.join(drive_checkpoint_path, "best_model.pth"))
            print(f"Saved best model with validation accuracy: {best_val_acc:.4f}")

        # Save metrics after each epoch
        metrics = {
            "train_losses": train_losses,
            "train_accuracies": train_accuracies,
            "val_losses": val_losses,
            "val_accuracies": val_accuracies,
        }
        with open("results/metrics.json", "w") as f:
            json.dump(metrics, f)

        # Also save metrics to Drive
        with open(os.path.join(drive_checkpoint_path, "metrics.json"), "w") as f:
            json.dump(metrics, f)

        # Generate and save confusion matrix and classification report
        if epoch % 5 == 0 or epoch == num_epochs - 1:
            cm = confusion_matrix(val_targets, val_predictions)
            plt.figure(figsize=(12, 10))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                        xticklabels=[class_names[i+1] for i in range(NUM_CLASSES)],
                        yticklabels=[class_names[i+1] for i in range(NUM_CLASSES)])
            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.title(f'Confusion Matrix - Epoch {epoch+1}')
            plt.tight_layout()
            plt.savefig(f"results/confusion_matrix_epoch_{epoch+1}.png")
            # Also save to Drive
            plt.savefig(os.path.join(drive_checkpoint_path, f"confusion_matrix_epoch_{epoch+1}.png"))
            plt.close()

            # Generate classification report
            report = classification_report(
                val_targets, val_predictions,
                target_names=[class_names[i+1] for i in range(NUM_CLASSES)],
                output_dict=True
            )

            # Save as JSON
            with open(f"results/classification_report_epoch_{epoch+1}.json", "w") as f:
                json.dump(report, f, indent=4)

            # Also save to Drive
            with open(os.path.join(drive_checkpoint_path, f"classification_report_epoch_{epoch+1}.json"), "w") as f:
                json.dump(report, f, indent=4)

    # Plot training and validation loss and accuracy
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Training Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.savefig("results/training_history.png")
    # Also save to Drive
    plt.savefig(os.path.join(drive_checkpoint_path, "training_history.png"))
    plt.close()

    return {
        'train_losses': train_losses,
        'train_accuracies': train_accuracies,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies,
        'best_val_acc': best_val_acc
    }

# Define evaluation function
def evaluate_model(model, test_loader, criterion):
    """
    Evaluate the model on the test set
    """
    # Create Drive directory for results
    drive_results_path = "/content/drive/MyDrive/crime_classification_results"
    os.makedirs(drive_results_path, exist_ok=True)

    model.eval()
    test_loss = 0.0
    test_predictions = []
    test_targets = []

    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc="Testing"):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            test_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            test_predictions.extend(preds.cpu().numpy())
            test_targets.extend(labels.cpu().numpy())

    # Calculate metrics
    test_loss /= len(test_loader.dataset)
    test_acc = accuracy_score(test_targets, test_predictions)
    test_precision = precision_score(test_targets, test_predictions, average='weighted')
    test_recall = recall_score(test_targets, test_predictions, average='weighted')
    test_f1 = f1_score(test_targets, test_predictions, average='weighted')

    # Generate confusion matrix
    cm = confusion_matrix(test_targets, test_predictions)

    # Generate and save classification report
    report = classification_report(
        test_targets, test_predictions,
        target_names=[class_names[i+1] for i in range(NUM_CLASSES)],
        output_dict=True
    )

    # Save detailed results
    results = {
        "test_loss": test_loss,
        "test_accuracy": test_acc,
        "test_precision": test_precision,
        "test_recall": test_recall,
        "test_f1": test_f1,
        "classification_report": report
    }

    with open("results/test_results.json", "w") as f:
        json.dump(results, f, indent=4)

    # Also save to Drive
    with open(os.path.join(drive_results_path, "test_results.json"), "w") as f:
        json.dump(results, f, indent=4)

    # Plot confusion matrix
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=[class_names[i+1] for i in range(NUM_CLASSES)],
                yticklabels=[class_names[i+1] for i in range(NUM_CLASSES)])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix - Test Set')
    plt.tight_layout()
    plt.savefig("results/test_confusion_matrix.png")
    # Also save to Drive
    plt.savefig(os.path.join(drive_results_path, "test_confusion_matrix.png"))
    plt.close()

    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_acc:.4f}")
    print(f"Test Precision: {test_precision:.4f}")
    print(f"Test Recall: {test_recall:.4f}")
    print(f"Test F1 Score: {test_f1:.4f}")

    return results

# Function for inference on a single video
def predict_video(model, video_path, num_frames=8, transform=None, target_size=(224, 224)):
    """
    Make prediction on a single video
    """
    model.eval()

    try:
        # Load video
        vr = decord.VideoReader(video_path)
        total_frames = len(vr)

        # Sample frames uniformly
        frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)

        # Get frames
        frames = vr.get_batch(frame_indices).asnumpy()

        # Process frames
        processed_frames = []
        for frame in frames:
            # Convert to RGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame, target_size)
            if transform:
                frame = transform(frame)
            processed_frames.append(frame)

        # Stack frames
        frames_tensor = torch.stack(processed_frames)

        # Rearrange to [num_frames, channels, height, width]
        frames_tensor = frames_tensor.permute(0, 3, 1, 2)

        # Add batch dimension
        frames_tensor = frames_tensor.unsqueeze(0).to(device)

        # Make prediction
        with torch.no_grad():
            outputs = model(frames_tensor)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            confidence, prediction = torch.max(probabilities, 1)

        return {
            'prediction': prediction.item(),
            'predicted_class': class_names[prediction.item() + 1],
            'confidence': confidence.item(),
            'probabilities': {class_names[i+1]: float(prob) for i, prob in enumerate(probabilities[0])}
        }

    except Exception as e:
        print(f"Error processing video {video_path}: {e}")
        return None

# Modified function to use manually downloaded pretrained weights
def download_pretrained_weights(model_name="TimeSformer_divST_8x32_224_K400"):
    """
    Use manually downloaded TimeSformer weights from Google Drive
    """
    import os
    from google.colab import drive

    # Mount Google Drive if in Colab
    drive.mount('/content/drive')

    # Path to your manually downloaded weights in Google Drive
    drive_model_path = "/content/drive/MyDrive/pretrained_models/TimeSformer_divST_8x32_224_K400.pyth"

    # Create directory for pretrained models in the local runtime
    os.makedirs("pretrained_models", exist_ok=True)
    local_path = f"pretrained_models/{model_name}.pyth"

    # Copy from Drive to local runtime
    if os.path.exists(drive_model_path):
        print(f"Copying pretrained weights from Drive to local runtime...")
        import shutil
        shutil.copy(drive_model_path, local_path)
        print(f"Copied to {local_path}")
    else:
        print(f"Error: Model file not found at {drive_model_path}")
        return None

    return local_path

def main():
    # Mount Google Drive
    drive.mount('/content/drive')

    # Set parameters
    num_frames = 8
    batch_size = 8  # Reduced from 16 to 8 for T4 GPU
    num_epochs = 20
    learning_rate = 1e-4
    embedding_dim = 768
    num_heads = 12
    depth = 12
    lora_rank = 16

    # Path to your dataset
    data_root = "/content/drive/MyDrive/Anamoly"  # Update this to your actual dataset path

    # Define transforms for training and validation/testing
    train_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Get pretrained model path
    pretrained_model_path = download_pretrained_weights()

    # Load data from directory structure
    print("Loading video data...")
    video_paths, labels = load_data_from_directory(data_root)
    print(f"Found {len(video_paths)} videos across {len(set(labels))} classes")

    # Split data
    (train_paths, train_labels), (val_paths, val_labels), (test_paths, test_labels) = split_data(
        video_paths, labels, train_ratio=0.7, val_ratio=0.15
    )

    print(f"Training set: {len(train_paths)} videos")
    print(f"Validation set: {len(val_paths)} videos")
    print(f"Test set: {len(test_paths)} videos")

    # Create datasets
    train_dataset = VideoDataset(
        train_paths, train_labels, transform=train_transform,
        num_frames=num_frames, target_size=(224, 224)
    )

    val_dataset = VideoDataset(
        val_paths, val_labels, transform=val_transform,
        num_frames=num_frames, target_size=(224, 224)
    )

    test_dataset = VideoDataset(
        test_paths, test_labels, transform=val_transform,
        num_frames=num_frames, target_size=(224, 224)
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=4, pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True
    )

    # Initialize model
    print("Initializing TimeSformer model with LoRA fine-tuning...")
    model = TimeSformerForVideoClassification(
        num_classes=NUM_CLASSES,
        embed_dim=embedding_dim,
        depth=depth,
        num_heads=num_heads,
        num_frames=num_frames,
        lora_rank=lora_rank,
        pretrained=True,
        pretrained_model_path=pretrained_model_path
    ).to(device)

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()

# Only optimize LoRA parameters and classification head
    optimizer = optim.AdamW([
        {'params': [p for n, p in model.named_parameters() if 'lora' in n], 'lr': learning_rate},
        {'params': [p for n, p in model.named_parameters() if 'head' in n], 'lr': learning_rate * 10}
    ], weight_decay=0.01)

    # Learning rate scheduler
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)

    # Check for existing checkpoints in Drive
    drive_checkpoint_path = "/content/drive/MyDrive/crime_classification_checkpoints"
    start_epoch = 0

    if os.path.exists(drive_checkpoint_path):
        checkpoint_files = sorted([f for f in os.listdir(drive_checkpoint_path) if f.startswith("checkpoint_epoch_")])

        if checkpoint_files:
            # Get the latest checkpoint
            latest_checkpoint = os.path.join(drive_checkpoint_path, checkpoint_files[-1])
            print(f"Found checkpoint: {latest_checkpoint}")

            try:
                # Load checkpoint
                checkpoint = torch.load(latest_checkpoint)
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint['epoch']

                if scheduler and 'scheduler_state_dict' in checkpoint:
                    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

                print(f"Resuming training from epoch {start_epoch}")
            except Exception as e:
                print(f"Error loading checkpoint: {e}")
                print("Starting training from scratch")
                start_epoch = 0
        else:
            print("No checkpoint found, starting training from scratch")
    else:
        print("No checkpoint directory found, starting training from scratch")

    # Train model
    print("Starting training...")
    train_results = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        num_epochs=num_epochs,
        scheduler=scheduler,
        checkpoint_path="checkpoints",
        best_model_path="checkpoints/best_model.pth",
        start_epoch=start_epoch
    )

    # Load best model for evaluation
    best_model_drive_path = os.path.join(drive_checkpoint_path, "best_model.pth")
    if os.path.exists(best_model_drive_path):
        print("Loading best model from Drive for evaluation...")
        model.load_state_dict(torch.load(best_model_drive_path))
    else:
        print("Loading best model from local path for evaluation...")
        model.load_state_dict(torch.load("checkpoints/best_model.pth"))

    # Evaluate on test set
    print("Evaluating on test set...")
    test_results = evaluate_model(model, test_loader, criterion)

    # Save model architecture for later use
    model_info = {
        "num_classes": NUM_CLASSES,
        "embed_dim": embedding_dim,
        "depth": depth,
        "num_heads": num_heads,
        "num_frames": num_frames,
        "lora_rank": lora_rank
    }

    with open("results/model_info.json", "w") as f:
        json.dump(model_info, f, indent=4)

    # Also save to Drive
    drive_results_path = "/content/drive/MyDrive/crime_classification_results"
    os.makedirs(drive_results_path, exist_ok=True)
    with open(os.path.join(drive_results_path, "model_info.json"), "w") as f:
        json.dump(model_info, f, indent=4)

    print("Training and evaluation complete!")
    print(f"Best validation accuracy: {train_results['best_val_acc']:.4f}")
    print(f"Test accuracy: {test_results['test_accuracy']:.4f}")

    # Example inference on a sample video
    if test_paths:
        sample_video = test_paths[0]
        print(f"Running inference on sample video: {sample_video}")
        prediction = predict_video(model, sample_video, num_frames=num_frames, transform=val_transform)

        if prediction:
            print(f"Predicted class: {prediction['predicted_class']}")
            print(f"Confidence: {prediction['confidence']:.4f}")

            # Save prediction to Drive
            with open(os.path.join(drive_results_path, "sample_prediction.json"), "w") as f:
                json.dump(prediction, f, indent=4)

if __name__ == "__main__":
    main()