In [1]:
import gdown

gdown.download(id="1ZEyNMEO43u3qhJAwJeBZxFBEYc_pVYZQ")

!unzip dataset.zip

Downloading...
From (original): https://drive.google.com/uc?id=1ZEyNMEO43u3qhJAwJeBZxFBEYc_pVYZQ
From (redirected): https://drive.google.com/uc?id=1ZEyNMEO43u3qhJAwJeBZxFBEYc_pVYZQ&confirm=t&uuid=67bb0dcb-d34c-4120-ae43-21fab59d6840
To: /content/dataset.zip
100%|██████████| 1.13G/1.13G [00:15<00:00, 71.5MB/s]


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: dataset/val/no/531.npy  
  inflating: dataset/val/no/257.npy  
  inflating: dataset/val/no/243.npy  
  inflating: dataset/val/no/525.npy  
  inflating: dataset/val/no/1099.npy  
  inflating: dataset/val/no/1927.npy  
  inflating: dataset/val/no/1933.npy  
  inflating: dataset/val/no/519.npy  
  inflating: dataset/val/no/1066.npy  
  inflating: dataset/val/no/1700.npy  
  inflating: dataset/val/no/294.npy  
  inflating: dataset/val/no/2209.npy  
  inflating: dataset/val/no/280.npy  
  inflating: dataset/val/no/1714.npy  
  inflating: dataset/val/no/1072.npy  
  inflating: dataset/val/no/2235.npy  
  inflating: dataset/val/no/1728.npy  
  inflating: dataset/val/no/2221.npy  
  inflating: dataset/val/no/733.npy  
  inflating: dataset/val/no/727.npy  
  inflating: dataset/val/no/1502.npy  
  inflating: dataset/val/no/1264.npy  
  inflating: dataset/val/no/928.npy  
  inflating: dataset/val/no/1270.npy  
  inflati

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, roc_curve
import numpy as np
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### **Dataset Preparation**  

The dataset consists of `.npy` files stored in class-specific directories. A custom `NPZDataset` class is used to load and preprocess the data:  

- **Class Mapping:** Three categories—`no (0)`, `sphere (1)`, and `vort (2)`.  
- **Loading Images:** Each `.npy` file is loaded as a NumPy array and converted into a PyTorch tensor.  
- **Channel Expansion:** Since the images are single-channel, they are repeated across three channels for compatibility with the pretrained CNN models.  
- **DataLoader Setup:**  
  - **Training Data:** Batch size of `128`, shuffled for better generalization.  
  - **Validation Data:** Batch size of `32`, no shuffling to maintain consistency.  

This ensures efficient loading and preprocessing for model training.

In [3]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF

class NPZDataset(Dataset):
    def __init__(self, root_dir):
        self.data = []
        self.labels = []
        self.class_map = {'no': 0, 'sphere': 1, 'vort': 2}

        for class_name in self.class_map.keys():
            class_dir = os.path.join(root_dir, class_name)
            for file in os.listdir(class_dir):
                if file.endswith(".npy"):
                    self.data.append(os.path.join(class_dir, file))
                    self.labels.append(self.class_map[class_name])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = np.load(self.data[idx])

        image = torch.tensor(image, dtype=torch.float32,device=device)

        image = image.repeat(3, 1, 1) # Repeat Image to make it 3 Channels

        label = torch.tensor(self.labels[idx], dtype=torch.long , device=device)
        return image, label

# Load dataset
data_path = "./dataset"
train_dataset = NPZDataset(os.path.join(data_path, "train"))
val_dataset = NPZDataset(os.path.join(data_path, "val"))

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

**MobileNetV3 Initialization**  

A **pretrained MobileNetV3-Large** is used with ImageNet weights. The final classifier is replaced with `Linear(1280, 3)` to match the three-class task.

In [4]:
from torchvision.io import decode_image
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights

weights = MobileNet_V3_Large_Weights.DEFAULT
model = mobilenet_v3_large(weights=weights)

model.classifier[3] = nn.Linear(1280, 3)

model = model.to(device)

Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-5c1a4163.pth
100%|██████████| 21.1M/21.1M [00:00<00:00, 69.5MB/s]


**Model Training**  

The model initially struggled with certain classes, so **class weights** (`[4.0, 2.0, 0.5]`) were assigned in **CrossEntropyLoss** to give more importance to harder-to-learn categories.

This weighting helps the model improve performance on underperforming classes.

In [5]:
import torch
import torch.nn.functional as F

def train_model(model, train_loader, criterion, optimizer, epochs=20):
    model.train()
    for epoch in range(epochs):
        total_loss = 0.0
        correct = 0
        total = 0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        accuracy = correct / total * 100
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_loader):.4f}, Accuracy: {accuracy:.2f}%")

# Define manual class weights
class_weights = torch.tensor([4.0, 2.0, 0.5]).to(device)

# Use the weights in the loss function
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=0.01)
train_model(model, train_loader, criterion, optimizer)

Epoch 1/20, Loss: 1.0322, Accuracy: 33.23%
Epoch 2/20, Loss: 0.8630, Accuracy: 34.38%
Epoch 3/20, Loss: 0.6066, Accuracy: 52.68%
Epoch 4/20, Loss: 0.4321, Accuracy: 62.12%
Epoch 5/20, Loss: 0.3481, Accuracy: 72.81%
Epoch 6/20, Loss: 0.2986, Accuracy: 78.09%
Epoch 7/20, Loss: 0.2649, Accuracy: 81.57%
Epoch 8/20, Loss: 0.2303, Accuracy: 84.58%
Epoch 9/20, Loss: 0.2289, Accuracy: 85.11%
Epoch 10/20, Loss: 0.2063, Accuracy: 86.81%
Epoch 11/20, Loss: 0.1984, Accuracy: 87.34%
Epoch 12/20, Loss: 0.2096, Accuracy: 86.97%
Epoch 13/20, Loss: 0.1769, Accuracy: 88.90%
Epoch 14/20, Loss: 0.1787, Accuracy: 88.67%
Epoch 15/20, Loss: 0.1933, Accuracy: 88.18%
Epoch 16/20, Loss: 0.1708, Accuracy: 89.62%
Epoch 17/20, Loss: 0.1663, Accuracy: 89.62%
Epoch 18/20, Loss: 0.1683, Accuracy: 89.93%
Epoch 19/20, Loss: 0.1472, Accuracy: 91.08%
Epoch 20/20, Loss: 0.1546, Accuracy: 90.89%


In [6]:
# Evaluate model
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for images, labels in val_loader:
        outputs = model(images)
        probabilities = torch.softmax(outputs[:, :3], dim=1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(probabilities.cpu().numpy())

# Calculate AUC score
auc_score = roc_auc_score(y_true, y_pred, multi_class='ovr')
print(f"AUC Score: {auc_score:.4f}")

AUC Score: 0.9542


In [7]:
# prompt: plot roc curve using plotly

import plotly.graph_objects as go
from sklearn.preprocessing import label_binarize

# Binarize the labels for ROC curve calculation
y_true_binarized = label_binarize(y_true, classes=np.unique(y_true))

# Assuming y_pred is a probability matrix
n_classes = y_true_binarized.shape[1]
fpr = dict()
tpr = dict()
roc_auc = dict()

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_true_binarized[:, i], np.array(y_pred)[:, i])
    roc_auc[i] = roc_auc_score(y_true_binarized[:, i], np.array(y_pred)[:, i])

# Plot ROC curve
fig = go.Figure()
for i in range(n_classes):
    fig.add_trace(go.Scatter(x=fpr[i], y=tpr[i], mode='lines', name=f'Class {i} (AUC = {roc_auc[i]:.2f})'))
fig.add_shape(type='line', line=dict(dash='dash'), x0=0, x1=1, y0=0, y1=1)
fig.update_layout(title='ROC Curve', xaxis_title='False Positive Rate', yaxis_title='True Positive Rate')
fig.show()


In [9]:
torch.save(model.state_dict(), 'Common_Test_Model.pth')