In [1]:
import re
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GroupKFold
from aeon.classification.convolution_based import RocketClassifier
from tsai.all import Learner
from tsai.all import InceptionTime
from tsai.models.InceptionTimePlus import InceptionTimePlus
from tsai.models.ResNet import ResNet
from tsai.models.XceptionTime import XceptionTime
from tsai.all import TSDataLoaders
from tsai.all import get_splits
from tsai.all import accuracy
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_curve, auc, roc_auc_score, accuracy_score
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt

from tsai.models.InceptionTime import InceptionBlock
from tsai.models.layers import GAP1d

In [2]:
# -----------------------------------------------
# 1. Load a folder AND extract group ID from filename
# -----------------------------------------------
def load_folder(folder_path, label):
    X_list, y_list, groups = [], [], []

    for fname in sorted(os.listdir(folder_path)):
        if fname.endswith(".csv"):

            # ---- extract sample number for grouping ----
            # matches "_sample_7.csv" OR "_sample_7_1.csv"
            m = re.search(r"sample_(\d+)", fname)
            if m is None:
                continue
            sample_id = int(m.group(1))   # group ID

            df = pd.read_csv(os.path.join(folder_path, fname))
            df.columns = [c.strip() for c in df.columns]

            sig = df["Signal/nA"].values
            X_list.append(sig.reshape(1, -1))   # (1, L)
            y_list.append(label)
            groups.append(sample_id)

    return X_list, y_list, groups


# ------------------------------------------------
# 2. Load all folders
# ------------------------------------------------
DATA_PATHS = {
    0: r"C:\Dev\MasterThesis\data\Si_wafer_clear_SRRs",
    1: r"C:\Dev\MasterThesis\data\SRRs_cleaned_with_ethanol",
    2: r"C:\Dev\MasterThesis\data\SRRs_with_10ppb_Acetamiprid",
    3: r"C:\Dev\MasterThesis\data\SRRs_with_100ppb_Acetamiprid",
    4: r"C:\Dev\MasterThesis\data\SRRs_with_1000ppb_Acetamiprid",
}

X, y, groups = [], [], []

for label, path in DATA_PATHS.items():
    Xi, yi, gi = load_folder(path, label)
    X += Xi
    y += yi
    groups += gi

X = np.array(X, dtype=object)
y = np.array(y, dtype=int)
groups = np.array(groups, dtype=int)

print("Total samples:", len(X))


# ------------------------------------------------
# 3. Pad all signals to equal length
# ------------------------------------------------
max_len = max(ts.shape[1] for ts in X)
X_pad = np.zeros((len(X), 1, max_len))

for i, ts in enumerate(X):
    X_pad[i, 0, :ts.shape[1]] = ts

# tsai format: (N, seq_len, channels)
X_tsai = np.swapaxes(X_pad, 1, 2).astype(np.float32)
c_in = 1
c_out = len(np.unique(y))


# ------------------------------------------------
# 4. Grouped 10-fold cross-validation
# ------------------------------------------------
def grouped_inception_cv(X, y, groups, epochs=40):
    # ---- FIX SHAPE ----
    if X.shape[1] != 1:   # <-- wrong orientation
        X = np.swapaxes(X, 1, 2)

    print("Corrected X shape:", X.shape)
    
    gkf = GroupKFold(n_splits=10)
    fold_results = []
    auc_per_fold = []

    for fold, (train_idx, valid_idx) in enumerate(gkf.split(X, y, groups)):

        print(f"\n============================")
        print(f"   FOLD {fold+1} / 10")
        print(f"============================")

        dls = TSDataLoaders.from_numpy(
            X[train_idx], y[train_idx],
            splits=[list(train_idx), list(valid_idx)],
            bs=32, num_workers=0
        )

        model = InceptionTimePlus(c_in=c_in, c_out=c_out)
        learn = Learner(dls, model, metrics=accuracy)

        learn.fit_one_cycle(epochs, 1e-3)

        # ---- predictions for ROC AUC ----
        preds, targets = learn.get_preds(ds_idx=1)   # validation only
        preds = preds.numpy()
        targets = targets.numpy().astype(int)

        # AUC per class
        aucs = []
        for cls in range(c_out):
            y_true = (targets == cls).astype(int)
            y_score = preds[:, cls]
            auc_val = roc_auc_score(y_true, y_score)
            aucs.append(auc_val)

        print("AUCs:", aucs)
        auc_per_fold.append(aucs)

        # accuracy
        acc = accuracy(torch.tensor(preds), torch.tensor(targets)).item()
        fold_results.append(acc)

    return fold_results, auc_per_fold


# ------------------------------------------------
# 5. RUN THE FULL PIPELINE
# ------------------------------------------------
accs, aucs = grouped_inception_cv(X_tsai, y, groups, epochs=40)

print("\nFinal accuracy per fold:")
print(accs)

print("\nMean accuracy:", np.mean(accs))

print("\nMean AUC per class:")
print(np.mean(np.array(aucs), axis=0))


Total samples: 400
Corrected X shape: (400, 1, 4001)

   FOLD 1 / 10


epoch,train_loss,valid_loss,accuracy,time
0,1.611614,1.606999,0.25,00:46
1,1.609055,1.607211,0.25,00:46
2,1.607851,1.607047,0.25,00:45
3,1.606188,1.611724,0.180556,00:45
4,1.605382,1.623542,0.194444,00:45
5,1.596699,2.182077,0.194444,00:45
6,1.581566,3.049576,0.194444,00:47
7,1.561404,2.600558,0.194444,00:51
8,1.543379,2.705513,0.194444,00:50
9,1.524722,2.423602,0.194444,00:50


AUCs: [0.8981481481481481, 0.8083441981747066, 1.0, 0.8409387222946545, 0.8583743842364532]

   FOLD 2 / 10


epoch,train_loss,valid_loss,accuracy,time
0,1.610755,1.607257,0.180556,00:51
1,1.608534,1.607488,0.180556,00:50
2,1.608365,1.608454,0.208333,00:50
3,1.607356,1.609161,0.180556,00:51
4,1.606979,1.6338,0.180556,00:51
5,1.60154,1.874179,0.194444,00:50
6,1.586492,2.878037,0.180556,00:50
7,1.567757,3.059578,0.194444,00:50
8,1.541072,2.340415,0.194444,00:50
9,1.529595,2.273565,0.194444,00:50


AUCs: [0.874485596707819, 0.7692307692307693, 0.9975369458128078, 0.7679269882659714, 0.8793103448275862]

   FOLD 3 / 10


epoch,train_loss,valid_loss,accuracy,time
0,1.610286,1.607179,0.25,00:51
1,1.608576,1.607353,0.25,00:51
2,1.607456,1.607666,0.180556,00:51
3,1.606175,1.610682,0.208333,00:51
4,1.60225,1.618998,0.194444,00:51
5,1.59267,2.754516,0.194444,00:51
6,1.578074,2.740171,0.194444,00:51
7,1.558118,3.244506,0.194444,00:57
8,1.533469,2.044488,0.194444,00:53
9,1.520357,2.057014,0.194444,00:51


AUCs: [0.9002057613168724, 0.7848761408083442, 1.0, 0.7966101694915255, 0.895320197044335]

   FOLD 4 / 10


epoch,train_loss,valid_loss,accuracy,time
0,1.611183,1.607159,0.25,00:50
1,1.60985,1.60747,0.25,00:50
2,1.608415,1.608786,0.180556,00:51
3,1.606775,1.608407,0.180556,00:50
4,1.604639,1.605021,0.125,00:50
5,1.594455,2.513197,0.180556,00:50
6,1.587438,2.566331,0.194444,00:50
7,1.570355,4.550733,0.194444,00:50
8,1.56649,3.239896,0.194444,00:50
9,1.554971,1.85911,0.194444,00:50


AUCs: [0.8837448559670782, 0.7953063885267275, 1.0, 0.8722294654498044, 0.9076354679802956]

   FOLD 5 / 10


epoch,train_loss,valid_loss,accuracy,time
0,1.61063,1.607171,0.25,00:50
1,1.608666,1.607632,0.180556,00:50
2,1.607191,1.607865,0.25,00:50
3,1.606027,1.609485,0.180556,00:50
4,1.604871,1.612553,0.194444,00:50
5,1.595964,2.175546,0.180556,00:50
6,1.580183,3.523977,0.194444,00:50
7,1.566045,2.931791,0.194444,00:50
8,1.543,3.818985,0.194444,00:50
9,1.517629,1.587577,0.194444,00:50


AUCs: [0.9012345679012346, 0.8252933507170794, 1.0, 0.8500651890482399, 0.8842364532019704]

   FOLD 6 / 10


epoch,train_loss,valid_loss,accuracy,time
0,1.610252,1.607151,0.25,00:50
1,1.608598,1.607386,0.25,00:50
2,1.607687,1.608116,0.347222,00:50
3,1.60557,1.608007,0.194444,00:50
4,1.602976,1.628237,0.166667,00:51
5,1.592367,3.100019,0.194444,00:50
6,1.572606,3.148144,0.194444,00:50
7,1.554283,2.475112,0.194444,00:50
8,1.534603,2.90555,0.194444,00:50
9,1.508746,1.859893,0.194444,00:50


AUCs: [0.9207818930041152, 0.8083441981747066, 1.0, 0.8565840938722294, 0.8940886699507389]

   FOLD 7 / 10


epoch,train_loss,valid_loss,accuracy,time
0,1.613241,1.607052,0.25,00:51
1,1.610474,1.60712,0.25,00:50
2,1.608432,1.608055,0.180556,00:50
3,1.606505,1.609386,0.180556,00:50
4,1.603184,1.613378,0.180556,00:50
5,1.5915,2.505195,0.194444,00:51
6,1.575626,3.39537,0.194444,00:50
7,1.561334,3.227718,0.194444,00:50
8,1.538006,2.065432,0.194444,00:50
9,1.514396,3.294731,0.194444,00:50


AUCs: [0.9135802469135802, 0.7731421121251629, 1.0, 0.8422425032594524, 0.8731527093596059]

   FOLD 8 / 10


epoch,train_loss,valid_loss,accuracy,time
0,1.612597,1.607564,0.194444,00:51
1,1.610047,1.608104,0.194444,00:51
2,1.608794,1.609187,0.180556,00:51
3,1.607195,1.606521,0.180556,00:50
4,1.601521,1.633766,0.180556,00:50
5,1.58616,2.697075,0.194444,00:50
6,1.569025,3.296235,0.194444,00:50
7,1.55906,4.557605,0.194444,00:50
8,1.552782,1.623518,0.305556,00:50
9,1.544443,1.928229,0.194444,00:50


AUCs: [0.8559670781893004, 0.8031290743155151, 1.0, 0.8148631029986961, 0.8460591133004927]

   FOLD 9 / 10


epoch,train_loss,valid_loss,accuracy,time
0,1.610923,1.607188,0.25,00:50
1,1.609179,1.607645,0.180556,00:50
2,1.607269,1.608739,0.180556,00:50
3,1.605532,1.608403,0.166667,00:50
4,1.603969,1.62749,0.194444,00:50
5,1.594527,2.877244,0.194444,00:50
6,1.579598,3.732547,0.194444,00:50
7,1.561825,3.603869,0.194444,00:50
8,1.54717,3.120521,0.194444,00:50
9,1.524717,3.071802,0.194444,00:50


AUCs: [0.881687242798354, 0.8213820078226858, 1.0, 0.8396349413298566, 0.8349753694581281]

   FOLD 10 / 10


epoch,train_loss,valid_loss,accuracy,time
0,1.611675,1.60727,0.25,00:51
1,1.60956,1.607746,0.194444,00:51
2,1.608016,1.609356,0.180556,00:50
3,1.606837,1.609952,0.180556,00:50
4,1.605358,1.61415,0.194444,00:50
5,1.59629,2.517113,0.194444,00:50
6,1.578806,3.885685,0.194444,00:50
7,1.560239,4.111368,0.194444,00:50
8,1.535521,3.707723,0.194444,00:50
9,1.514165,2.187713,0.194444,00:51


AUCs: [0.8796296296296297, 0.8591916558018253, 1.0, 0.8305084745762712, 0.8041871921182266]

Final accuracy per fold:
[0.5555555820465088, 0.4722222089767456, 0.5416666865348816, 0.6111111044883728, 0.625, 0.6111111044883728, 0.5833333134651184, 0.5833333134651184, 0.5416666865348816, 0.5972222089767456]

Mean accuracy: 0.5722222208976746

Mean AUC per class:
[0.8909465  0.80482399 0.99975369 0.83116037 0.86773399]


In [3]:
import numpy as np
from collections import Counter

print("Unique groups:", len(np.unique(groups)))
print("Group counts:", Counter(groups))


Unique groups: 40
Group counts: Counter({np.int64(1): 10, np.int64(10): 10, np.int64(11): 10, np.int64(12): 10, np.int64(13): 10, np.int64(14): 10, np.int64(15): 10, np.int64(16): 10, np.int64(17): 10, np.int64(18): 10, np.int64(19): 10, np.int64(2): 10, np.int64(20): 10, np.int64(21): 10, np.int64(22): 10, np.int64(23): 10, np.int64(24): 10, np.int64(25): 10, np.int64(26): 10, np.int64(27): 10, np.int64(28): 10, np.int64(29): 10, np.int64(3): 10, np.int64(30): 10, np.int64(31): 10, np.int64(32): 10, np.int64(33): 10, np.int64(34): 10, np.int64(35): 10, np.int64(36): 10, np.int64(37): 10, np.int64(38): 10, np.int64(39): 10, np.int64(4): 10, np.int64(40): 10, np.int64(5): 10, np.int64(6): 10, np.int64(7): 10, np.int64(8): 10, np.int64(9): 10})
