In [1]:
# Imports
import os, torch, csv
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import AutoImageProcessor, AutoModelForImageClassification
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Dataset class
class OCTDataset(Dataset):
    def __init__(self, root_dir, processor):
        self.image_paths = []
        self.labels = []
        self.class_names = sorted(os.listdir(root_dir))

        for label, cls in enumerate(self.class_names):
            cls_folder = os.path.join(root_dir, cls)
            for fname in os.listdir(cls_folder):
                if fname.endswith(('.jpg', '.jpeg', '.png')):
                    self.image_paths.append(os.path.join(cls_folder, fname))
                    self.labels.append(label)

        self.processor = processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        inputs = self.processor(images=img, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)
        label = torch.tensor(self.labels[idx])
        return {"pixel_values": pixel_values, "labels": label}

In [8]:
# 📁 Path setup
base_path = "/Users/mananmathur/Documents/Academics/MIT/subject matter/YEAR 4/SEM 8/PROJECT/project"
dataset_base = os.path.join(base_path, "OCTID")
logs_base = os.path.join(base_path, "logs", "swin", "octid")
os.makedirs(os.path.join(logs_base, "training"), exist_ok=True)
os.makedirs(os.path.join(logs_base, "models"), exist_ok=True)
os.makedirs(os.path.join(logs_base, "results"), exist_ok=True)

log_file = unique_path(os.path.join(logs_base, "training", "swin_training_log.csv"))
model_save_path = unique_path(os.path.join(logs_base, "models", "swin_octid.pth"))

In [9]:
# ✅ Device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = "microsoft/swin-tiny-patch4-window7-224"
processor = AutoImageProcessor.from_pretrained(checkpoint)
model = AutoModelForImageClassification.from_pretrained(checkpoint, num_labels=5, ignore_mismatched_sizes=True)
model.to(device)

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([5]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([5, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


SwinForImageClassification(
  (swin): SwinModel(
    (embeddings): SwinEmbeddings(
      (patch_embeddings): SwinPatchEmbeddings(
        (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): SwinEncoder(
      (layers): ModuleList(
        (0): SwinStage(
          (blocks): ModuleList(
            (0): SwinLayer(
              (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
              (attention): SwinAttention(
                (self): SwinSelfAttention(
                  (query): Linear(in_features=96, out_features=96, bias=True)
                  (key): Linear(in_features=96, out_features=96, bias=True)
                  (value): Linear(in_features=96, out_features=96, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )
                (output): SwinSelfOutput(
        

In [10]:
# 📚 OCT Dataset
class OCTDataset(Dataset):
    def __init__(self, root_dir, processor):
        self.processor = processor
        self.image_paths, self.labels = [], []
        self.class_names = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        for idx, class_name in enumerate(self.class_names):
            class_path = os.path.join(root_dir, class_name)
            for img_file in os.listdir(class_path):
                if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    self.image_paths.append(os.path.join(class_path, img_file))
                    self.labels.append(idx)
    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        inputs = self.processor(images=img, return_tensors="pt")
        return {"pixel_values": inputs["pixel_values"].squeeze(0), "labels": torch.tensor(self.labels[idx])}

In [11]:
# 🔄 Dataloader, loss, optimizer
train_dataset = OCTDataset(os.path.join(dataset_base, "Train"), processor)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

weights = compute_class_weight('balanced', classes=np.unique(train_dataset.labels), y=train_dataset.labels)
weights = torch.tensor(weights, dtype=torch.float).to(device)

loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

In [12]:
# 🔁 Training loop with CSV logging
epochs = 15
with open(log_file, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(["Epoch", "Train_Loss", "Train_Accuracy"])

for epoch in range(1, epochs + 1):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for batch in train_loader:
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(pixel_values=pixel_values)
        logits = outputs.logits
        loss = loss_fn(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * pixel_values.size(0)
        correct += (logits.argmax(-1) == labels).sum().item()
        total += pixel_values.size(0)

    avg_loss = total_loss / total
    acc = correct / total
    print(f"Epoch {epoch} - Loss: {avg_loss:.4f}, Accuracy: {acc:.4f}")

    with open(log_file, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([epoch, avg_loss, acc])

torch.save(model.state_dict(), model_save_path)
print(f"✅ Final model saved to: {model_save_path}")

Epoch 1 - Loss: 1.2460, Accuracy: 0.5436
Epoch 2 - Loss: 0.4292, Accuracy: 0.9024
Epoch 3 - Loss: 0.1036, Accuracy: 0.9756
Epoch 4 - Loss: 0.0651, Accuracy: 0.9861
Epoch 5 - Loss: 0.0794, Accuracy: 0.9791
Epoch 6 - Loss: 0.2290, Accuracy: 0.9199
Epoch 7 - Loss: 0.0810, Accuracy: 0.9791
Epoch 8 - Loss: 0.0169, Accuracy: 1.0000
Epoch 9 - Loss: 0.0064, Accuracy: 1.0000
Epoch 10 - Loss: 0.0034, Accuracy: 1.0000
Epoch 11 - Loss: 0.0037, Accuracy: 1.0000
Epoch 12 - Loss: 0.0051, Accuracy: 1.0000
Epoch 13 - Loss: 0.0059, Accuracy: 1.0000
Epoch 14 - Loss: 0.0040, Accuracy: 1.0000
Epoch 15 - Loss: 0.0157, Accuracy: 0.9965
✅ Final model saved to: /Users/mananmathur/Documents/Academics/MIT/subject matter/YEAR 4/SEM 8/PROJECT/project/logs/swin/models/swin_octid_v1.pth


In [13]:
# 🧪 Evaluation function
def evaluate(split_name, split_path, summary_list):
    print(f"\n🔍 Evaluating: {split_name.upper()}")
    dataset = OCTDataset(split_path, processor)
    loader = DataLoader(dataset, batch_size=16, shuffle=False)

    all_preds, all_labels = [], []
    model.eval()
    with torch.no_grad():
        for batch in loader:
            inputs = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(pixel_values=inputs)
            preds = outputs.logits.argmax(dim=-1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    class_names = dataset.class_names
    report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True)
    df = pd.DataFrame(report).transpose()
    csv_path = unique_path(os.path.join(logs_base, "results", f"{split_name.lower()}_classification_report.csv"))
    df.to_csv(csv_path)

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.title(f"{split_name} Confusion Matrix")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    cm_path = unique_path(os.path.join(logs_base, "results", f"{split_name.lower()}_confusion_matrix.png"))
    plt.savefig(cm_path)
    plt.close()

    summary_list.append({
        "Split": split_name,
        "Accuracy": report["accuracy"],
        "Precision": report["weighted avg"]["precision"],
        "Recall": report["weighted avg"]["recall"],
        "F1-score": report["weighted avg"]["f1-score"],
        "Support": int(report["weighted avg"]["support"])
    })

In [14]:
# 🚀 Run evaluations
summary = []
evaluate("Train", os.path.join(dataset_base, "Train"), summary)
evaluate("Validation", os.path.join(dataset_base, "Validation"), summary)
evaluate("Test", os.path.join(dataset_base, "Test"), summary)

summary_path = unique_path(os.path.join(logs_base, "results", "swin_summary.xlsx"))
df_summary = pd.DataFrame(summary)
df_summary.to_excel(summary_path, index=False)
print("\n✅ Evaluation complete. Results saved to /logs/swin/results/")


🔍 Evaluating: TRAIN

🔍 Evaluating: VALIDATION

🔍 Evaluating: TEST

✅ Evaluation complete. Results saved to /logs/swin/results/
