In [None]:
import os
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import clip
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from tqdm import tqdm

# Define the folder to save plots
save_folder = 'saved_plots_6_classes'
os.makedirs(save_folder, exist_ok=True)  # Create the folder if it doesn't exist

# Choose computation device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load pre-trained CLIP model
model, preprocess = clip.load('ViT-L/14@336px', device=device, jit=False)

# Define custom dataset
class ImageTitleDataset(Dataset):
    def __init__(self, image_paths, labels, class_to_idx):
        self.image_paths = image_paths
        self.labels = labels
        self.class_to_idx = class_to_idx
        self.tokenized_labels = clip.tokenize(labels)  # Tokenize labels for CLIP

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = preprocess(Image.open(self.image_paths[idx]))
        label = self.tokenized_labels[idx]
        return image, label, self.labels[idx]  # Return image, tokenized text, and label

# Prepare dataset and class-wise splitting
folder_path = "labeled_activity_dataset_6_classes/"
folders = os.listdir(folder_path)

class_to_idx = {cls_name: idx for idx, cls_name in enumerate(folders)}
train_images, val_images, test_images = [], [], []
train_labels, val_labels, test_labels = [], [], []

for folder in folders:
    folder_path_full = os.path.join(folder_path, folder)
    images_in_folder = os.listdir(folder_path_full)
    image_paths = [os.path.join(folder_path_full, img) for img in images_in_folder]

    # Split 2:1:1 for train, val, test
    train, temp = train_test_split(image_paths, test_size=0.5, random_state=42)
    val, test = train_test_split(temp, test_size=0.5, random_state=42)

    train_images.extend(train)
    val_images.extend(val)
    test_images.extend(test)

    train_labels.extend([folder] * len(train))
    val_labels.extend([folder] * len(val))
    test_labels.extend([folder] * len(test))

# Create datasets
train_dataset = ImageTitleDataset(train_images, train_labels, class_to_idx)
val_dataset = ImageTitleDataset(val_images, val_labels, class_to_idx)
test_dataset = ImageTitleDataset(test_images, test_labels, class_to_idx)

# Custom batch sampler to ensure one image per class in each batch
class ClassBalancedBatchSampler:
    def __init__(self, dataset, class_to_idx):
        self.dataset = dataset
        self.class_to_idx = class_to_idx
        self.class_to_image_idx = {cls: [] for cls in class_to_idx.keys()}

        # Group indices by class
        for idx, (_, _, label) in enumerate(dataset):
            self.class_to_image_idx[label].append(idx)

    def __iter__(self):
        # Create batches ensuring one sample per class in each batch
        min_class_size = min(len(indices) for indices in self.class_to_image_idx.values())
        for i in range(min_class_size):
            batch_indices = [
                self.class_to_image_idx[cls][i % len(self.class_to_image_idx[cls])]
                for cls in self.class_to_idx.keys()
            ]
            yield batch_indices

    def __len__(self):
        return min(len(indices) for indices in self.class_to_image_idx.values())

train_sampler = ClassBalancedBatchSampler(train_dataset, class_to_idx)
val_sampler = ClassBalancedBatchSampler(val_dataset, class_to_idx)

train_loader = DataLoader(train_dataset, batch_sampler=train_sampler)
val_loader = DataLoader(val_dataset, batch_sampler=val_sampler)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)

# Early stopping parameters
patience = 5  # Number of epochs to wait for improvement
best_val_loss = float('inf')
patience_counter = 0

# Training and validation with early stopping
num_epochs = 100
train_losses, val_losses = [], []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, texts, labels in train_loader:
        images = images.to(device)  # Move images to device
        texts = texts.to(device)    # Move tokenized texts to device

        optimizer.zero_grad()
        logits_per_image, logits_per_text = model(images, texts)
        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
        loss = (loss_fn(logits_per_image, ground_truth) + loss_fn(logits_per_text, ground_truth)) / 2
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_losses.append(train_loss / len(train_loader))

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, texts, labels in val_loader:
            images = images.to(device)  # Move images to the computation device
            texts = texts.to(device)    # Move tokenized texts to the computation device

            logits_per_image, logits_per_text = model(images, texts)
            ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
            loss = (loss_fn(logits_per_image, ground_truth) + loss_fn(logits_per_text, ground_truth)) / 2
            val_loss += loss.item()

    val_losses.append(val_loss / len(val_loader))
    print(f"Epoch {epoch+1}, Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}")

    # Check for early stopping
    if val_losses[-1] < best_val_loss:
        best_val_loss = val_losses[-1]
        patience_counter = 0  # Reset counter if validation loss improves
        # Optionally, save the best model
        torch.save(model.state_dict(), "best_model_6_classes.pth")
    else:
        patience_counter += 1  # Increment counter if no improvement
        print(f"Early stopping patience counter: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print("Early stopping triggered. Stopping training.")
            break

# Load the best model for evaluation
model.load_state_dict(torch.load("best_model_6_classes.pth"))

# Plot training and validation loss
plt.figure(figsize=(8, 6))
plt.plot(range(len(train_losses)), train_losses, label="Train Loss")
plt.plot(range(len(val_losses)), val_losses, label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.title("Training and Validation Loss")
plt.savefig(os.path.join(save_folder, "train_val_loss.png"))  # Save plot

# Test Phase
model.eval()
all_preds, all_labels, all_probs = [], [], []

# Tokenize all class descriptions once for efficiency
class_texts = clip.tokenize(list(class_to_idx.keys())).to(device)

with torch.no_grad():
    for image_path, label in zip(test_images, test_labels):
        # Load and preprocess the test image
        image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)

        # Encode the image and text using the model
        image_features = model.encode_image(image)
        text_features = model.encode_text(class_texts)

        # Compute logits and probabilities
        logits_per_image, logits_per_text = model(image, class_texts)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()

        # Get the top predicted class
        pred_idx = logits_per_image.argmax(dim=1).item()
        pred_label = list(class_to_idx.keys())[pred_idx]

        # Store predictions and true labels
        all_preds.append(pred_label)
        all_labels.append(label)
        all_probs.append(probs[0])

# Metrics
accuracy = accuracy_score(all_labels, all_preds)
f1 = f1_score(all_labels, all_preds, average="weighted")
print(f"Test Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")

# Map class names to numeric labels
numeric_class_labels = list(range(len(class_to_idx)))
label_to_numeric = {label: idx for idx, label in enumerate(class_to_idx.keys())}
numeric_all_labels = [label_to_numeric[label] for label in all_labels]
numeric_all_preds = [label_to_numeric[label] for label in all_preds]

# Confusion Matrix
cm = confusion_matrix(numeric_all_labels, numeric_all_preds, labels=numeric_class_labels)
disp = ConfusionMatrixDisplay(cm, display_labels=numeric_class_labels)
disp.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.savefig(os.path.join(save_folder, "confusion_matrix.png"))  # Save confusion matrix plot

# Display class names beside the matrix
plt.gca().figure.subplots_adjust(left=0.25)  # Adjust for extra space
for i, class_name in enumerate(class_to_idx.keys()):
    plt.text(-2.5, i, f"{class_name}", rotation=0, va="center", ha="right", fontsize=10)
plt.savefig(os.path.join(save_folder, "confusion_matrix_with_labels.png"))  # Save updated confusion matrix plot

# Group test images by true class
test_images_grouped = {class_name: [] for class_name in class_to_idx.keys()}
for img_path, label in zip(test_images, all_labels):
    test_images_grouped[label].append(img_path)

# Plot grouped images with predicted class names
for class_name, images in test_images_grouped.items():
    plt.figure(figsize=(16, len(images) // 4 * 4))  # Adjust figure size based on number of images
    plt.suptitle(f"Class: {class_name} (Numeric Label: {class_to_idx[class_name]})", fontsize=16, y=1.02)

    for idx, img_path in enumerate(images):
        img = Image.open(img_path)
        
        # Get predicted label for this image
        try:
            img_index = test_images.index(img_path)  # Ensure img_path exists in test_images
            pred_label = all_preds[img_index]  # This is already a string (the predicted class)
        except ValueError:
            pred_label = "Unknown"  # If img_path is not found in test_images
        
        # Determine if the prediction is correct
        is_correct = pred_label == class_name  # Compare predicted label with true label

        # Set color for title based on correctness
        title_color = "green" if is_correct else "red"

        plt.subplot((len(images) + 3) // 4, 4, idx + 1)  # 4 columns
        plt.imshow(img)
        plt.title(f"Pred: {pred_label}", fontsize=10, color=title_color, pad=5)
        plt.axis("off")
    
    # Save the figure
    save_path = os.path.join(save_folder, f"{class_name}_grouped_images.png")
    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Make space for the title
    plt.savefig(save_path)  # Save plot to the folder
    plt.show()  # Close the plot to avoid display in the notebook or further processing

print("All plots saved successfully!")