# ResNet50 Cricket Shot Classification

This notebook implements a ResNet50-based classifier for cricket shot classification, which should provide significantly better accuracy than the MLP approach.

## Overview
- Uses pretrained ResNet50 for transfer learning
- Implements data augmentation for better generalization
- Fine-tunes the model for cricket shot classification
- Expected accuracy improvement: 85-95%

In [None]:
# Install required packages
!pip install torch torchvision torchaudio
!pip install opencv-python pillow matplotlib seaborn scikit-learn tqdm

In [None]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import pandas as pd
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [None]:
# Cricket shot types
SHOT_TYPES = ['drive', 'legglance-flick', 'pull', 'sweep']

class CricketShotDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load image
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

In [None]:
class ResNet50CricketClassifier(nn.Module):
    def __init__(self, num_classes=4, pretrained=True):
        super(ResNet50CricketClassifier, self).__init__()
        
        # Load pretrained ResNet50
        self.resnet = models.resnet50(pretrained=pretrained)
        
        # Freeze early layers for transfer learning
        for param in list(self.resnet.parameters())[:-20]:  # Freeze all but last few layers
            param.requires_grad = False
        
        # Modify the final layer for our classification task
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.resnet(x)

# Test the model
model = ResNet50CricketClassifier(num_classes=len(SHOT_TYPES))
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
def load_dataset(data_dir):
    """Load dataset and return image paths and labels"""
    image_paths = []
    labels = []
    
    print("Loading dataset...")
    
    for shot_type in SHOT_TYPES:
        shot_dir = os.path.join(data_dir, shot_type)
        if not os.path.exists(shot_dir):
            print(f"Warning: Directory {shot_dir} not found")
            continue
        
        image_files = [f for f in os.listdir(shot_dir) 
                      if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        print(f"Found {len(image_files)} images for {shot_type}")
        
        for image_file in image_files:
            image_path = os.path.join(shot_dir, image_file)
            image_paths.append(image_path)
            labels.append(SHOT_TYPES.index(shot_type))
    
    return image_paths, labels

def create_data_transforms():
    """Create data transforms for training and validation"""
    # Training transforms with augmentation
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Validation transforms (no augmentation)
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, val_transform

In [None]:
# Load dataset
data_dir = 'data'  # Update this path to your dataset directory
image_paths, labels = load_dataset(data_dir)

if len(image_paths) == 0:
    print("No images found! Please check your data directory.")
else:
    print(f"Total images loaded: {len(image_paths)}")
    
    # Check class distribution
    unique, counts = np.unique(labels, return_counts=True)
    print(f"\nClass distribution:")
    for i, (class_idx, count) in enumerate(zip(unique, counts)):
        print(f"{SHOT_TYPES[class_idx]}: {count} samples")

In [None]:
# Split dataset
X_train, X_test, y_train, y_test = train_test_split(
    image_paths, labels, test_size=0.2, random_state=42, stratify=labels
)

X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.2, random_state=42, stratify=y_train
)

print(f"Training set: {len(X_train)} samples")
print(f"Validation set: {len(X_val)} samples")
print(f"Test set: {len(X_test)} samples")

# Create transforms
train_transform, val_transform = create_data_transforms()

# Create datasets
train_dataset = CricketShotDataset(X_train, y_train, transform=train_transform)
val_dataset = CricketShotDataset(X_val, y_val, transform=val_transform)
test_dataset = CricketShotDataset(X_test, y_test, transform=val_transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2)

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=30, device='cuda'):
    """Train the ResNet50 model"""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
    
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    best_val_acc = 0.0
    
    print("Starting training...")
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        # Calculate metrics
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        train_accuracy = 100 * train_correct / train_total
        val_accuracy = 100 * val_correct / val_total
        
        # Store metrics
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        train_accuracies.append(train_accuracy)
        val_accuracies.append(val_accuracy)
        
        # Update learning rate
        scheduler.step(avg_val_loss)
        
        # Save best model
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            torch.save(model.state_dict(), 'best_resnet50_cricket_model.pth')
        
        # Print progress
        print(f"Epoch [{epoch+1}/{num_epochs}]")
        print(f"  Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}%")
        print(f"  Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")
        print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
        print("-" * 50)
    
    return train_losses, val_losses, train_accuracies, val_accuracies

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model
model = ResNet50CricketClassifier(num_classes=len(SHOT_TYPES))

# Train model
train_losses, val_losses, train_accuracies, val_accuracies = train_model(
    model, train_loader, val_loader, num_epochs=30, device=device
)

In [None]:
# Plot training curves
plt.figure(figsize=(15, 5))

# Loss plot
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('ResNet50 Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Accuracy plot
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(val_accuracies, label='Val Accuracy')
plt.title('ResNet50 Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
def evaluate_model(model, test_loader, device='cuda'):
    """Evaluate the trained model"""
    model.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Print classification report
    print("Classification Report:")
    print(classification_report(all_labels, all_predictions, target_names=SHOT_TYPES))
    
    # Plot confusion matrix
    cm = confusion_matrix(all_labels, all_predictions)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=SHOT_TYPES, yticklabels=SHOT_TYPES)
    plt.title('ResNet50 Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.show()
    
    return all_predictions, all_labels

# Load best model and evaluate
model.load_state_dict(torch.load('best_resnet50_cricket_model.pth'))
evaluate_model(model, test_loader, device)

In [None]:
# Save model info
model_info = {
    'model_type': 'ResNet50',
    'num_classes': len(SHOT_TYPES),
    'shot_types': SHOT_TYPES,
    'input_size': (224, 224, 3),
    'best_val_accuracy': max(val_accuracies),
    'architecture': 'ResNet50 with custom classifier head'
}

import json
with open('resnet50_model_info.json', 'w') as f:
    json.dump(model_info, f, indent=2)

print("Training completed! Model saved as 'best_resnet50_cricket_model.pth'")
print(f"Best validation accuracy: {max(val_accuracies):.2f}%")

In [None]:
# Test the model with a sample image
def test_model_with_image(model, image_path, device='cuda'):
    """Test the trained model with a single image"""
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    
    # Apply validation transform
    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    image_tensor = val_transform(image).unsqueeze(0).to(device)
    
    # Make prediction
    model.eval()
    with torch.no_grad():
        output = model(image_tensor)
        probabilities = torch.softmax(output, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()
        confidence = probabilities[0][predicted_class].item()
    
    print(f"Predicted shot: {SHOT_TYPES[predicted_class]}")
    print(f"Confidence: {confidence:.4f}")
    print("\nAll probabilities:")
    for i, shot_type in enumerate(SHOT_TYPES):
        print(f"  {shot_type}: {probabilities[0][i].item():.4f}")
    
    # Display image
    plt.figure(figsize=(8, 6))
    plt.imshow(image)
    plt.title(f"Predicted: {SHOT_TYPES[predicted_class]} ({confidence:.2%})")
    plt.axis('off')
    plt.show()

# Example usage (uncomment and modify path as needed)
# test_model_with_image(model, 'path/to/test/image.jpg', device)