this is the code which was able to load the saved weights + config easily. 76% accuracy model

In [None]:
import os
import torch
import numpy as np
from scipy.signal import butter, filtfilt
from models.resnet_1d import ResNet1D  # Make sure this path is correct

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Bandpass Filter ----------
def bandpass_filter(data, lowcut=1.0, highcut=40.0, fs=250.0, order=5):
    nyq = 0.5 * fs
    b, a = butter(order, [lowcut / nyq, highcut / nyq], btype='band')
    return filtfilt(b, a, data)

# ---------- Preprocess EEG ----------
def preprocess_eeg(file_path, target_length=500, stats=None):
    raw = np.loadtxt(file_path, delimiter=",", skiprows=1)
    eeg = raw[:, :22].T  # [22, T]

    eeg = np.array([bandpass_filter(ch) for ch in eeg])

    if eeg.shape[1] > target_length:
        eeg = eeg[:, :target_length]
    else:
        eeg = np.pad(eeg, ((0, 0), (0, target_length - eeg.shape[1])), mode="constant")

    if stats:
        eeg[:20] = (eeg[:20] - stats['eeg_mean']) / (stats['eeg_std'] + 1e-8)

    return torch.tensor(eeg, dtype=torch.float32)

# ---------- Load Saved Model ----------
checkpoint = torch.load("checkpoints/best_retrained_model.pt", map_location=device, weights_only=False)
model_config = checkpoint['config']
stats = checkpoint['stats']

model = ResNet1D(**model_config)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()


In [12]:

def predict_single_file(file_path, true_age, model, stats):
    eeg_tensor = preprocess_eeg(file_path, stats=stats).unsqueeze(0).to(device)  # [1, 22, 500]
    age_tensor = torch.tensor([(true_age - stats['age_mean']) / stats['age_std']], dtype=torch.float32).unsqueeze(1).to(device)  # [1, 1]

    with torch.no_grad():
        output = model(eeg_tensor, age_tensor)
        predicted_class = output.argmax(dim=1).item()

    class_names = ['HC', 'MCI', 'AD']
    print(f"🧠 EEG File: {os.path.basename(file_path)}")
    print(f"✅ Predicted Label: {class_names[predicted_class]}")

    return predicted_class

'''test_file_path = "data/balanced_subset/0_00906.csv"
true_age = 55'''

'''test_file_path = "data/balanced_subset/1_00582.csv"
true_age = 70'''

test_file_path = "data/balanced_subset/2_00746.csv"
true_age = 65 
predict_single_file(test_file_path, true_age, model, stats)

🧠 EEG File: 2_00746.csv
✅ Predicted Label: AD


2