In [None]:
!pip install decord
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import os
from glob import glob
import numpy as np
from google.colab import drive
import decord
from sklearn.model_selection import train_test_split
import json
import time
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, accuracy_score, classification_report

# Mount Google Drive
drive.mount('/content/drive')

# Set paths
UCF_CRIME_PATH = "/content/drive/MyDrive/Anamoly2"  # Update this path
CHECKPOINT_DIR = "/content/drive/MyDrive/timesformer_checkpoints"
RESULTS_DIR = "/content/drive/MyDrive/ucf_crime_results"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

# Function to extract frames from videos
def extract_frames(video_path, num_frames=8):
    try:
        vr = decord.VideoReader(video_path)
        total_frames = len(vr)
        if total_frames == 0:
            print(f"Warning: No frames in video {video_path}")
            return None
            
        # Get evenly spaced frames
        frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
        frames = []
        for idx in frame_indices:
            frame = vr[int(idx)].asnumpy()  # Get numpy array
            frames.append(frame)
            
        return np.stack(frames)  # Return shape: (num_frames, height, width, 3)
    except Exception as e:
        print(f"Error extracting frames from {video_path}: {e}")
        return None

# Custom VideoClassifier model (no TimeSformer dependency)
class SimpleVideoClassifier(nn.Module):
    def __init__(self, num_classes, num_frames=8):
        super().__init__()
        # Use ResNet18 as base CNN
        from torchvision.models import resnet18, ResNet18_Weights
        
        # Initialize ResNet with pretrained weights
        base_model = resnet18(weights=ResNet18_Weights.DEFAULT)
        
        # Remove the final classification layer
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])
        
        # Add temporal modeling
        self.temporal_pool = nn.AdaptiveAvgPool1d(1)
        
        # Classification head
        self.classifier = nn.Linear(512, num_classes)
        
    def forward(self, x):
        # Input x has shape [batch_size, channels=3, frames, height, width]
        batch_size, channels, num_frames, height, width = x.shape
        
        # Reshape for frame-wise processing
        x = x.permute(0, 2, 1, 3, 4)  # [batch, frames, channels, height, width]
        x = x.reshape(batch_size * num_frames, channels, height, width)
        
        # Extract features from each frame
        frame_features = self.feature_extractor(x)  # [batch*frames, 512, 1, 1]
        frame_features = frame_features.reshape(batch_size, num_frames, -1)  # [batch, frames, 512]
        
        # Temporal pooling across frames
        temporal_features = frame_features.transpose(1, 2)  # [batch, 512, frames]
        pooled_features = self.temporal_pool(temporal_features).squeeze(-1)  # [batch, 512]
        
        # Classification
        output = self.classifier(pooled_features)
        
        return output

# Dataset class
class UCFCrimeDataset(Dataset):
    def __init__(self, video_paths, labels, transform=None, num_frames=8):
        self.video_paths = video_paths
        self.labels = labels
        self.transform = transform
        self.num_frames = num_frames
        
    def __len__(self):
        return len(self.video_paths)
    
    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.labels[idx]
        
        # Extract frames
        frames = extract_frames(video_path, self.num_frames)
        if frames is None:  # Handle error case
            frames = np.zeros((self.num_frames, 224, 224, 3), dtype=np.uint8)
        
        # Convert to tensor [T, H, W, C] -> [T, C, H, W]
        frame_tensors = []
        for frame in frames:
            frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float()  # [C, H, W]
            if self.transform:
                frame_tensor = self.transform(frame_tensor)
            frame_tensors.append(frame_tensor)
        
        # Stack frames to get [T, C, H, W]
        video_tensor = torch.stack(frame_tensors)
        
        # Reshape to [C, T, H, W] for 3D models
        video_tensor = video_tensor.permute(1, 0, 2, 3)
        
        return video_tensor, label

# Define transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ConvertImageDtype(torch.float32),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Function to load dataset with a test split
def load_dataset_with_test():
    """
    Load dataset with train, validation, and test splits
    """
    # Find all video files and their labels
    video_paths = []
    labels = []
    
    class_dirs = sorted([d for d in os.listdir(UCF_CRIME_PATH) 
                        if os.path.isdir(os.path.join(UCF_CRIME_PATH, d))])
    class_to_idx = {cls_name: i for i, cls_name in enumerate(class_dirs)}
    
    print(f"Found {len(class_dirs)} classes: {class_dirs}")
    print(f"Class to index mapping: {class_to_idx}")
    
    # Save class mapping for later use
    with open(os.path.join(CHECKPOINT_DIR, "class_mapping.json"), "w") as f:
        json.dump(class_to_idx, f, indent=4)
    
    for cls_name in class_dirs:
        cls_dir = os.path.join(UCF_CRIME_PATH, cls_name)
        videos = glob(os.path.join(cls_dir, "*.mp4"))
        videos.extend(glob(os.path.join(cls_dir, "*.avi")))
        
        print(f"Found {len(videos)} videos in class {cls_name}")
        
        for video_path in videos:
            video_paths.append(video_path)
            labels.append(class_to_idx[cls_name])
    
    if not video_paths:
        raise ValueError("No video files found. Please check your dataset path.")
        
    # First split: 80% for train+val, 20% for testing
    train_val_paths, test_paths, train_val_labels, test_labels = train_test_split(
        video_paths, labels, test_size=0.2, random_state=42, stratify=labels if len(set(labels)) > 1 else None
    )
    
    # Second split: 80% of train_val for training, 20% for validation (which is 0.2 * 0.8 = 0.16 of total)
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        train_val_paths, train_val_labels, test_size=0.2, random_state=42, 
        stratify=train_val_labels if len(set(train_val_labels)) > 1 else None
    )
    
    print(f"Training samples: {len(train_paths)}")
    print(f"Validation samples: {len(val_paths)}")
    print(f"Test samples: {len(test_paths)}")
    
    return train_paths, val_paths, test_paths, train_labels, val_labels, test_labels, len(class_dirs), class_to_idx

# Save checkpoint function
def save_checkpoint(model, optimizer, epoch, val_loss, val_acc, checkpoint_dir, filename, is_best=False, training_history=None):
    """
    Save model checkpoint with comprehensive metadata
    """
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
        'val_acc': val_acc,
        'timestamp': time.time(),
        'training_history': training_history
    }
    
    # Save the checkpoint
    filepath = os.path.join(checkpoint_dir, filename)
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved to {filepath}")
    
    # If this is the best model, make a copy as best_model.pth
    if is_best:
        best_filepath = os.path.join(checkpoint_dir, "best_model.pth")
        torch.save(checkpoint, best_filepath)
        print(f"Best model saved to {best_filepath}")

# Function to find the latest checkpoint
def find_latest_checkpoint(checkpoint_dir):
    """
    Find the latest checkpoint file in the directory
    """
    checkpoint_files = glob(os.path.join(checkpoint_dir, "checkpoint_epoch_*.pth"))
    if not checkpoint_files:
        return None
    
    # Sort by modification time (newest first)
    checkpoint_files.sort(key=os.path.getmtime, reverse=True)
    return checkpoint_files[0]

# Function to evaluate the model and save comprehensive metrics
def evaluate_model(model, test_loader, device, class_names, class_indices, results_dir):
    """
    Evaluate the model and save comprehensive metrics
    """
    model.eval()
    all_preds = []
    all_labels = []
    
    print("Running inference on test set...")
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Convert to numpy arrays
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, support = precision_recall_fscore_support(
        all_labels, all_preds, average=None, labels=range(len(class_names))
    )
    precision_avg, recall_avg, f1_avg, _ = precision_recall_fscore_support(
        all_labels, all_preds, average='weighted'
    )
    
    # Create confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    # Create classification report
    report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
    
    # Print summary results
    print(f"Overall Accuracy: {accuracy:.4f}")
    print(f"Weighted Precision: {precision_avg:.4f}")
    print(f"Weighted Recall: {recall_avg:.4f}")
    print(f"Weighted F1 Score: {f1_avg:.4f}")
    
    # Prepare results dictionary
    results = {
        'accuracy': float(accuracy),
        'precision_per_class': {class_names[i]: float(precision[i]) for i in range(len(class_names))},
        'recall_per_class': {class_names[i]: float(recall[i]) for i in range(len(class_names))},
        'f1_per_class': {class_names[i]: float(f1[i]) for i in range(len(class_names))},
        'support_per_class': {class_names[i]: int(support[i]) for i in range(len(class_names))},
        'precision_weighted': float(precision_avg),
        'recall_weighted': float(recall_avg), 
        'f1_weighted': float(f1_avg),
        'confusion_matrix': cm.tolist(),
        'classification_report': report
    }
    
    # Create results directory with timestamp
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    eval_dir = os.path.join(results_dir, f"evaluation_{timestamp}")
    os.makedirs(eval_dir, exist_ok=True)
    
    # Save results as JSON
    with open(os.path.join(eval_dir, "metrics.json"), "w") as f:
        json.dump(results, f, indent=4)
    
    # Save metrics as CSV for easy import to Excel/sheets
    metrics_df = pd.DataFrame({
        'Class': list(results['precision_per_class'].keys()),
        'Precision': list(results['precision_per_class'].values()),
        'Recall': list(results['recall_per_class'].values()),
        'F1 Score': list(results['f1_per_class'].values()),
        'Support': list(results['support_per_class'].values())
    })
    metrics_df.to_csv(os.path.join(eval_dir, 'metrics_by_class.csv'), index=False)
    
    # Save summary metrics
    summary_df = pd.DataFrame({
        'Metric': ['Accuracy', 'Weighted Precision', 'Weighted Recall', 'Weighted F1'],
        'Value': [
            results['accuracy'], 
            results['precision_weighted'], 
            results['recall_weighted'], 
            results['f1_weighted']
        ]
    })
    summary_df.to_csv(os.path.join(eval_dir, 'summary_metrics.csv'), index=False)
    
    # Save confusion matrix as CSV
    cm_df = pd.DataFrame(
        results['confusion_matrix'],
        index=list(results['precision_per_class'].keys()),
        columns=list(results['precision_per_class'].keys())
    )
    cm_df.to_csv(os.path.join(eval_dir, 'confusion_matrix.csv'))
    
    # Plot and save confusion matrix
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.savefig(os.path.join(eval_dir, 'confusion_matrix.png'))
    plt.close()
    
    # Plot metrics by class
    metrics_to_plot = [
        ('precision_per_class', 'Precision by Class'),
        ('recall_per_class', 'Recall by Class'),
        ('f1_per_class', 'F1 Score by Class')
    ]
    
    for metric_key, title in metrics_to_plot:
        plt.figure(figsize=(14, 8))
        values = results[metric_key]
        classes = list(values.keys())
        scores = list(values.values())
        
        # Sort by class name for consistent display
        sorted_data = sorted(zip(classes, scores))
        classes, scores = zip(*sorted_data)
        
        bars = plt.bar(classes, scores, color='skyblue')
        plt.xlabel('Class')
        plt.ylabel('Score')
        plt.title(title)
        plt.xticks(rotation=45, ha='right')
        plt.ylim(0, 1.0)
        
        # Add values on top of bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                     f'{height:.2f}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.savefig(os.path.join(eval_dir, f"{title.lower().replace(' ', '_')}.png"))
        plt.close()
    
    # Plot training history if available
    if os.path.exists(os.path.join(CHECKPOINT_DIR, "training_history.json")):
        with open(os.path.join(CHECKPOINT_DIR, "training_history.json"), "r") as f:
            history = json.load(f)
        
        # Plot training & validation accuracy
        plt.figure(figsize=(12, 6))
        plt.plot(history['train_acc'], label='Training Accuracy')
        plt.plot(history['val_acc'], label='Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.title('Training and Validation Accuracy')
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(eval_dir, 'accuracy_history.png'))
        plt.close()
        
        # Plot training & validation loss
        plt.figure(figsize=(12, 6))
        plt.plot(history['train_loss'], label='Training Loss')
        plt.plot(history['val_loss'], label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(eval_dir, 'loss_history.png'))
        plt.close()
    
    print(f"Evaluation complete. Results saved to {eval_dir}")
    return results

# Main training function
def train():
    print("Starting Video Classification training on UCF Crime Dataset")
    
    # Load dataset with test split
    train_paths, val_paths, test_paths, train_labels, val_labels, test_labels, num_classes, class_mapping = load_dataset_with_test()
    
    # Create datasets
    train_dataset = UCFCrimeDataset(train_paths, train_labels, transform)
    val_dataset = UCFCrimeDataset(val_paths, val_labels, transform)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0)
    
    # Save test set info for later evaluation
    test_data = {
        'test_paths': test_paths,
        'test_labels': test_labels
    }
    with open(os.path.join(CHECKPOINT_DIR, "test_data.json"), "w") as f:
        json.dump(test_data, f)
    
    # Initialize model
    model = SimpleVideoClassifier(num_classes)
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Set up training parameters
    start_epoch = 0
    num_epochs = 20  # Extended training duration
    best_val_acc = 0.0
    
    # Training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    # Check for checkpoints to resume from
    latest_checkpoint = find_latest_checkpoint(CHECKPOINT_DIR)
    
    if latest_checkpoint:
        print(f"Found checkpoint: {latest_checkpoint}")
        try:
            checkpoint = torch.load(latest_checkpoint, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer = optim.AdamW(model.parameters(), lr=1e-4)
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint['epoch']
            best_val_acc = checkpoint.get('val_acc', 0.0)
            
            # Restore history if available
            if 'training_history' in checkpoint and checkpoint['training_history'] is not None:
                history = checkpoint['training_history']
                
            print(f"Resuming from epoch {start_epoch} with validation accuracy {best_val_acc}")
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
            print("Starting from scratch instead.")
            optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    else:
        print("No checkpoint found. Starting from scratch.")
        optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    
    # Training loop
    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            if (i+1) % 5 == 0:
                print(f"Batch {i+1}, Loss: {loss.item():.4f}, Acc: {100*correct/total:.2f}%")
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")
        
        # Save training metrics
        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_loss = val_loss / len(val_loader)
        val_acc = 100 * val_correct / val_total
        print(f"Validation Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%")
        
        # Save validation metrics
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Save regular checkpoint every epoch
        checkpoint_filename = f"checkpoint_epoch_{epoch+1}.pth"
        is_best = val_acc > best_val_acc
        
        if is_best:
            best_val_acc = val_acc
            print(f"New best validation accuracy: {val_acc:.2f}%")
        
        save_checkpoint(
            model, optimizer, epoch, val_loss, val_acc, 
            CHECKPOINT_DIR, checkpoint_filename, 
            is_best=is_best, 
            training_history=history
        )
        
        # Always save a fallback "last_checkpoint.pth" - this ensures we can always resume
        save_checkpoint(
            model, optimizer, epoch, val_loss, val_acc, 
            CHECKPOINT_DIR, "last_checkpoint.pth", 
            is_best=False, 
            training_history=history
        )
    
    print("Training complete!")
    
    # Save the final model regardless of performance
    final_checkpoint_path = os.path.join(CHECKPOINT_DIR, "final_model.pth")
    torch.save({
        'epoch': num_epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
        'val_acc': val_acc,
        'training_history': history
    }, final_checkpoint_path)
    print(f"Final model saved to {final_checkpoint_path}")
    
    # Save training history as JSON
    history_path = os.path.join(CHECKPOINT_DIR, "training_history.json")
    with open(history_path, 'w') as f:
        json.dump(history, f, indent=4)
    print(f"Training history saved to {history_path}")

# Function to evaluate the model on the test set
def evaluate():
    print("Starting evaluation on test set")
    
    # Load test data
    test_data_path = os.path.join(CHECKPOINT_DIR, "test_data.json")
    if not os.path.exists(test_data_path):
        print("Test data not found. Please run training first.")
        return
    
    with open(test_data_path, "r") as f:
        test_data = json.load(f)
    
    test_paths = test_data['test_paths']
    test_labels = test_data['test_labels']
    
    # Load class mapping
    with open(os.path.join(CHECKPOINT_DIR, "class_mapping.json"), "r") as f:
        class_mapping = json.load(f)
    
    # Reverse the class mapping to get index to class name
    idx_to_class = {int(v): k for k, v in class_mapping.items()}
    class_names = [idx_to_class[i] for i in range(len(idx_to_class))]
    
    # Create test dataset and loader
    test_dataset = UCFCrimeDataset(test_paths, test_labels, transform)
    test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=0)
    
    # Choose model to evaluate
    model_choices = {
        "1": ("best_model.pth", "Best validation model"),
        "2": ("final_model.pth", "Final model (after all epochs)"),
        "3": ("last_checkpoint.pth", "Last saved checkpoint")
    }
    
    print("\nChoose a model to evaluate:")
    for key, (filename, description) in model_choices.items():
        if os.path.exists(os.path.join(CHECKPOINT_DIR, filename)):
            print(f"{key}: {description}")
    
    choice = input("Enter choice (1/2/3): ")
    while choice not in model_choices or not os.path.exists(os.path.join(CHECKPOINT_DIR, model_choices[choice][0])):
        print("Invalid choice or model file not found.")
        choice = input("Enter choice (1/2/3): ")
    
    model_path = os.path.join(CHECKPOINT_DIR, model_choices[choice][0])
    print(f"Evaluating model: {model_path}")
    
    # Load model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = len(class_mapping)
    model = SimpleVideoClassifier(num_classes)
    
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    # Evaluate model
    evaluate_model(model, test_loader, device, class_names, class_mapping, RESULTS_DIR)

# Main execution function - allows choosing between training and evaluation
def main():
    print("UCF Crime Video Classification - Training and Evaluation")
    print("1: Train model")
    print("2: Evaluate model")
    
    choice = input("Enter choice (1/2): ")
    
    if choice == '1':
        train()
    elif choice == '2':
        evaluate()
    else:
        print("Invalid choice. Exiting.")

if __name__ == "__main__":
    main()