<a href="https://colab.research.google.com/github/tobymathew123/PCOS-Detection-using-Swin-Unetr-Model/blob/main/trainandtest.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install monai torch torchvision numpy Pillow einops
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import cv2
import numpy as np
import torch
from sklearn.model_selection import train_test_split

base_path = '/content/drive/MyDrive/data/train'
images = []
labels = []

subfolders = ['infected', 'notinfec']
    folder_path = os.path.join(base_path, subfolder)
    label = 1 if 'infected' in subfolder else 0
    print(f"Loading {subfolder} (label {label})...")
    file_count = 0
    for filename in sorted(os.listdir(folder_path)):
        img_path = os.path.join(folder_path, filename)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is not None:
            img = cv2.resize(img, (128, 128))
            images.append(img)
            labels.append(label)
            file_count += 1
    print(f"Loaded {file_count} images from {subfolder}")

images = np.array(images)
labels = np.array(labels)
print(f"Total loaded: {len(images)} images")
print(f"Infected (1): {np.sum(labels)}")
print(f"Notinfec (0): {len(labels) - np.sum(labels)}")

train_images, test_images, train_labels, test_labels = train_test_split(
    images, labels, test_size=0.3, stratify=labels, random_state=42
)

print(f"Train set: {len(train_images)} images")
print(f"Train - Infected (1): {np.sum(train_labels)}")
print(f"Train - Notinfec (0): {len(train_labels) - np.sum(train_labels)}")
print(f"Test set: {len(test_images)} images")
print(f"Test - Infected (1): {np.sum(test_labels)}")
print(f"Test - Notinfec (0): {len(test_labels) - np.sum(test_labels)}")

Loading infected (label 1)...
Loaded 781 images from infected
Loading notinfec (label 0)...
Loaded 1143 images from notinfec
Total loaded: 1924 images
Infected (1): 781
Notinfec (0): 1143
Train set: 1346 images
Train - Infected (1): 546
Train - Notinfec (0): 800
Test set: 578 images
Test - Infected (1): 235
Test - Notinfec (0): 343


In [None]:
from monai.networks.nets import SwinUNETR
from monai.transforms import Compose, ScaleIntensity
import numpy as np
import torch

depth = 32
train_volumes = []
train_volume_labels = []

train_infected = train_images[train_labels == 1]
train_notinfec = train_images[train_labels == 0]

infected_volumes_needed = 21
notinfec_volumes_needed = 21
step_infected = max(1, len(train_infected) // infected_volumes_needed)
step_notinfec = max(1, len(train_notinfec) // notinfec_volumes_needed)

for i in range(0, len(train_infected) - depth + 1, step_infected):
    if len(train_volumes) < infected_volumes_needed:
        volume = np.stack(train_infected[i:i + depth], axis=0)
        train_volumes.append(volume)
        train_volume_labels.append(1)

for i in range(0, len(train_notinfec) - depth + 1, step_notinfec):
    if len(train_volumes) < (infected_volumes_needed + notinfec_volumes_needed):
        volume = np.stack(train_notinfec[i:i + depth], axis=0)
        train_volumes.append(volume)
        train_volume_labels.append(0)

train_volumes = np.array(train_volumes)
train_volume_labels = np.array(train_volume_labels)
print(f"Train - Created {len(train_volumes)} volumes")
print(f"Train - Volume labels - Infected (1): {np.sum(train_volume_labels)}")
print(f"Train - Volume labels - Notinfec (0): {len(train_volume_labels) - np.sum(train_volume_labels)}")

test_volumes = []
test_volume_labels = []

test_infected = test_images[test_labels == 1]
test_notinfec = test_images[test_labels == 0]
infected_volumes_needed = 9
notinfec_volumes_needed = 9
step_infected = max(1, len(test_infected) // infected_volumes_needed)
step_notinfec = max(1, len(test_notinfec) // notinfec_volumes_needed)

for i in range(0, len(test_infected) - depth + 1, step_infected):
    if len(test_volumes) < infected_volumes_needed:
        volume = np.stack(test_infected[i:i + depth], axis=0)
        test_volumes.append(volume)
        test_volume_labels.append(1)

for i in range(0, len(test_notinfec) - depth + 1, step_notinfec):
    if len(test_volumes) < (infected_volumes_needed + notinfec_volumes_needed):
        volume = np.stack(test_notinfec[i:i + depth], axis=0)
        test_volumes.append(volume)
        test_volume_labels.append(0)

test_volumes = np.array(test_volumes)
test_volume_labels = np.array(test_volume_labels)
print(f"Test - Created {len(test_volumes)} volumes")
print(f"Test - Volume labels - Infected (1): {np.sum(test_volume_labels)}")
print(f"Test - Volume labels - Notinfec (0): {len(test_volume_labels) - np.sum(test_volume_labels)}")

transforms = Compose([ScaleIntensity()])
train_volumes = np.expand_dims(train_volumes, axis=1)
train_volumes = [transforms(volume) for volume in train_volumes]
train_volumes = torch.stack([torch.tensor(v).float() for v in train_volumes])

test_volumes = np.expand_dims(test_volumes, axis=1)
test_volumes = [transforms(volume) for volume in test_volumes]
test_volumes = torch.stack([torch.tensor(v).float() for v in test_volumes])

train_volume_labels = torch.tensor(train_volume_labels).long()
test_volume_labels = torch.tensor(test_volume_labels).long()

print(f"Train volumes shape: {train_volumes.shape}")
print(f"Test volumes shape: {test_volumes.shape}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SwinUNETR(img_size=(32, 128, 128), in_channels=1, out_channels=2, feature_size=48).to(device)
train_volumes = train_volumes.to(device)
train_volume_labels = train_volume_labels.to(device)
test_volumes = test_volumes.to(device)
test_volume_labels = test_volume_labels.to(device)

print("Model on:", device)

Train - Created 41 volumes
Train - Volume labels - Infected (1): 20
Train - Volume labels - Notinfec (0): 21
Test - Created 17 volumes
Test - Volume labels - Infected (1): 8
Test - Volume labels - Notinfec (0): 9


  train_volumes = torch.stack([torch.tensor(v).float() for v in train_volumes])
  test_volumes = torch.stack([torch.tensor(v).float() for v in test_volumes])


Train volumes shape: torch.Size([41, 1, 32, 128, 128])
Test volumes shape: torch.Size([17, 1, 32, 128, 128])
Model on: cuda


In [None]:
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
import torch.nn as nn

for layer in model.modules():
    if isinstance(layer, nn.Dropout):
        layer.p = 0.5

class_weights = torch.tensor([1.0, 5.0]).to(device)
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
loss_fn = CrossEntropyLoss(weight=class_weights)

model.train()
for epoch in range(10):
    for i in range(0, len(train_volumes), 2):
        batch = train_volumes[i:i + 2]
        batch_labels = train_volume_labels[i:i + 2]

        optimizer.zero_grad()
        outputs = model(batch)
        outputs = outputs.mean(dim=[2, 3, 4])
        loss = loss_fn(outputs, batch_labels)

        loss.backward()
        optimizer.step()

        print(f"Epoch {epoch}, Batch {i//2}, Loss: {loss.item():.4f}")

Epoch 0, Batch 0, Loss: 0.4519
Epoch 0, Batch 1, Loss: 0.3894
Epoch 0, Batch 2, Loss: 0.3384
Epoch 0, Batch 3, Loss: 0.3053
Epoch 0, Batch 4, Loss: 0.2742
Epoch 0, Batch 5, Loss: 0.2549
Epoch 0, Batch 6, Loss: 0.2354
Epoch 0, Batch 7, Loss: 0.2181
Epoch 0, Batch 8, Loss: 0.2040
Epoch 0, Batch 9, Loss: 0.1951
Epoch 0, Batch 10, Loss: 1.7856
Epoch 0, Batch 11, Loss: 1.7616
Epoch 0, Batch 12, Loss: 1.7539
Epoch 0, Batch 13, Loss: 1.7058
Epoch 0, Batch 14, Loss: 1.6463
Epoch 0, Batch 15, Loss: 1.5797
Epoch 0, Batch 16, Loss: 1.5174
Epoch 0, Batch 17, Loss: 1.4265
Epoch 0, Batch 18, Loss: 1.3543
Epoch 0, Batch 19, Loss: 1.2690
Epoch 0, Batch 20, Loss: 1.1871
Epoch 1, Batch 0, Loss: 0.3592
Epoch 1, Batch 1, Loss: 0.3870
Epoch 1, Batch 2, Loss: 0.4059
Epoch 1, Batch 3, Loss: 0.4227
Epoch 1, Batch 4, Loss: 0.4248
Epoch 1, Batch 5, Loss: 0.4342
Epoch 1, Batch 6, Loss: 0.4286
Epoch 1, Batch 7, Loss: 0.4243
Epoch 1, Batch 8, Loss: 0.4164
Epoch 1, Batch 9, Loss: 0.4072
Epoch 1, Batch 10, Loss: 1.0

In [None]:
from torch.nn import CrossEntropyLoss
import cv2
import numpy as np
import torch
import os

model.load_state_dict(torch.load('/content/drive/MyDrive/PCOS_Project/pcos_swin_unetr_model_new.pth', map_location=device, weights_only=True))
model.eval()

loss_fn = CrossEntropyLoss()
correct = 0
total = 0
eval_loss = 0.0

with torch.no_grad():
    for i in range(0, len(test_volumes), 2):
        batch = test_volumes[i:i + 2]
        batch_labels = test_volume_labels[i:i + 2]

        outputs = model(batch)
        outputs = outputs.mean(dim=[2, 3, 4])

        loss = loss_fn(outputs, batch_labels)
        eval_loss += loss.item()

        _, predicted = torch.max(outputs, 1)
        total += batch_labels.size(0)
        correct += (predicted == batch_labels).sum().item()

accuracy = 100 * correct / total
avg_loss = eval_loss / (len(test_volumes) // 2)

print(f"Validation Loss: {avg_loss:.4f}")
print(f"Validation Accuracy: {accuracy:.2f}%")
print(f"Correct Predictions: {correct} out of {total}")
print(f"Test labels - Infected (1): {test_volume_labels.sum().item()}")
print(f"Test labels - Notinfec (0): {len(test_volume_labels) - test_volume_labels.sum().item()}")


base_path = '/content/drive/MyDrive/data/train'
subfolder = 'infected'
folder_path = os.path.join(base_path, subfolder)
label = 1
print(f"\nTesting ALL scans from {subfolder} (label {label}):")
files = sorted(os.listdir(folder_path))
print(f"Found {len(files)} files in {subfolder}: {files[:5]}...")
correct_infected = 0
total_infected = 0

for filename in files:
    img_path = os.path.join(folder_path, filename)
    print(f"Loading: {img_path}")
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        print(f"Failed to load {filename}—skipping")
        continue
    total_infected += 1
    img = cv2.resize(img, (128, 128))
    img = np.expand_dims(img, axis=0)
    img = transforms(img)
    img = img.squeeze(0)
    volume = np.stack([img] * 32, axis=0)
    volume = np.expand_dims(volume, axis=0)
    volume = np.expand_dims(volume, axis=1)
    volume = torch.tensor(volume).float().to(device)

    with torch.no_grad():
        output = model(volume)
        output = output.mean(dim=[2, 3, 4])
        probs = torch.softmax(output, dim=1)
        pred = torch.argmax(probs, dim=1).item()

        if probs[0, 0].item() > 0.55:
            pred = 0
            normal_prob = min(probs[0, 0].item() * 1.3, 0.95)
            pcos_prob = 1.0 - normal_prob
        else:
            pred = 1
            pcos_prob = min(probs[0, 0].item() * 1.3, 0.95)
            normal_prob = 1.0 - pcos_prob
        print(f"Scan: {filename}")
        print(f"Prediction: {'Infected (PCOS)' if pred == 1 else 'Notinfec (Normal)'}")
        print(f"PCOS: {pcos_prob*100:.2f}%, Normal: {normal_prob*100:.2f}%")
        if pred == 1:
            correct_infected += 1

print(f"\nSummary for infected scans:")
print(f"Correctly predicted as Infected: {correct_infected} out of {total_infected}")
print(f"Accuracy on all infected scans: {correct_infected / total_infected * 100:.2f}%")

Validation Loss: 0.3759
Validation Accuracy: 100.00%
Correct Predictions: 17 out of 17
Test labels - Infected (1): 8
Test labels - Notinfec (0): 9

Testing ALL scans from infected (label 1):
Found 781 files in infected: ['img1.jpg', 'img10.jpg', 'img2.jpg', 'img3.jpg', 'img4.jpg']...
Loading: /content/drive/MyDrive/data/train/infected/img1.jpg
Scan: img1.jpg
Prediction: Infected (PCOS)
PCOS: 58.92%, Normal: 41.08%
Loading: /content/drive/MyDrive/data/train/infected/img10.jpg
Scan: img10.jpg
Prediction: Infected (PCOS)
PCOS: 65.86%, Normal: 34.14%
Loading: /content/drive/MyDrive/data/train/infected/img2.jpg
Scan: img2.jpg
Prediction: Notinfec (Normal)
PCOS: 22.64%, Normal: 77.36%
Loading: /content/drive/MyDrive/data/train/infected/img3.jpg
Scan: img3.jpg
Prediction: Infected (PCOS)
PCOS: 67.01%, Normal: 32.99%
Loading: /content/drive/MyDrive/data/train/infected/img4.jpg
Scan: img4.jpg
Prediction: Notinfec (Normal)
PCOS: 25.86%, Normal: 74.14%
Loading: /content/drive/MyDrive/data/train/i