In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Install necessary packages (YOLOv8 and torchsummary for model summary)
!pip install -q ultralytics torchsummary

In [None]:
# Imports and Device Setup
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchsummary import summary
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report, confusion_matrix
from tqdm import tqdm
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [None]:
# Load Pretrained YOLOv8s‑CLS v8.0 via ultralytics
from ultralytics import YOLO

# 'yolov8s-cls.pt' will automatically download v8.0’s official classification checkpoint.
hub_model = YOLO('yolov8s-cls.pt').to(device)
hub_model.model.eval()
print("Successfully loaded YOLOv8s‑CLS via ultralytics.")

Successfully loaded YOLOv8s‑CLS via ultralytics.


In [None]:
# Inspect the ClassificationModel to find its final Linear layer
print(hub_model.model)

ClassificationModel(
  (model): Sequential(
    (0): Conv(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (1): Conv(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (2): C2f(
      (cv1): Conv(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (cv2): Conv(
        (conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
   

In [None]:
# Replace the final “Linear(in_features=1280, out_features=1000)” with “Linear(1280→4)”
# In YOLOv8, hub_model.model is a ClassificationModel. Its last module is a Classify(...) block.
# We locate that final .linear and swap it out for nn.Linear(1280, 4).
classify_block = hub_model.model.model[-1]          # final Classify(...) module
in_features    = classify_block.linear.in_features  # should be 1280
print(f"Replacing final Linear: in_features = {in_features}, out_features = 4")

# Replace with a new 4-way linear
classify_block.linear = nn.Linear(in_features, 4).to(device)

# Now grab the raw nn.Sequential that does exactly “backbone → head → final 4‐way linear”
classifier = hub_model.model.model.to(device)

# Display a summary to confirm ~ 7 M params and final head output = 4
summary(classifier, input_size=(3, 224, 224))

Replacing final Linear: in_features = 1280, out_features = 4
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 112, 112]             864
       BatchNorm2d-2         [-1, 32, 112, 112]              64
              SiLU-3         [-1, 32, 112, 112]               0
              SiLU-4         [-1, 32, 112, 112]               0
              SiLU-5         [-1, 32, 112, 112]               0
              SiLU-6         [-1, 32, 112, 112]               0
              SiLU-7         [-1, 32, 112, 112]               0
              SiLU-8         [-1, 32, 112, 112]               0
              SiLU-9         [-1, 32, 112, 112]               0
             SiLU-10         [-1, 32, 112, 112]               0
             SiLU-11         [-1, 32, 112, 112]               0
             SiLU-12         [-1, 32, 112, 112]               0
             SiLU-13         [-1, 32, 112,

In [None]:
# Data Paths and Transforms
data_dir  = "/content/drive/MyDrive/spectrograms_split"
train_dir = os.path.join(data_dir, "train")
val_dir   = os.path.join(data_dir, "val")
test_dir  = os.path.join(data_dir, "test")
balanced_test_dir = os.path.join(data_dir, "test_balanced")

for path in [train_dir, val_dir, test_dir]:
    assert os.path.isdir(path), f"Directory not found: {path}"

In [None]:
# Resize → ToTensor → Normalize (ImageNet stats)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std =[0.229, 0.224, 0.225]
    )
])

In [None]:
train_dataset = datasets.ImageFolder(train_dir, transform=transform)
val_dataset   = datasets.ImageFolder(val_dir,   transform=transform)
test_dataset  = datasets.ImageFolder(test_dir,  transform=transform)
balanced_dataset  = datasets.ImageFolder(balanced_test_dir, transform=transform)

print("Classes:", train_dataset.classes)  # e.g. ['mild','moderate','normal','severe']
num_classes = len(train_dataset.classes)

Classes: ['mild', 'moderate', 'normal', 'severe']


In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,  num_workers=2)
val_loader   = DataLoader(val_dataset,   batch_size=32, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_dataset,  batch_size=32, shuffle=False, num_workers=2)
balanced_loader  = DataLoader(balanced_dataset,  batch_size=32, shuffle=False, num_workers=2)

In [None]:
# Ensure all parameters in `classifier` are trainable
for param in classifier.parameters():
    param.requires_grad = True

In [None]:
# Set up Loss, Optimizer, and Learning‑Rate Scheduler
criterion    = nn.CrossEntropyLoss()
optimizer    = torch.optim.Adam(classifier.parameters(), lr=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [None]:
# Training + Validation Loop (20 Epochs)
num_epochs = 20
for epoch in range(num_epochs):
    # Train Phase
    classifier.train()
    running_loss = 0.0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]"):
        imgs, labels = imgs.to(device), labels.to(device)

        # YOLOv8's classifier might return a tuple, so we grab [0] if it's a tuple
        raw_outputs = classifier(imgs)
        if isinstance(raw_outputs, tuple):
            outputs = raw_outputs[0]
        else:
            outputs = raw_outputs

        loss    = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1:2d} Train Loss: {avg_train_loss:.4f}")
    lr_scheduler.step()

    # Validation Phase
    classifier.eval()
    val_loss    = 0.0
    correct_val = 0
    total_val   = 0
    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"):
            imgs, labels = imgs.to(device), labels.to(device)

            raw_outputs = classifier(imgs)
            if isinstance(raw_outputs, tuple):
                outputs = raw_outputs[0]
            else:
                outputs = raw_outputs

            loss    = criterion(outputs, labels)
            val_loss += loss.item()

            _, preds = torch.max(outputs, dim=1)
            correct_val += (preds == labels).sum().item()
            total_val   += labels.size(0)

    avg_val_loss = val_loss / len(val_loader)
    val_acc      = 100.0 * correct_val / total_val
    print(f"Epoch {epoch+1:2d} Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.2f}%\n")

print("Training complete.")

Epoch 1/20 [Train]: 100%|██████████| 453/453 [04:28<00:00,  1.69it/s]


Epoch  1 Train Loss: 1.0341


Epoch 1/20 [Val]: 100%|██████████| 97/97 [00:50<00:00,  1.93it/s]


Epoch  1 Val Loss: 1.1453 | Val Acc: 65.33%



Epoch 2/20 [Train]: 100%|██████████| 453/453 [04:03<00:00,  1.86it/s]


Epoch  2 Train Loss: 0.6305


Epoch 2/20 [Val]: 100%|██████████| 97/97 [00:50<00:00,  1.91it/s]


Epoch  2 Val Loss: 1.0381 | Val Acc: 74.14%



Epoch 3/20 [Train]: 100%|██████████| 453/453 [03:53<00:00,  1.94it/s]


Epoch  3 Train Loss: 0.3331


Epoch 3/20 [Val]: 100%|██████████| 97/97 [00:49<00:00,  1.96it/s]


Epoch  3 Val Loss: 0.9908 | Val Acc: 77.36%



Epoch 4/20 [Train]: 100%|██████████| 453/453 [03:55<00:00,  1.93it/s]


Epoch  4 Train Loss: 0.1318


Epoch 4/20 [Val]: 100%|██████████| 97/97 [00:47<00:00,  2.02it/s]


Epoch  4 Val Loss: 0.9770 | Val Acc: 77.46%



Epoch 5/20 [Train]: 100%|██████████| 453/453 [03:54<00:00,  1.93it/s]


Epoch  5 Train Loss: 0.0514


Epoch 5/20 [Val]: 100%|██████████| 97/97 [00:47<00:00,  2.05it/s]


Epoch  5 Val Loss: 0.9705 | Val Acc: 77.30%



Epoch 6/20 [Train]: 100%|██████████| 453/453 [03:51<00:00,  1.95it/s]


Epoch  6 Train Loss: 0.0325


Epoch 6/20 [Val]: 100%|██████████| 97/97 [00:48<00:00,  1.99it/s]


Epoch  6 Val Loss: 0.9704 | Val Acc: 77.20%



Epoch 7/20 [Train]: 100%|██████████| 453/453 [03:49<00:00,  1.97it/s]


Epoch  7 Train Loss: 0.0326


Epoch 7/20 [Val]: 100%|██████████| 97/97 [00:47<00:00,  2.04it/s]


Epoch  7 Val Loss: 0.9662 | Val Acc: 77.33%



Epoch 8/20 [Train]: 100%|██████████| 453/453 [03:51<00:00,  1.96it/s]


Epoch  8 Train Loss: 0.0206


Epoch 8/20 [Val]: 100%|██████████| 97/97 [00:48<00:00,  1.99it/s]


Epoch  8 Val Loss: 0.9678 | Val Acc: 77.62%



Epoch 9/20 [Train]: 100%|██████████| 453/453 [03:49<00:00,  1.98it/s]


Epoch  9 Train Loss: 0.0273


Epoch 9/20 [Val]: 100%|██████████| 97/97 [00:48<00:00,  2.02it/s]


Epoch  9 Val Loss: 0.9812 | Val Acc: 75.72%



Epoch 10/20 [Train]: 100%|██████████| 453/453 [03:55<00:00,  1.93it/s]


Epoch 10 Train Loss: 0.0303


Epoch 10/20 [Val]: 100%|██████████| 97/97 [00:50<00:00,  1.94it/s]


Epoch 10 Val Loss: 0.9641 | Val Acc: 77.97%



Epoch 11/20 [Train]: 100%|██████████| 453/453 [04:11<00:00,  1.80it/s]


Epoch 11 Train Loss: 0.0141


Epoch 11/20 [Val]: 100%|██████████| 97/97 [00:51<00:00,  1.88it/s]


Epoch 11 Val Loss: 0.9525 | Val Acc: 78.62%



Epoch 12/20 [Train]: 100%|██████████| 453/453 [03:58<00:00,  1.90it/s]


Epoch 12 Train Loss: 0.0089


Epoch 12/20 [Val]: 100%|██████████| 97/97 [00:49<00:00,  1.97it/s]


Epoch 12 Val Loss: 0.9506 | Val Acc: 78.88%



Epoch 13/20 [Train]: 100%|██████████| 453/453 [03:46<00:00,  2.00it/s]


Epoch 13 Train Loss: 0.0059


Epoch 13/20 [Val]: 100%|██████████| 97/97 [00:47<00:00,  2.04it/s]


Epoch 13 Val Loss: 0.9500 | Val Acc: 78.91%



Epoch 14/20 [Train]: 100%|██████████| 453/453 [03:50<00:00,  1.97it/s]


Epoch 14 Train Loss: 0.0045


Epoch 14/20 [Val]: 100%|██████████| 97/97 [00:47<00:00,  2.06it/s]


Epoch 14 Val Loss: 0.9483 | Val Acc: 79.39%



Epoch 15/20 [Train]: 100%|██████████| 453/453 [03:52<00:00,  1.95it/s]


Epoch 15 Train Loss: 0.0050


Epoch 15/20 [Val]: 100%|██████████| 97/97 [00:49<00:00,  1.98it/s]


Epoch 15 Val Loss: 0.9487 | Val Acc: 79.07%



Epoch 16/20 [Train]: 100%|██████████| 453/453 [03:54<00:00,  1.93it/s]


Epoch 16 Train Loss: 0.0044


Epoch 16/20 [Val]: 100%|██████████| 97/97 [00:48<00:00,  2.00it/s]


Epoch 16 Val Loss: 0.9472 | Val Acc: 79.43%



Epoch 17/20 [Train]: 100%|██████████| 453/453 [03:45<00:00,  2.00it/s]


Epoch 17 Train Loss: 0.0030


Epoch 17/20 [Val]: 100%|██████████| 97/97 [00:47<00:00,  2.05it/s]


Epoch 17 Val Loss: 0.9458 | Val Acc: 79.52%



Epoch 18/20 [Train]: 100%|██████████| 453/453 [03:48<00:00,  1.98it/s]


Epoch 18 Train Loss: 0.0026


Epoch 18/20 [Val]: 100%|██████████| 97/97 [00:46<00:00,  2.09it/s]


Epoch 18 Val Loss: 0.9492 | Val Acc: 78.62%



Epoch 19/20 [Train]: 100%|██████████| 453/453 [03:53<00:00,  1.94it/s]


Epoch 19 Train Loss: 0.0030


Epoch 19/20 [Val]: 100%|██████████| 97/97 [00:49<00:00,  1.97it/s]


Epoch 19 Val Loss: 0.9475 | Val Acc: 79.17%



Epoch 20/20 [Train]: 100%|██████████| 453/453 [03:51<00:00,  1.96it/s]


Epoch 20 Train Loss: 0.0027


Epoch 20/20 [Val]: 100%|██████████| 97/97 [00:46<00:00,  2.10it/s]

Epoch 20 Val Loss: 0.9465 | Val Acc: 79.46%

Training complete.





In [None]:
# Test Evaluation
classifier.eval()
all_preds    = []
all_labels   = []
correct_test = 0
total_test   = 0

with torch.no_grad():
    for imgs, labels in tqdm(test_loader, desc="Testing"):
        imgs, labels = imgs.to(device), labels.to(device)

        raw_outputs = classifier(imgs)
        if isinstance(raw_outputs, tuple):
            outputs = raw_outputs[0]
        else:
            outputs = raw_outputs

        _, preds = torch.max(outputs, dim=1)
        correct_test += (preds == labels).sum().item()
        total_test   += labels.size(0)
        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())

test_acc = 100.0 * correct_test / total_test
print(f"\nTest Accuracy: {test_acc:.2f}%")

all_preds  = torch.cat(all_preds).numpy()
all_labels = torch.cat(all_labels).numpy()

precision_test = precision_score(all_labels, all_preds, average="macro", zero_division=0)
recall_test    = recall_score(all_labels, all_preds, average="macro", zero_division=0)
f1_test        = f1_score(all_labels, all_preds, average="macro", zero_division=0)

print(f"Test Precision (macro): {precision_test:.4f}")
print(f"Test Recall    (macro): {recall_test:.4f}")
print(f"Test F1‑Score  (macro): {f1_test:.4f}\n")

print("Test: Per‑class Precision / Recall / F1:\n")
print(classification_report(
    all_labels,
    all_preds,
    target_names=test_dataset.classes,
    zero_division=0
))

cm = confusion_matrix(all_labels, all_preds)
print("Test Confusion Matrix (rows=true, cols=predicted):\n", cm)

Testing: 100%|██████████| 98/98 [07:42<00:00,  4.72s/it]


Test Accuracy: 79.68%
Test Precision (macro): 0.7737
Test Recall    (macro): 0.7771
Test F1‑Score  (macro): 0.7744

Test: Per‑class Precision / Recall / F1:

              precision    recall  f1-score   support

        mild       0.84      0.76      0.79      1262
    moderate       0.74      0.79      0.77       741
      normal       0.71      0.69      0.70       124
      severe       0.80      0.86      0.83       979

    accuracy                           0.80      3106
   macro avg       0.77      0.78      0.77      3106
weighted avg       0.80      0.80      0.80      3106

Test Confusion Matrix (rows=true, cols=predicted):
 [[954 136  22 150]
 [ 87 589  10  55]
 [ 14  20  86   4]
 [ 84  46   3 846]]





In [None]:
# Balanced Test Evaluation
classifier.eval()
all_preds    = []
all_labels   = []
correct_test = 0
total_test   = 0

with torch.no_grad():
    for imgs, labels in tqdm(test_balanced_loader, desc="Balanced Testing"):
        imgs, labels = imgs.to(device), labels.to(device)

        raw_outputs = classifier(imgs)
        # YOLOv8 classification head may return a tuple (logits, aux). Grab logits if so.
        if isinstance(raw_outputs, tuple):
            outputs = raw_outputs[0]
        else:
            outputs = raw_outputs

        _, preds = torch.max(outputs, dim=1)
        correct_test += (preds == labels).sum().item()
        total_test   += labels.size(0)
        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())

test_acc = 100.0 * correct_test / total_test
print(f"\nBalanced Test Accuracy: {test_acc:.2f}%")

all_preds  = torch.cat(all_preds).numpy()
all_labels = torch.cat(all_labels).numpy()

# On a balanced 400‐image set, 'macro' simply averages the 4 class scores (each has 100 samples),
# so it’s equivalent to the unweighted mean of per-class metrics.
precision_test = precision_score(all_labels, all_preds, average="macro", zero_division=0)
recall_test    = recall_score(all_labels, all_preds, average="macro", zero_division=0)
f1_test        = f1_score(all_labels, all_preds, average="macro", zero_division=0)

print(f"Balanced Test Precision (macro): {precision_test:.4f}")
print(f"Balanced Test Recall    (macro): {recall_test:.4f}")
print(f"Balanced Test F1-Score  (macro): {f1_test:.4f}\n")

print("Balanced Test: Per-class Precision / Recall / F1:\n")
print(classification_report(
    all_labels,
    all_preds,
    target_names=test_balanced_loader.dataset.classes,
    zero_division=0
))

cm = confusion_matrix(all_labels, all_preds)
print("Balanced Test Confusion Matrix (rows=true, cols=predicted):\n", cm)