In [None]:
import torch
import numpy as np
import librosa

def compute_spectrogram(waveform, sample_rate=16000, n_fft=1024, hop_length=512):
    # convert to numpy
    if isinstance(waveform, torch.Tensor):
        waveform = waveform.numpy()

    # Compute the STFT
    stft = librosa.stft(waveform, n_fft=n_fft, hop_length=hop_length)
    spectrogram = librosa.amplitude_to_db(np.abs(stft), ref=np.max)  # Convert to dB scale
    phase = np.angle(stft)  # Get the phase information

    # print(spectrogram.shape)

    return torch.from_numpy(spectrogram).float(), torch.from_numpy(phase).float()

def compute_mel_spectrogram(
    waveform,
    sample_rate=16000,
    n_fft=1024,
    hop_length=512,
    n_mels=128,
    fmin=0,
    fmax=None,
    to_db=True
):
    """
    Compute the Mel-spectrogram from a raw waveform signal.

    Args:
        waveform (torch.Tensor or np.ndarray): 1D input waveform.
        sample_rate (int): Sampling rate of the waveform (Hz).
        n_fft (int): FFT window size for the STFT.
        hop_length (int): Number of samples between successive frames.
        n_mels (int): Number of Mel filter banks.
        fmin (float): Minimum frequency (Hz).
        fmax (float): Maximum frequency (Hz), defaults to Nyquist (sr/2).
        to_db (bool): Whether to convert power values to decibel (dB) scale.

    Returns:
        torch.FloatTensor: Mel-spectrogram with shape [n_mels, time_frames].
    """

    # Convert tensor to numpy array if needed
    if isinstance(waveform, torch.Tensor):
        waveform = waveform.detach().cpu().numpy()

    # Compute Mel-spectrogram (power)
    mel_spec = librosa.feature.melspectrogram(
        y=waveform,
        sr=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        fmin=fmin,
        fmax=fmax or sample_rate // 2
    )

    # Optionally convert to dB scale for better dynamic range visualization
    if to_db:
        mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

    # Convert to PyTorch tensor (float32)
    mel_spec = torch.from_numpy(mel_spec).float()

    return mel_spec


In [None]:
# ============================================================
# Parkinson Detection - Multi-Branch Fusion (CNN + Transformer)
# ============================================================

import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from transformers import AutoModel, AutoProcessor
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, classification_report
import torchaudio, os
from PIL import Image
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
labels =["HC_AH", "PD_AH"] #
class ParkinsonDataset(Dataset):
    def __init__(self, root_dir, processor, img_transform=None):
        self.samples = []
        self.processor = processor
        self.img_transform = img_transform

        for label, cls in enumerate(labels):
            cls_dir = os.path.join(root_dir, cls)
            for fname in os.listdir(cls_dir):
                if fname.endswith(".wav"):
                    wav_path = os.path.join(cls_dir, fname)
                    png_path = wav_path.replace("/audio/","/").replace(".wav", ".jpg")
                    #print(fname)
                    #if os.path.exists(png_path):
                    self.samples.append((wav_path, png_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        wav_path, img_path, label = self.samples[idx]

        # Load waveform
        waveform, sr = torchaudio.load(wav_path)
        waveform = torchaudio.functional.resample(waveform, sr, 16000)

        # Whisper processor → Mel log features
        inputs = self.processor(audio=waveform.squeeze().numpy(),
                                sampling_rate=16000, return_tensors="pt")
        wav_input = inputs.input_features.squeeze(0)  # [80, time]

        # Load Mel-spectrogram image
        #image,_ = compute_spectrogram(waveform) #Image.open(img_path).convert("RGB")
        spectrogram, _ = compute_spectrogram(waveform)
        if spectrogram.ndim == 3:
          spectrogram = spectrogram.squeeze(0)  # remove extra channel dim

        # Convert to 3-channel NumPy image (for CNN)
        image = np.stack([spectrogram, spectrogram, spectrogram], axis=-1)  # (H, W, 3)
        if self.img_transform:
            image = self.img_transform(image)

        return image, wav_input, label

In [None]:
# -------------------------
# Model definition
# -------------------------
class CNNBranch(nn.Module):
    def __init__(self, out_dim=1024):
        super().__init__()
        base = models.inception_v3(models.Inception_V3_Weights.IMAGENET1K_V1, aux_logits=True)
        base.fc = nn.Identity()
        for p in base.parameters():
            p.requires_grad = False
        self.base = base
        self.fc = nn.Sequential(
            nn.BatchNorm1d(2048),
            nn.Linear(2048, out_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

    def forward(self, x):
        # ✅ Handle InceptionOutputs object
        outputs = self.base(x)
        if hasattr(outputs, "logits"):  # typical for InceptionV3
            feat = outputs.logits
        elif isinstance(outputs, tuple):  # backward compatibility
            feat = outputs[0]
        else:
            feat = outputs

        return self.fc(feat)


class TransformerBranch(nn.Module):
    def __init__(self, model_name="openai/whisper-tiny", out_dim=512):
        super().__init__()
        from transformers import WhisperModel
        # chỉ load phần encoder (không có decoder)
        whisper = WhisperModel.from_pretrained(model_name)
        self.encoder = whisper.encoder

        for p in self.encoder.parameters():
            p.requires_grad = False  # freeze encoder

        hidden_dim = self.encoder.config.d_model
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, out_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

    def forward(self, x):
        # x shape: [batch, features, time]
        out = self.encoder(x)
        pooled = out.last_hidden_state.mean(dim=1)
        return self.fc(pooled)


class FusionModel2(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn = CNNBranch(1024)
        self.trf = TransformerBranch(out_dim=512)

        self.gate = nn.Sequential(nn.Linear(1536, 1536), nn.Sigmoid())
        self.classifier = nn.Sequential(
            nn.Linear(1536, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 2)
        )

    def forward(self, img, wav):
        f1 = self.cnn(img)
        f2 = self.trf(wav)
        fused = torch.cat([f1, f2], dim=1)
        attn = self.gate(fused)
        fused = attn * fused + (1 - attn) * fused.mean(dim=1, keepdim=True)
        return self.classifier(fused)

class FusionModel(nn.Module):
    def __init__(self, use_transformer=True):
        super().__init__()
        self.use_transformer = use_transformer  # <--- new flag

        self.cnn = CNNBranch(1024)
        self.trf = TransformerBranch(out_dim=512)

        # total feature dim changes depending on transformer usage
        fusion_dim = 1024 + (512 if use_transformer else 0)

        self.gate = nn.Sequential(
            nn.Linear(fusion_dim, fusion_dim),
            nn.Sigmoid()
        )

        self.classifier = nn.Sequential(
            nn.Linear(fusion_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, 2)
        )

    def forward(self, img, wav=None):
        # CNN branch
        f1 = self.cnn(img)

        # Transformer branch (optional)
        if self.use_transformer and wav is not None:
            f2 = self.trf(wav)
            fused = torch.cat([f1, f2], dim=1)
        else:
            fused = f1  # use only CNN features

        # gating mechanism
        attn = self.gate(fused)
        fused = attn * fused + (1 - attn) * fused.mean(dim=1, keepdim=True)

        # classification head
        return self.classifier(fused)


# -------------------------
# Training & Evaluation
# -------------------------
def train_model(model, train_loader, val_loader, epochs=10, lr=1e-4):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        preds, labels = [], []

        for img, wav, y in train_loader:
            img, wav, y = img.to(device), wav.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(img, wav)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            preds += torch.argmax(logits, dim=1).cpu().tolist()
            labels += y.cpu().tolist()

        acc = accuracy_score(labels, preds)
        print(f"Epoch {epoch+1}/{epochs} - Train loss: {total_loss/len(train_loader):.4f} - Acc: {acc:.3f}")

        evaluate(model, val_loader)

def evaluate(model, loader):
    model.eval()
    preds, probs, labels = [], [], []
    with torch.no_grad():
        for img, wav, y in loader:
            img, wav = img.to(device), wav.to(device)
            logits = model(img, wav)
            pred = torch.argmax(logits, dim=1)
            prob = torch.softmax(logits, dim=1)[:, 1]
            preds += pred.cpu().tolist()
            probs += prob.cpu().tolist()
            labels += y.tolist()

    acc = accuracy_score(labels, preds)
    auc = roc_auc_score(labels, probs)
    cm = confusion_matrix(labels, preds)
    print(f"Val Accuracy: {acc:.3f} | AUC: {auc:.3f}")
    print("Confusion Matrix:\n", cm)
    print(classification_report(labels, preds, target_names=["HC", "PD"]))
    return acc, auc


# -------------------------
# Run training
# -------------------------
#if __name__ == "__main__":
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")






In [None]:
test_dataset.samples

[('/content/drive/MyDrive/Colab Notebooks/Parkinson/Augmented/test/HC_AH/adrso178.wav',
  '/content/drive/MyDrive/Colab Notebooks/Parkinson/Augmented/test/HC_AH/adrso178.jpg',
  0),
 ('/content/drive/MyDrive/Colab Notebooks/Parkinson/Augmented/test/HC_AH/adrso172.wav',
  '/content/drive/MyDrive/Colab Notebooks/Parkinson/Augmented/test/HC_AH/adrso172.jpg',
  0),
 ('/content/drive/MyDrive/Colab Notebooks/Parkinson/Augmented/test/HC_AH/test_clipping_distortion_4_adrso010.wav',
  '/content/drive/MyDrive/Colab Notebooks/Parkinson/Augmented/test/HC_AH/test_clipping_distortion_4_adrso010.jpg',
  0),
 ('/content/drive/MyDrive/Colab Notebooks/Parkinson/Augmented/test/HC_AH/test_add_reverb_5_adrso264.wav',
  '/content/drive/MyDrive/Colab Notebooks/Parkinson/Augmented/test/HC_AH/test_add_reverb_5_adrso264.jpg',
  0),
 ('/content/drive/MyDrive/Colab Notebooks/Parkinson/Augmented/test/HC_AH/adrso016.wav',
  '/content/drive/MyDrive/Colab Notebooks/Parkinson/Augmented/test/HC_AH/adrso016.jpg',
  0),


In [None]:
img_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

dataset = ParkinsonDataset("/content/drive/MyDrive/Colab Notebooks/Parkinson/Augmented/train", processor, img_transform)
test_dataset = ParkinsonDataset("/content/drive/MyDrive/Colab Notebooks/Parkinson/Augmented/test", processor, img_transform)
#dataset = ParkinsonDataset("/content/drive/MyDrive/Colab Notebooks/Parkinson/Dataset/audio", processor, img_transform)


train_size = int(0.6 * len(dataset))
val_size = len(dataset) - train_size
#train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])
train_ds = dataset
val_ds = test_dataset

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=4, shuffle=False)



# Base Pipeline

In [None]:
model = FusionModel(use_transformer=False)
train_model(model, train_loader, val_loader, epochs=20, lr=1e-4)


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 1/20 - Train loss: 0.6891 - Acc: 0.590


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 0.725 | AUC: 0.930
Confusion Matrix:
 [[19  1]
 [10 10]]
              precision    recall  f1-score   support

          HC       0.66      0.95      0.78        20
          PD       0.91      0.50      0.65        20

    accuracy                           0.72        40
   macro avg       0.78      0.72      0.71        40
weighted avg       0.78      0.72      0.71        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 2/20 - Train loss: 0.5859 - Acc: 0.680


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 0.825 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 7 13]]
              precision    recall  f1-score   support

          HC       0.74      1.00      0.85        20
          PD       1.00      0.65      0.79        20

    accuracy                           0.82        40
   macro avg       0.87      0.82      0.82        40
weighted avg       0.87      0.82      0.82        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 3/20 - Train loss: 0.4680 - Acc: 0.780


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 0.950 | AUC: 0.997
Confusion Matrix:
 [[20  0]
 [ 2 18]]
              precision    recall  f1-score   support

          HC       0.91      1.00      0.95        20
          PD       1.00      0.90      0.95        20

    accuracy                           0.95        40
   macro avg       0.95      0.95      0.95        40
weighted avg       0.95      0.95      0.95        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 4/20 - Train loss: 0.3994 - Acc: 0.840


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 0.875 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 5 15]]
              precision    recall  f1-score   support

          HC       0.80      1.00      0.89        20
          PD       1.00      0.75      0.86        20

    accuracy                           0.88        40
   macro avg       0.90      0.88      0.87        40
weighted avg       0.90      0.88      0.87        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 5/20 - Train loss: 0.2642 - Acc: 0.915


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 6/20 - Train loss: 0.2282 - Acc: 0.915


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 7/20 - Train loss: 0.2306 - Acc: 0.905


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 8/20 - Train loss: 0.1547 - Acc: 0.955


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 9/20 - Train loss: 0.1717 - Acc: 0.925


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 10/20 - Train loss: 0.1049 - Acc: 0.955


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 11/20 - Train loss: 0.1986 - Acc: 0.925


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 12/20 - Train loss: 0.1764 - Acc: 0.910


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 13/20 - Train loss: 0.2766 - Acc: 0.860


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 14/20 - Train loss: 0.1694 - Acc: 0.920


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 15/20 - Train loss: 0.1428 - Acc: 0.955


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 16/20 - Train loss: 0.1011 - Acc: 0.970


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 17/20 - Train loss: 0.1605 - Acc: 0.915


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 18/20 - Train loss: 0.1921 - Acc: 0.925


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 19/20 - Train loss: 0.1614 - Acc: 0.945


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Epoch 20/20 - Train loss: 0.2195 - Acc: 0.900


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


Val Accuracy: 1.000 | AUC: 1.000
Confusion Matrix:
 [[20  0]
 [ 0 20]]
              precision    recall  f1-score   support

          HC       1.00      1.00      1.00        20
          PD       1.00      1.00      1.00        20

    accuracy                           1.00        40
   macro avg       1.00      1.00      1.00        40
weighted avg       1.00      1.00      1.00        40



# Proposed pipeline

In [1]:
import math
import numpy as np
import librosa.display
import matplotlib.pyplot as plt

rows= len(val_loader) // 4
for images, wavs, labels in val_loader:
  # Vẽ lưới ảnh
  plt.figure(figsize=(16, 4 * rows))
  cols = 4
  n_images = images.shape[0]
  rows = math.ceil(n_images / cols)

  # Chuyển tensor về numpy để hiển thị
  images_np = images.permute(0, 2, 3, 1).numpy()  # [B, H, W, C]
  images_np = (images_np * 0.5 + 0.5).clip(0, 1)  # khử chuẩn hóa [-1,1] → [0,1]

  # Vẽ lưới ảnh
  plt.figure(figsize=(16, 4 * rows))
  for i in range(n_images):
      plt.subplot(rows, cols, i + 1)
      plt.imshow(images_np[i])
      plt.title(f"Label: {'PD' if labels[i]==1 else 'HC'}")
      plt.axis("off")

NameError: name 'val_loader' is not defined