In [5]:
# === Imports ===
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import Dataset, DataLoader
from transformers import AutoImageProcessor, AutoModelForImageClassification
import csv

In [6]:
# === Paths ===
base_path = "/Users/mananmathur/Documents/Academics/MIT/subject matter/YEAR 4/SEM 8/PROJECT/project"
dataset_base = os.path.join(base_path, "OCTDL")
logs_base = os.path.join(base_path, "logs", "swin", "octdl")
os.makedirs(os.path.join(logs_base, "training"), exist_ok=True)
os.makedirs(os.path.join(logs_base, "models"), exist_ok=True)
os.makedirs(os.path.join(logs_base, "results"), exist_ok=True)

def unique_path(path, sep="_v"):
    if not os.path.exists(path): return path
    base, ext = os.path.splitext(path)
    i = 1
    while os.path.exists(f"{base}{sep}{i}{ext}"):
        i += 1
    return f"{base}{sep}{i}{ext}"

log_file = unique_path(os.path.join(logs_base, "training", "swin_octdl_log.csv"))
model_save_path = unique_path(os.path.join(logs_base, "models", "swin_octdl.pth"))

In [7]:
# === Setup ===
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
checkpoint = "microsoft/swin-tiny-patch4-window7-224"
processor = AutoImageProcessor.from_pretrained(checkpoint)
model = AutoModelForImageClassification.from_pretrained(checkpoint, num_labels=2, ignore_mismatched_sizes=True).to(device)

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
# === Dataset Class ===
class OCTDLBinaryDataset(Dataset):
    def __init__(self, root_dir, processor):
        self.processor = processor
        self.image_paths, self.labels = [], []
        for label, class_name in enumerate(["Normal", "Abnormal"]):
            class_path = os.path.join(root_dir, class_name)
            for img_file in os.listdir(class_path):
                if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.image_paths.append(os.path.join(class_path, img_file))
                    self.labels.append(label)
    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        inputs = self.processor(images=img, return_tensors="pt")
        return {
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "labels": torch.tensor(self.labels[idx])
        }

In [9]:
# === Training ===
train_dataset = OCTDLBinaryDataset(os.path.join(dataset_base, "Train"), processor)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

weights = compute_class_weight('balanced', classes=np.unique(train_dataset.labels), y=train_dataset.labels)
weights = torch.tensor(weights, dtype=torch.float).to(device)

loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

epochs = 15
with open(log_file, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(["Epoch", "Train_Loss", "Train_Accuracy"])

for epoch in range(1, epochs + 1):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for batch in train_loader:
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(pixel_values=pixel_values)
        logits = outputs.logits
        loss = loss_fn(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * pixel_values.size(0)
        correct += (logits.argmax(-1) == labels).sum().item()
        total += pixel_values.size(0)

    avg_loss = total_loss / total
    acc = correct / total
    print(f"Epoch {epoch} - Loss: {avg_loss:.4f}, Accuracy: {acc:.4f}")

    with open(log_file, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([epoch, avg_loss, acc])

torch.save(model.state_dict(), model_save_path)
print(f"✅ Final model saved to: {model_save_path}")

Epoch 1 - Loss: 0.2547, Accuracy: 0.8877
Epoch 2 - Loss: 0.0882, Accuracy: 0.9653
Epoch 3 - Loss: 0.0876, Accuracy: 0.9653
Epoch 4 - Loss: 0.0279, Accuracy: 0.9927
Epoch 5 - Loss: 0.0746, Accuracy: 0.9744
Epoch 6 - Loss: 0.0258, Accuracy: 0.9900
Epoch 7 - Loss: 0.0088, Accuracy: 0.9982
Epoch 8 - Loss: 0.0029, Accuracy: 0.9991
Epoch 9 - Loss: 0.0034, Accuracy: 0.9991
Epoch 10 - Loss: 0.0225, Accuracy: 0.9918
Epoch 11 - Loss: 0.0184, Accuracy: 0.9945
Epoch 12 - Loss: 0.0063, Accuracy: 0.9963
Epoch 13 - Loss: 0.0376, Accuracy: 0.9909
Epoch 14 - Loss: 0.0150, Accuracy: 0.9936
Epoch 15 - Loss: 0.0019, Accuracy: 0.9991
✅ Final model saved to: /Users/mananmathur/Documents/Academics/MIT/subject matter/YEAR 4/SEM 8/PROJECT/project/logs/swin/octdl/models/swin_octdl_v1.pth


In [10]:
# === Evaluation ===
def evaluate(split_name, split_path, summary_list):
    print(f"\n🔍 Evaluating: {split_name.upper()}")
    dataset = OCTDLBinaryDataset(split_path, processor)
    loader = DataLoader(dataset, batch_size=16, shuffle=False)

    all_preds, all_labels = [], []
    model.eval()
    with torch.no_grad():
        for batch in loader:
            inputs = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(pixel_values=inputs)
            preds = outputs.logits.argmax(dim=-1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    report = classification_report(all_labels, all_preds, target_names=["Normal", "Abnormal"], output_dict=True)
    df = pd.DataFrame(report).transpose()
    df.to_csv(unique_path(os.path.join(logs_base, "results", f"{split_name.lower()}_classification_report.csv")))

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap="Blues", xticklabels=["Normal", "Abnormal"], yticklabels=["Normal", "Abnormal"])
    plt.title(f"{split_name} Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    plt.savefig(unique_path(os.path.join(logs_base, "results", f"{split_name.lower()}_confusion_matrix.png")))
    plt.close()

    summary_list.append({
        "Split": split_name,
        "Accuracy": report["accuracy"],
        "Precision": report["weighted avg"]["precision"],
        "Recall": report["weighted avg"]["recall"],
        "F1-score": report["weighted avg"]["f1-score"],
        "Support": int(report["weighted avg"]["support"])
    })

In [11]:
# === Summary ===
summary = []
evaluate("Train", os.path.join(dataset_base, "Train"), summary)
evaluate("Validation", os.path.join(dataset_base, "Validation"), summary)
evaluate("Test", os.path.join(dataset_base, "Test"), summary)

summary_path = unique_path(os.path.join(logs_base, "results", "swin_octdl_summary.xlsx"))
df_summary = pd.DataFrame(summary)
df_summary.to_excel(summary_path, index=False)
print("\n✅ Evaluation complete. Results saved to /logs/swin/octdl/results/")


🔍 Evaluating: TRAIN

🔍 Evaluating: VALIDATION

🔍 Evaluating: TEST

✅ Evaluation complete. Results saved to /logs/swin/octdl/results/
