In [1]:
import os
import torch
import torchvision.transforms as transforms
from torchvision.models.video import r2plus1d_18, R2Plus1D_18_Weights
from torchvision.models.video import mvit_v2_s, MViT_V2_S_Weights
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.io import read_video
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score, confusion_matrix
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

print("All modules loaded")

All modules loaded


In [10]:
# Define paths
DATA_DIR = './eyeVideos'
CSV_FILE = './video_labels.csv'

# Load labels from CSV
labels_df = pd.read_csv(CSV_FILE)
labels_dict = {row['filename']: row['label'] for _, row in labels_df.iterrows()}

# Custom Video Dataset
class VideoDataset(Dataset):
    def __init__(self, root_dir, labels_dict, transform=None):
        self.root_dir = root_dir
        self.labels_dict = labels_dict
        self.transform = transform
        self.video_paths = list(labels_dict.keys())

    def __len__(self):
        return len(self.video_paths)
    
    def __getitem__(self, idx):
        video_name = self.video_paths[idx]
        video_path = os.path.join(self.root_dir, video_name)
        video, _, _ = read_video(video_path, pts_unit='sec')  # Load full video

        video = self.random_selection(video)

        video = video.permute(3, 0, 1, 2)  # Convert (Frames, H, W, C) → (C, Frames, H, W)
        video = video.float() / 255.0  # Normalize

        if self.transform:
            video = self.transform(video)

        label = self.labels_dict[video_name]
        return video, label

    @staticmethod
    def random_selection(video, num_frames=100, interval=1):    # shape: (Frames, H, W, C)
        video_length = video.shape[0]
        starting_frame = np.random.randint(0, video_length-num_frames*interval)
        return video[starting_frame:starting_frame+num_frames*interval:interval]

In [11]:
# Load one sample from the dataset
dataset = VideoDataset(DATA_DIR, labels_dict)

# Get the first video and label
video_sample, label = dataset[0]

# Print video shape and label
video_shape = video_sample.shape  # (C, Frames, H, W)
video_filename = dataset.video_paths[0]

video_shape, label, video_filename


(torch.Size([3, 100, 256, 256]), 2, 'hypo_1_163.mp4')

In [12]:
# Data Augmentation
transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load Dataset
dataset = VideoDataset(DATA_DIR, labels_dict, transform=transform)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Load R(2+1)D model
    # Jooyoung's way:
    # weights = R2Plus1D_18_Weights.DEFAULT
    # model = r2plus1d_18(weights=weights)



# Offline loading:
def load_model_offline(model_name, pth_path, ignore_key):
    model = model_name(weights=None, num_classes=3)
    state_dict = torch.load(pth_path, weights_only=True)
    # Remove parameters from the head if they exist in the state dict
    for key in list(state_dict.keys()):
        if key.startswith(ignore_key):
            print(f"Removing {key} from checkpoint due to shape mismatch.")
            del state_dict[key]
    model.load_state_dict(state_dict, strict=False)
    return model
model = load_model_offline(r2plus1d_18, "./pretrained_weights/r2plus1d_18-91a641e6.pth", ignore_key="fc")
# model = load_model_offline(mvit_v2_s,"./pretrained_weights/mvit_v2_s-ae3be16 7.pth", ignore_key="head")

model = model.cuda()

# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
num_epochs = 50
training_loss = []
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    pbar = tqdm(train_loader, total=len(train_loader))
    for inputs, labels in pbar:
        inputs, labels = inputs.cuda(), labels.cuda()
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        training_loss.append(loss.item())
        pbar.set_postfix(dict(running_loss=loss.item()))
        pbar.update()
    pbar.close()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")

print("Training completed!")

Removing fc.weight from checkpoint due to shape mismatch.
Removing fc.bias from checkpoint due to shape mismatch.


 27%|██▋       | 68/252 [02:57<07:03,  2.30s/it, running_loss=1.06] 

In [None]:
import matplotlib.pyplot as plt
plt.plot(training_loss)

In [None]:
# Evaluation
model.eval()
all_labels = []
all_probs = []
all_preds = []

with torch.no_grad():
    for inputs, labels in tqdm(test_loader):
        inputs, labels = inputs.cuda(), labels.cuda()

        outputs = model(inputs)  # Raw logits
        probs = F.softmax(outputs, dim=1)  # Convert to probabilities
        preds = torch.argmax(probs, dim=1)  # Get predicted class

        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

# Convert to NumPy arrays
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)
all_preds = np.array(all_preds)

# Compute AUROC & AUPRC for each class
num_classes = 3
auroc_scores = []
auprc_scores = []

for i in range(num_classes):
    y_true = (all_labels == i).astype(int)  # Convert to binary labels
    y_score = all_probs[:, i]  # Probability of class i
    
    auroc = roc_auc_score(y_true, y_score)
    auprc = average_precision_score(y_true, y_score)

    auroc_scores.append(auroc)
    auprc_scores.append(auprc)

    print(f"Class {i} - AUROC: {auroc:.4f}, AUPRC: {auprc:.4f}")

In [None]:
# Compute and Plot Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(6,6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["normal", "hyper", "hypo"], yticklabels=["normal", "hyper", "hypo"])
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.show()

In [None]:
from sklearn.metrics import classification_report
CLASS_NAMES = {'normal': 0, 'hyper': 1, 'hypo': 2}

print(classification_report(all_labels, all_preds, target_names=CLASS_NAMES))

In [None]:
save_path = "video_classification_model.pth"  # Choose a filename
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

In [None]:

# Convert data to Pandas DataFrame
df = pd.DataFrame({
    "true_label": all_labels,
    "predicted_label": all_preds
})

# Add probability columns for each class
num_classes = all_probs[0].shape[0]
for i in range(num_classes):
    df[f"prob_class_{i}"] = [prob[i] for prob in all_probs]

# Save as CSV
csv_filename = "predictions.csv"
df.to_csv(csv_filename, index=False)

print(f"Predictions saved to {csv_filename}")