Multiclass Classification (Play Type Classification)

Author: Sebastian Pareiss 
Year: 2024 

In this Jupyter Notebook, a neural network for multiclass classification is created. Each cell includes a brief summary of its respective task. This notebook was created with the assistance of ChatGPT 4.

In [8]:
# Import all libraries and packages

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torchvision import  transforms
from torchvision import models
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import copy
import random
from sklearn.metrics import confusion_matrix
from collections import defaultdict
from PIL import Image

Data Preparation and Processing

- Data Transformations: Define transformations including resizing to 224x224, tensor conversion, and normalization
- Data Directory: Set path to training, validation, and test data
- Class Filtering Function: Filter datasets to include only 'Run', 'Pass', 'Punt', 'Kick-Off' and 'Field Goal' classes
- Dataset Application: Apply filters across training, validation, and test datasets
- Dataset Summary: Print sizes and class distributions within the datasets

In [None]:
# Define data transformations
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        #transforms.RandomHorizontalFlip(p=0.5),
        #transforms.RandomRotation(degrees=10),
        #transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  
    ]),
    'test': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Path to your data directory
data_dir = 'Sets'

# Load data from directories, focusing on 'Play' and 'Time_Between' classes only
def filter_classes(dataset, classes_to_include):
    # Filter samples
    filtered_samples = [(path, label) for path, label in dataset.samples if dataset.classes[label] in classes_to_include]
    
    # Reassign targets and samples
    new_targets = []
    new_samples = []
    class_to_idx = {cls: idx for idx, cls in enumerate(classes_to_include)}
    
    for path, label in filtered_samples:
        class_name = dataset.classes[label]
        if class_name in class_to_idx:
            new_label = class_to_idx[class_name]
            new_samples.append((path, new_label))
            new_targets.append(new_label)
    
    dataset.samples = new_samples
    dataset.targets = new_targets
    dataset.classes = classes_to_include
    dataset.class_to_idx = class_to_idx

# Apply the filter to each dataset split
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val', 'test']}
for phase in ['train', 'val', 'test']:
    filter_classes(image_datasets[phase], ['Run', 'Pass', 'Punt', 'Kick-Off', 'Field_Goal'])

# Check dataset sizes and class names
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
class_names = image_datasets['train'].classes

# Print dataset sizes and class names
print("Filtered dataset sizes:", dataset_sizes)
print("Filtered classes:", class_names)

# Print number of frames for each class in each dataset
for phase in ['train', 'val', 'test']:
    print(f"\n{phase.upper()} dataset:")
    class_counts = {class_name: 0 for class_name in class_names}
    for _, label in image_datasets[phase].samples:
        class_name = class_names[label]
        class_counts[class_name] += 1
    for class_name, count in class_counts.items():
        print(f"  {class_name}: {count} frames")



Video Dataset Processing:

- Custom Dataset Class (VideoDataset): Constructs video datasets by processing directories of frame sequences, ensuring all frames in a sequence share the same label. This class also handles potential sequence imbalance by limiting the number of sequences per label if a maximum is specified
- Initialization Parameters: Includes settings for the sequence length and optional transformation application to each frame
Sequence Processing: Frames are grouped by video and label, sorted, and checked for consistency before being compiled into sequences. Sequences with mismatched labels within their frames are discarded.
- Dataset Setup: Applies filtering functions to train, validation, and test splits to focus on specific classes ('Run', 'Pass', 'Punt', 'Kick-Off', 'Field_Goal'). Configures the maximum number of sequences to balance and limit the data effectively
- Data Loader Customization: Implements a custom collate function to handle potential None entries in batches, ensuring robust data loading
- Dataset Verification: Outputs class and sequence distribution for each dataset, ensuring correct setup and providing transparency about the data structure and contents

In [None]:
class VideoDataset(Dataset):
    def __init__(self, image_folder_dataset, seq_length, max_sequences=None, transform=None):
        self.image_folder_dataset = image_folder_dataset
        self.seq_length = seq_length
        self.transform = transform
        self.sequences = []
        self.label_distribution = defaultdict(int)

        # Group frames by directory (label) and video ID
        video_label_frames = defaultdict(list)
        for path, label in self.image_folder_dataset.samples:
            # Skip augmented frames
            if path.split(os.sep)[-1].startswith('aug_'):
                continue
            path_elements = path.split(os.sep)
            video_id = path_elements[-1].split('_frame_')[0]
            label_dir = path_elements[-2]
            full_id = f"{label_dir}_{video_id}"
            video_label_frames[full_id].append((path, label))

        # Process frames for each unique video-label combination
        for frames in video_label_frames.values():
            # Sort frames within the same video and label by frame number
            frames.sort(key=lambda x: int(x[0].split('_frame_')[1].split('.')[0]))

            # Create valid sequences
            for i in range(len(frames) - seq_length + 1):
                sequence = frames[i:i + seq_length]
                if len(set(label for _, label in sequence)) == 1:  # Check if all labels in sequence are the same
                    self.sequences.append(sequence)
                    self.label_distribution[label] += 1

        # Balance and limit number of sequences if max_sequences is set
        if max_sequences:
            balanced_sequences = []
            min_count = min(self.label_distribution.values(), default=0)  # Avoid division by zero
            limit_per_label = min(max_sequences // 2, min_count)

            label_counters = defaultdict(int)
            for seq in self.sequences:
                label = seq[0][1]
                if label_counters[label] < limit_per_label:
                    balanced_sequences.append(seq)
                    label_counters[label] += 1

            self.sequences = balanced_sequences
            self.label_distribution = label_counters

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        frames = self.sequences[idx]
        images, labels, img_paths = [], [], []
        for img_path, label in frames:
            try:
                img = Image.open(img_path)
                img = img.convert('RGB')  # Ensure the image is in RGB mode
                if self.transform:
                    img = self.transform(img)
                images.append(img)
                labels.append(label)
                img_paths.append(img_path)
            except IOError:
                print(f"Skipping frame {img_path} due to loading error.")
                continue

        if len(images) != len(frames):  # Check if all images were loaded
            return None  # Can be skipped in DataLoader logic

        images = torch.stack(images)  # Stack images into a tensor
        return images, torch.tensor(labels[0]), img_paths

## Apply the filter to each dataset split
image_datasets = {
    x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
    for x in ['train', 'val', 'test']
}
for phase in ['train', 'val', 'test']:
    filter_classes(image_datasets[phase], ['Run', 'Pass', 'Punt', 'Kick-Off', 'Field_Goal'])

# Print classes in each dataset to verify correct loading
print("Classes in the training dataset:", image_datasets['train'].classes)
print("Classes in the validation dataset:", image_datasets['val'].classes)
print("Classes in the testing dataset:", image_datasets['test'].classes)

# Setup your datasets and dataloaders
seq_length = 60  # Desired sequence length
max_train_sequences = 10000  # Maximum allowable training sequences
max_val_test_sequences = int(0.15 * max_train_sequences)  # Proportion for validation and test

video_datasets = {
    'train': VideoDataset(image_datasets['train'], seq_length, max_train_sequences, data_transforms['train']),
    'val': VideoDataset(image_datasets['val'], seq_length, max_val_test_sequences, data_transforms['val']),
    'test': VideoDataset(image_datasets['test'], seq_length, max_val_test_sequences, data_transforms['test'])
}

def custom_collate_fn(batch):
    batch = [item for item in batch if item is not None]  # Entfernt None-Einträge
    if not batch:
        return torch.tensor([]), torch.tensor([]), []
    images, labels, paths = zip(*batch)
    images = torch.stack(images)
    labels = torch.stack(labels)
    return images, labels, paths

# Verwende diese Funktion in deinem DataLoader
dataloaders = {
    x: DataLoader(video_datasets[x], batch_size=2, shuffle=True, collate_fn=custom_collate_fn)
    for x in ['train', 'val', 'test']
}

# Print information about loaded data
for phase in ['train', 'val', 'test']:
    print(f"\n{phase.upper()} dataset: {len(video_datasets[phase])} sequences")
    for label in range(len(image_datasets[phase].classes)):  # Adjust to count all labels
        print(f"  Label {label} ({image_datasets[phase].classes[label]}): {video_datasets[phase].label_distribution[label]} sequences")


CNN-LSTM Network forMulticlas Classification:

- Device Setup: Checks CUDA availability for GPU acceleration and sets the computation device accordingly
- CNN-LSTM Architecture: Defines a hybrid neural network model combining Convolutional Neural Network (CNN) and Long Short-Term Memory (LSTM) 
    - CNN: Utilizes a pre-trained ResNet-50 model, adapted for feature extraction
    - LSTM: Processes temporal sequences with configurable hidden sizes and layer counts
- Model Instantiation: Constructs the CNN-LSTM model with specific parameters and assigns it to the chosen computation device
- Training Setup: Establishes the loss function, optimizer, and learning rate scheduler to guide the training process

In [4]:
# Check if CUDA is available and use it if possible
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

class CNN_LSTM(nn.Module):
    def __init__(self, cnn_model, hidden_size, num_classes=5, num_layers=1):
        super(CNN_LSTM, self).__init__()
        self.cnn = cnn_model
        self.lstm = nn.LSTM(input_size=2048, hidden_size=hidden_size, num_layers=num_layers, batch_first=True) # 2048 for ResNet
        self.fc = nn.Linear(hidden_size, num_classes)
        #self.sigmoid = nn.Sigmoid()
        

    def forward(self, x):
        batch_size, seq_length, C, H, W = x.size()  # Extract the dimensions of the input

        # Reshape input for CNN
        
        c_in = x.view(batch_size * seq_length, C, H, W) 
        c_out = self.cnn(c_in)  # Run through CNN for feature extraction
        c_out = c_out.view(batch_size, seq_length, -1)  

        # Run through LSTM for sequence processing
        r_out, (h_n, c_n) = self.lstm(c_out)  # LSTM layer
        out = self.fc(r_out[:, -1, :])  # Use last output of the LSTM for classification
        return out

# Load the pre-trained ResNet-50 model and modify output feature maps
resnet = models.resnet50(pretrained=True)
cnn_model = nn.Sequential(*list(resnet.children())[:-2], nn.AdaptiveAvgPool2d((1, 1)))

# Define the hidden size, input size, number of classes, and number of LSTM layers
hidden_size = 256
num_classes = 5 
#num_classes = 1  # Binary classification with single output for BCEWithLogitsLoss
num_layers = 2 # Example number of LSTM layers

# Instantiate the combined CNN-LSTM model
cnn_lstm_model = CNN_LSTM(cnn_model, hidden_size, num_classes, num_layers).to(device)

# Define the criterion, optimizer, and learning rate scheduler
#criterion = nn.BCEWithLogitsLoss()  # BCEWithLogitsLoss for binary classification (Sigmoid)
criterion = nn.CrossEntropyLoss()  # CrossEntropyLoss for classification tasks (Softmax)

optimizer_ft = optim.Adam(cnn_lstm_model.parameters(), lr=0.0001)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=2, gamma=0.1)



Using device: cuda:0




Training and Evaluation Process for the CNN-LSTM Model:

- Initialization: Sets up the training process, initializing the best model weights and tracking training duration
- Training Loop: Iterates over the specified number of epochs, dividing each epoch into training and validation phases
- Batch Processing: Processes each batch, applying model predictions, calculating loss, and updating model parameters during the training phase
- Performance Metrics: Calculates and displays running loss and accuracy for each batch, updating periodically to monitor performance
- Validation and Metrics Storage: In the validation phase, calculates the confusion matrix and updates the best model if improved accuracy is detected
- Finalization: Completes training, prints total duration, and loads the best model weights for future use

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=3, batch_update_interval=100):
    since = time.time()  # Track the start time for training duration

    best_model_wts = copy.deepcopy(model.state_dict())  # Keep a copy of the best model weights
    best_acc = 0.0  # Initialize the best accuracy

    print("Training start")  # Print training start message

    for epoch in range(num_epochs):  # Loop over epochs
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:  # Each epoch has a training and validation phase
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            running_loss = 0.0  # Initialize running los9
        
            running_corrects = 0  # Initialize running correct predictions
            batch_count = 0  # Initialize batch count

            all_labels = []  # Store all true labels
            all_preds = []  # Store all predictions

            data_iter = iter(dataloaders[phase])  # Create an iterator for the DataLoader
            batch_total = len(dataloaders[phase])  # Total number of batches

            while True:
                try:
                    batch = next(data_iter)  # Get the next batch
                    inputs, labels, _ = batch  # Get the first two elements only, ignore the rest
                    
                    # Check if the batch is empty and skip if true
                    if inputs.size(0) == 0:
                        print("Skipping empty batch.")
                        continue

                except StopIteration:
                    break  # Exit the loop if there are no more batches

                inputs = inputs.to(device)
                labels = labels.to(device).long()

                optimizer.zero_grad()  # Zero the parameter gradients

                # Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)  # Get the index of the max log-probability
                    loss = criterion(outputs, labels)

                    # Backward pass and optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Update running loss and correct predictions
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                batch_count += 1

                # Collect labels and predictions for confusion matrix
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

                # Print update every `batch_update_interval` batches
                if batch_count % batch_update_interval == 0:
                    print(f'Batch {batch_count}/{batch_total}: {phase} Loss: {running_loss / (batch_count * inputs.size(0)):.4f} Acc: {running_corrects.double() / (batch_count * inputs.size(0)):.4f}')

            if phase == 'train':
                scheduler.step()  # Step the learning rate scheduler

            # Calculate epoch loss and accuracy
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Print confusion matrix for validation phase
            if phase == 'val':
                cm = confusion_matrix(np.array(all_labels).flatten(), np.array(all_preds).flatten())
                print(f'Confusion Matrix for epoch {epoch}:\n{cm}')

            # Deep copy the best model weights and save the model if it has the best accuracy
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(best_model_wts, 'best_model_play.pth')  # Save the best model weights

        print()

    # Calculate total training time
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    #print(f'Best val Acc: {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model

# Instantiate the criterion for multi-class classification
criterion = nn.CrossEntropyLoss()

# Train and evaluate the model with live updates and confusion matrix printing
model_ft = train_model(cnn_lstm_model, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=3)



Evaluation of the CNN-LSTM Model on Test Data:

- Model Loading: Loads the trained CNN-LSTM model from a specified file path and sets it to evaluation mode.
- Loss and Accuracy Tracking: Initializes counters for running loss and correct predictions, and prepares to collect all labels and predictions for confusion matrix analysis
- Batch Evaluation: Processes each batch from the test DataLoader, performing a forward pass to compute losses and predictions without gradient updates
- Metrics Computation: Calculates final accuracy and loss for the test set and constructs the confusion matrix
- Results Output: Prints the confusion matrix, final accuracy, and loss, providing a comprehensive evaluation of the model's performance on test data

In [None]:
def evaluate_model(model_path, dataloaders, dataset_sizes, device, update_interval=50):
    # Loading the saved model state and setting the model on the device
    model = CNN_LSTM(cnn_model, hidden_size, num_classes=5, num_layers=2).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    running_loss = 0.0
    running_corrects = 0
    all_labels = []
    all_preds = []

    # Not tracking gradients
    with torch.no_grad():
        batch_count = 0
        for inputs, labels, _ in dataloaders['val']:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)  # Assuming `criterion` is defined

            # Collecting results for the entire evaluation
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

            # Update for every 50th batch
            if (batch_count + 1) % update_interval == 0:
                intermediate_acc = running_corrects.double() / ((batch_count + 1) * inputs.size(0))
                intermediate_loss = running_loss / ((batch_count + 1) * inputs.size(0))
                print(f'After {batch_count + 1} batches: Intermediate accuracy = {intermediate_acc:.4f}, Intermediate loss = {intermediate_loss:.4f}')
            
            batch_count += 1

    # Calculating overall accuracy and loss
    final_acc = running_corrects.double() / dataset_sizes['val']
    final_loss = running_loss / dataset_sizes['val']
    
    # Creating confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    print(f'Confusion matrix:\n{cm}')
    print(f'Final accuracy: {final_acc:.4f}')
    print(f'Final loss: {final_loss:.4f}')

# Configuring model path and device
model_path = 'best_model_play.pth'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the loss criterion
criterion = nn.CrossEntropyLoss()

# Evaluate the model
evaluate_model(model_path, dataloaders, dataset_sizes, device, update_interval=50)
