In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

2.6.0+cu124


In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset, Subset
import numpy as np
import random
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, log_loss
from munkres import Munkres
from torchvision import models

In [2]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

In [3]:
class PairDataset(Dataset):
    def __init__(self, pairs, labels):
        self.pairs = pairs
        self.labels = labels

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        return self.pairs[idx][0], self.pairs[idx][1], self.labels[idx]

In [4]:
class SiameseNet(nn.Module):
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)

    def forward_once(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

    def forward(self, x1, x2):
        out1 = self.forward_once(x1)
        out2 = self.forward_once(x2)
        dist = F.pairwise_distance(out1, out2)
        return dist, out1, out2

In [5]:
class OrthoLinear(nn.Module):
    def __init__(self, in_dim, out_dim, eps=1e-4):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim)
        self.eps = eps

    def forward(self, x):
        Y_tilde = self.fc(x)
        gram = Y_tilde.T @ Y_tilde + self.eps * torch.eye(Y_tilde.shape[1], device=x.device)
        L = torch.linalg.cholesky(gram)
        L_inv = torch.inverse(L)
        return Y_tilde @ L_inv.T

In [6]:
class SpectralNet(nn.Module):
    def __init__(self, input_dim, n_clusters, hidden_dim=256):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.ortho = OrthoLinear(hidden_dim, n_clusters)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.ortho(x)

In [7]:
def contrastive_loss(distance, label, margin=1.0):
    pos = label * torch.pow(distance, 2)
    neg = (1 - label) * torch.pow(torch.clamp(margin - distance, min=0.0), 2)
    return torch.mean(pos + neg)

def spectral_loss(Y, W):
    D = torch.diag(W.sum(axis=1))
    L = D - W
    num = torch.trace(Y.T @ L @ Y)
    denom = torch.trace(Y.T @ D @ Y)
    return num / (denom + 1e-12)

In [8]:
def compute_affinity(X, scale, n_neighbors=20):
    nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(X)
    distances, indices = nbrs.kneighbors(X)
    distances, indices = distances[:, 1:], indices[:, 1:]
    W = np.zeros((len(X), len(X)))
    for i in range(len(X)):
        for j in range(n_neighbors):
            W[i, indices[i, j]] = np.exp(-distances[i, j] ** 2 / (2 * scale ** 2))
            W[indices[i, j], i] = W[i, indices[i, j]]
    return W

def compute_scale(X, n_neighbors=20):
    nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(X)
    distances, _ = nbrs.kneighbors(X)
    return np.median(distances[:, -1])

def calculate_accuracy(y_pred, y_true, n_clusters):
    cm = confusion_matrix(y_true, y_pred)
    cost = np.zeros((n_clusters, n_clusters))
    for i in range(n_clusters):
        for j in range(n_clusters):
            cost[i, j] = cm[:, j].sum() - cm[i, j]
    m = Munkres()
    mapping = m.compute(cost.tolist())
    new_labels = np.zeros_like(y_pred)
    for row, col in mapping:
        new_labels[y_pred == row] = col
    return (new_labels == y_true).mean()

In [9]:
def train_siamese(siamese_net, dataloader, epochs=50, lr=1e-3, device='cpu'):
    siamese_net.to(device)
    opt = optim.Adam(siamese_net.parameters(), lr=lr)
    for ep in range(epochs):
        siamese_net.train()
        total_loss = 0.0
        for x1, x2, labels in dataloader:
            x1 = x1.to(device).float()
            x2 = x2.to(device).float()
            labels = labels.to(device).float()
            dist, _, _ = siamese_net(x1, x2)
            loss = contrastive_loss(dist, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item() * x1.size(0)
        avg_loss = total_loss / len(dataloader.dataset)
        print(f"[Siamese] Epoch {ep + 1}/{epochs}, Avg Loss={avg_loss:.6f}")
    siamese_net.to('cpu')
    return siamese_net

def train_spectral(spectral_net, X_train, W, epochs=50, lr=1e-3, tol=1e-6, device='cpu'):
    spectral_net.to(device)
    opt = optim.Adam([
        {'params': spectral_net.fc1.parameters()},
        {'params': spectral_net.fc2.parameters()},
        {'params': spectral_net.ortho.fc.parameters()}
    ], lr=lr)

    X_tensor = torch.tensor(X_train, dtype=torch.float32, device=device)
    W_tensor = torch.tensor(W, dtype=torch.float32, device=device)

    prev_loss = float('inf')
    for ep in range(epochs):
        spectral_net.train()
        Y = spectral_net(X_tensor)  # (n_samples, n_clusters)
        loss = spectral_loss(Y, W_tensor)
        opt.zero_grad()
        loss.backward()
        opt.step()

        loss_item = loss.item()
        print(f"[SpectralNet] Epoch {ep + 1}/{epochs}, Loss={loss_item:.8f}")

        if abs(prev_loss - loss_item) < tol:
            print("SpectralNet converged (tol reached).")
            break
        prev_loss = loss_item
    spectral_net.to('cpu')
    return spectral_net

In [13]:
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # ====== Load PneumoniaMNIST ======
    data = np.load('/home/snu/Downloads/pneumoniamnist_224.npz', allow_pickle=True)
    all_images = np.concatenate([data['train_images'], data['val_images'], data['test_images']], axis=0)
    all_labels = np.concatenate([data['train_labels'], data['val_labels'], data['test_labels']], axis=0).squeeze()

    images = all_images.astype(np.float32) / 255.0
    images = np.repeat(images[:, None, :, :], 3, axis=1)  # (N, 3, 224, 224)
    labels = all_labels.astype(np.int64)

    # ====== Select 2000 samples per class ======
    selected_indices = []
    num_per_class = 2000
    classes = np.unique(labels)
    for c in classes:
        class_idx = np.where(labels == c)[0]
        chosen = np.random.choice(class_idx, size=min(num_per_class, len(class_idx)), replace=False)
        selected_indices.extend(chosen)

    selected_indices = np.array(selected_indices)
    images = images[selected_indices]
    labels = labels[selected_indices]

    # ====== Create dataset ======
    dataset = TensorDataset(torch.tensor(images), torch.tensor(labels))
    loader = DataLoader(dataset, batch_size=64, shuffle=False)

    # ====== Load ResNet-18 (pretrained) ======
    resnet = models.resnet18(pretrained=True)
    resnet.fc = nn.Identity()  # remove final classification layer
    resnet = resnet.to(device)
    resnet.eval()

    # ====== Extract ResNet features ======
    feats, y_list = [], []
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(device).float()
            f = resnet(imgs)
            feats.append(f.cpu())
            y_list.append(lbls)

    X = torch.cat(feats, dim=0).numpy().astype(np.float32)
    y = torch.cat(y_list, dim=0).numpy().astype(np.int64)

    # Shuffle
    perm = np.random.permutation(len(X))
    X, y = X[perm], y[perm]
    print("Balanced subset:", X.shape, y.shape)

    num_nodes, num_feats = X.shape
    n_clusters = 2
    hidden_dim = 256
    batch_size = 16
    n_neighbors = 20

    # ====== Build siamese pairs ======
    pairs, labels_pairs = [], []
    nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(X)
    distances, indices = nbrs.kneighbors(X)
    for i in range(len(X)):
        for j in indices[i, 1:]:
            pairs.append([X[i], X[j]])
            labels_pairs.append(1)
        all_indices = set(range(len(X)))
        neighbor_set = set(indices[i, 1:])
        non_neighbors = list(all_indices - neighbor_set - {i})
        j = np.random.choice(non_neighbors)
        pairs.append([X[i], X[j]])
        labels_pairs.append(0)

    dataset_pairs = PairDataset(pairs, labels_pairs)
    dataloader_pairs = DataLoader(dataset_pairs, batch_size=batch_size, shuffle=True)

    # ====== Train Siamese ======
    siamese = SiameseNet(num_feats, hidden_dim)
    siamese = train_siamese(siamese, dataloader_pairs, epochs=50, device=device)

    with torch.no_grad():
        X_embed = siamese.forward_once(torch.tensor(X, dtype=torch.float32)).numpy()

    # ====== Train SpectralNet ======
    scale = compute_scale(X_embed, n_neighbors=n_neighbors)
    W = compute_affinity(X_embed, scale, n_neighbors=n_neighbors)
    spectral = SpectralNet(num_feats, n_clusters, hidden_dim)
    spectral = train_spectral(spectral, X, W, epochs=50, device=device)

    with torch.no_grad():
        Y = spectral(torch.tensor(X, dtype=torch.float32)).numpy()
        y_pred_proba = F.softmax(torch.tensor(Y), dim=1).numpy()

    kmeans = KMeans(n_clusters=n_clusters, n_init=20)
    y_pred = kmeans.fit_predict(Y)

    acc_score = calculate_accuracy(y_pred, y, n_clusters)
    acc_score_inverted = calculate_accuracy(1 - y_pred, y, n_clusters)
    if acc_score_inverted > acc_score:
        acc_score = acc_score_inverted
        y_pred = 1 - y_pred

    prec = precision_score(y, y_pred)
    rec = recall_score(y, y_pred)
    f1 = f1_score(y, y_pred)
    ll = log_loss(y, y_pred_proba)

    print("Final clustering accuracy:", acc_score)
    print("Precision:", prec)
    print("Recall:", rec)
    print("F1 Score:", f1)
    print("Log Loss:", ll)

    return {"accuracy": acc_score, "precision": prec, "recall": rec, "f1": f1, "log_loss": ll}


# --------------------------
# Multi-runs
# --------------------------
if __name__ == "__main__":
    num_runs = 10
    all_results = {"accuracy": [], "precision": [], "recall": [], "f1": [], "log_loss": []}
    for run in range(num_runs):
        print(f"\n--- Run {run+1}/{num_runs} ---")
        set_seed(run)
        res = main()
        for k in all_results.keys():
            all_results[k].append(res[k])

    print("\n=== FINAL SUMMARY ===")
    for metric, vals in all_results.items():
        print(f"{metric:>10} | mean={np.mean(vals):.4f} ± {np.std(vals):.4f}")


--- Run 1/10 ---




Balanced subset: (3583, 512) (3583,)
[Siamese] Epoch 1/50, Avg Loss=0.029973
[Siamese] Epoch 2/50, Avg Loss=0.029008
[Siamese] Epoch 3/50, Avg Loss=0.028685
[Siamese] Epoch 4/50, Avg Loss=0.028501
[Siamese] Epoch 5/50, Avg Loss=0.028510
[Siamese] Epoch 6/50, Avg Loss=0.028595
[Siamese] Epoch 7/50, Avg Loss=0.028478
[Siamese] Epoch 8/50, Avg Loss=0.028447
[Siamese] Epoch 9/50, Avg Loss=0.028396
[Siamese] Epoch 10/50, Avg Loss=0.028256
[Siamese] Epoch 11/50, Avg Loss=0.028247
[Siamese] Epoch 12/50, Avg Loss=0.028211
[Siamese] Epoch 13/50, Avg Loss=0.028168
[Siamese] Epoch 14/50, Avg Loss=0.028111
[Siamese] Epoch 15/50, Avg Loss=0.028082
[Siamese] Epoch 16/50, Avg Loss=0.028072
[Siamese] Epoch 17/50, Avg Loss=0.027985
[Siamese] Epoch 18/50, Avg Loss=0.027941
[Siamese] Epoch 19/50, Avg Loss=0.027991
[Siamese] Epoch 20/50, Avg Loss=0.027952
[Siamese] Epoch 21/50, Avg Loss=0.027915
[Siamese] Epoch 22/50, Avg Loss=0.027953
[Siamese] Epoch 23/50, Avg Loss=0.027911
[Siamese] Epoch 24/50, Avg Lo



Balanced subset: (3583, 512) (3583,)
[Siamese] Epoch 1/50, Avg Loss=0.029306
[Siamese] Epoch 2/50, Avg Loss=0.028232
[Siamese] Epoch 3/50, Avg Loss=0.027841
[Siamese] Epoch 4/50, Avg Loss=0.027758
[Siamese] Epoch 5/50, Avg Loss=0.027663
[Siamese] Epoch 6/50, Avg Loss=0.027607
[Siamese] Epoch 7/50, Avg Loss=0.027919
[Siamese] Epoch 8/50, Avg Loss=0.027846
[Siamese] Epoch 9/50, Avg Loss=0.027773
[Siamese] Epoch 10/50, Avg Loss=0.027685
[Siamese] Epoch 11/50, Avg Loss=0.027591
[Siamese] Epoch 12/50, Avg Loss=0.027590
[Siamese] Epoch 13/50, Avg Loss=0.027597
[Siamese] Epoch 14/50, Avg Loss=0.027491
[Siamese] Epoch 15/50, Avg Loss=0.027485
[Siamese] Epoch 16/50, Avg Loss=0.027463
[Siamese] Epoch 17/50, Avg Loss=0.027460
[Siamese] Epoch 18/50, Avg Loss=0.027403
[Siamese] Epoch 19/50, Avg Loss=0.027469
[Siamese] Epoch 20/50, Avg Loss=0.027402
[Siamese] Epoch 21/50, Avg Loss=0.027377
[Siamese] Epoch 22/50, Avg Loss=0.027406
[Siamese] Epoch 23/50, Avg Loss=0.027316
[Siamese] Epoch 24/50, Avg Lo



Balanced subset: (3583, 512) (3583,)
[Siamese] Epoch 1/50, Avg Loss=0.029421
[Siamese] Epoch 2/50, Avg Loss=0.028094
[Siamese] Epoch 3/50, Avg Loss=0.027782
[Siamese] Epoch 4/50, Avg Loss=0.027677
[Siamese] Epoch 5/50, Avg Loss=0.027621
[Siamese] Epoch 6/50, Avg Loss=0.027426
[Siamese] Epoch 7/50, Avg Loss=0.027334
[Siamese] Epoch 8/50, Avg Loss=0.027316
[Siamese] Epoch 9/50, Avg Loss=0.027160
[Siamese] Epoch 10/50, Avg Loss=0.027444
[Siamese] Epoch 11/50, Avg Loss=0.027450
[Siamese] Epoch 12/50, Avg Loss=0.027359
[Siamese] Epoch 13/50, Avg Loss=0.027384
[Siamese] Epoch 14/50, Avg Loss=0.027383
[Siamese] Epoch 15/50, Avg Loss=0.027309
[Siamese] Epoch 16/50, Avg Loss=0.027303
[Siamese] Epoch 17/50, Avg Loss=0.027296
[Siamese] Epoch 18/50, Avg Loss=0.027246
[Siamese] Epoch 19/50, Avg Loss=0.027267
[Siamese] Epoch 20/50, Avg Loss=0.027247
[Siamese] Epoch 21/50, Avg Loss=0.027184
[Siamese] Epoch 22/50, Avg Loss=0.027195
[Siamese] Epoch 23/50, Avg Loss=0.027206
[Siamese] Epoch 24/50, Avg Lo



Balanced subset: (3583, 512) (3583,)
[Siamese] Epoch 1/50, Avg Loss=0.029304
[Siamese] Epoch 2/50, Avg Loss=0.028125
[Siamese] Epoch 3/50, Avg Loss=0.027816
[Siamese] Epoch 4/50, Avg Loss=0.027675
[Siamese] Epoch 5/50, Avg Loss=0.027515
[Siamese] Epoch 6/50, Avg Loss=0.027414
[Siamese] Epoch 7/50, Avg Loss=0.027368
[Siamese] Epoch 8/50, Avg Loss=0.027296
[Siamese] Epoch 9/50, Avg Loss=0.027179
[Siamese] Epoch 10/50, Avg Loss=0.027082
[Siamese] Epoch 11/50, Avg Loss=0.026987
[Siamese] Epoch 12/50, Avg Loss=0.026918
[Siamese] Epoch 13/50, Avg Loss=0.026961
[Siamese] Epoch 14/50, Avg Loss=0.026918
[Siamese] Epoch 15/50, Avg Loss=0.026913
[Siamese] Epoch 16/50, Avg Loss=0.026847
[Siamese] Epoch 17/50, Avg Loss=0.026853
[Siamese] Epoch 18/50, Avg Loss=0.026817
[Siamese] Epoch 19/50, Avg Loss=0.026828
[Siamese] Epoch 20/50, Avg Loss=0.026850
[Siamese] Epoch 21/50, Avg Loss=0.026761
[Siamese] Epoch 22/50, Avg Loss=0.026802
[Siamese] Epoch 23/50, Avg Loss=0.026723
[Siamese] Epoch 24/50, Avg Lo



Balanced subset: (3583, 512) (3583,)
[Siamese] Epoch 1/50, Avg Loss=0.028963
[Siamese] Epoch 2/50, Avg Loss=0.028127
[Siamese] Epoch 3/50, Avg Loss=0.027953
[Siamese] Epoch 4/50, Avg Loss=0.027791
[Siamese] Epoch 5/50, Avg Loss=0.027705
[Siamese] Epoch 6/50, Avg Loss=0.027536
[Siamese] Epoch 7/50, Avg Loss=0.027514
[Siamese] Epoch 8/50, Avg Loss=0.027451
[Siamese] Epoch 9/50, Avg Loss=0.027374
[Siamese] Epoch 10/50, Avg Loss=0.027329
[Siamese] Epoch 11/50, Avg Loss=0.027343
[Siamese] Epoch 12/50, Avg Loss=0.027250
[Siamese] Epoch 13/50, Avg Loss=0.027264
[Siamese] Epoch 14/50, Avg Loss=0.027257
[Siamese] Epoch 15/50, Avg Loss=0.027211
[Siamese] Epoch 16/50, Avg Loss=0.027212
[Siamese] Epoch 17/50, Avg Loss=0.027140
[Siamese] Epoch 18/50, Avg Loss=0.027150
[Siamese] Epoch 19/50, Avg Loss=0.027152
[Siamese] Epoch 20/50, Avg Loss=0.027135
[Siamese] Epoch 21/50, Avg Loss=0.027103
[Siamese] Epoch 22/50, Avg Loss=0.027066
[Siamese] Epoch 23/50, Avg Loss=0.027084
[Siamese] Epoch 24/50, Avg Lo



Balanced subset: (3583, 512) (3583,)
[Siamese] Epoch 1/50, Avg Loss=0.029782
[Siamese] Epoch 2/50, Avg Loss=0.028805
[Siamese] Epoch 3/50, Avg Loss=0.028573
[Siamese] Epoch 4/50, Avg Loss=0.028452
[Siamese] Epoch 5/50, Avg Loss=0.028404
[Siamese] Epoch 6/50, Avg Loss=0.028322
[Siamese] Epoch 7/50, Avg Loss=0.028218
[Siamese] Epoch 8/50, Avg Loss=0.028149
[Siamese] Epoch 9/50, Avg Loss=0.028087
[Siamese] Epoch 10/50, Avg Loss=0.028042
[Siamese] Epoch 11/50, Avg Loss=0.027983
[Siamese] Epoch 12/50, Avg Loss=0.027969
[Siamese] Epoch 13/50, Avg Loss=0.027917
[Siamese] Epoch 14/50, Avg Loss=0.027887
[Siamese] Epoch 15/50, Avg Loss=0.027824
[Siamese] Epoch 16/50, Avg Loss=0.027831
[Siamese] Epoch 17/50, Avg Loss=0.027822
[Siamese] Epoch 18/50, Avg Loss=0.027768
[Siamese] Epoch 19/50, Avg Loss=0.027731
[Siamese] Epoch 20/50, Avg Loss=0.027727
[Siamese] Epoch 21/50, Avg Loss=0.027729
[Siamese] Epoch 22/50, Avg Loss=0.027671
[Siamese] Epoch 23/50, Avg Loss=0.027600
[Siamese] Epoch 24/50, Avg Lo



Balanced subset: (3583, 512) (3583,)
[Siamese] Epoch 1/50, Avg Loss=0.029651
[Siamese] Epoch 2/50, Avg Loss=0.028371
[Siamese] Epoch 3/50, Avg Loss=0.028221
[Siamese] Epoch 4/50, Avg Loss=0.028099
[Siamese] Epoch 5/50, Avg Loss=0.028418
[Siamese] Epoch 6/50, Avg Loss=0.028349
[Siamese] Epoch 7/50, Avg Loss=0.028208
[Siamese] Epoch 8/50, Avg Loss=0.028087
[Siamese] Epoch 9/50, Avg Loss=0.028025
[Siamese] Epoch 10/50, Avg Loss=0.027967
[Siamese] Epoch 11/50, Avg Loss=0.027843
[Siamese] Epoch 12/50, Avg Loss=0.027778
[Siamese] Epoch 13/50, Avg Loss=0.027857
[Siamese] Epoch 14/50, Avg Loss=0.027794
[Siamese] Epoch 15/50, Avg Loss=0.027803
[Siamese] Epoch 16/50, Avg Loss=0.027784
[Siamese] Epoch 17/50, Avg Loss=0.027705
[Siamese] Epoch 18/50, Avg Loss=0.027726
[Siamese] Epoch 19/50, Avg Loss=0.027703
[Siamese] Epoch 20/50, Avg Loss=0.027690
[Siamese] Epoch 21/50, Avg Loss=0.027629
[Siamese] Epoch 22/50, Avg Loss=0.027602
[Siamese] Epoch 23/50, Avg Loss=0.027573
[Siamese] Epoch 24/50, Avg Lo



Balanced subset: (3583, 512) (3583,)
[Siamese] Epoch 1/50, Avg Loss=0.029405
[Siamese] Epoch 2/50, Avg Loss=0.028208
[Siamese] Epoch 3/50, Avg Loss=0.028308
[Siamese] Epoch 4/50, Avg Loss=0.027791
[Siamese] Epoch 5/50, Avg Loss=0.027624
[Siamese] Epoch 6/50, Avg Loss=0.027545
[Siamese] Epoch 7/50, Avg Loss=0.027538
[Siamese] Epoch 8/50, Avg Loss=0.027439
[Siamese] Epoch 9/50, Avg Loss=0.027406
[Siamese] Epoch 10/50, Avg Loss=0.027331
[Siamese] Epoch 11/50, Avg Loss=0.027322
[Siamese] Epoch 12/50, Avg Loss=0.027714
[Siamese] Epoch 13/50, Avg Loss=0.027710
[Siamese] Epoch 14/50, Avg Loss=0.027709
[Siamese] Epoch 15/50, Avg Loss=0.027704
[Siamese] Epoch 16/50, Avg Loss=0.027652
[Siamese] Epoch 17/50, Avg Loss=0.027596
[Siamese] Epoch 18/50, Avg Loss=0.027613
[Siamese] Epoch 19/50, Avg Loss=0.027624
[Siamese] Epoch 20/50, Avg Loss=0.027569
[Siamese] Epoch 21/50, Avg Loss=0.027551
[Siamese] Epoch 22/50, Avg Loss=0.027485
[Siamese] Epoch 23/50, Avg Loss=0.027516
[Siamese] Epoch 24/50, Avg Lo



Balanced subset: (3583, 512) (3583,)
[Siamese] Epoch 1/50, Avg Loss=0.029106
[Siamese] Epoch 2/50, Avg Loss=0.028181
[Siamese] Epoch 3/50, Avg Loss=0.027549
[Siamese] Epoch 4/50, Avg Loss=0.027667
[Siamese] Epoch 5/50, Avg Loss=0.027668
[Siamese] Epoch 6/50, Avg Loss=0.027592
[Siamese] Epoch 7/50, Avg Loss=0.027529
[Siamese] Epoch 8/50, Avg Loss=0.027382
[Siamese] Epoch 9/50, Avg Loss=0.027417
[Siamese] Epoch 10/50, Avg Loss=0.027368
[Siamese] Epoch 11/50, Avg Loss=0.027278
[Siamese] Epoch 12/50, Avg Loss=0.027238
[Siamese] Epoch 13/50, Avg Loss=0.027227
[Siamese] Epoch 14/50, Avg Loss=0.027190
[Siamese] Epoch 15/50, Avg Loss=0.027182
[Siamese] Epoch 16/50, Avg Loss=0.027221
[Siamese] Epoch 17/50, Avg Loss=0.027217
[Siamese] Epoch 18/50, Avg Loss=0.027178
[Siamese] Epoch 19/50, Avg Loss=0.027172
[Siamese] Epoch 20/50, Avg Loss=0.027137
[Siamese] Epoch 21/50, Avg Loss=0.027130
[Siamese] Epoch 22/50, Avg Loss=0.027086
[Siamese] Epoch 23/50, Avg Loss=0.027080
[Siamese] Epoch 24/50, Avg Lo



Balanced subset: (3583, 512) (3583,)
[Siamese] Epoch 1/50, Avg Loss=0.029835
[Siamese] Epoch 2/50, Avg Loss=0.028690
[Siamese] Epoch 3/50, Avg Loss=0.028332
[Siamese] Epoch 4/50, Avg Loss=0.028308
[Siamese] Epoch 5/50, Avg Loss=0.028209
[Siamese] Epoch 6/50, Avg Loss=0.028201
[Siamese] Epoch 7/50, Avg Loss=0.028104
[Siamese] Epoch 8/50, Avg Loss=0.028042
[Siamese] Epoch 9/50, Avg Loss=0.027994
[Siamese] Epoch 10/50, Avg Loss=0.027925
[Siamese] Epoch 11/50, Avg Loss=0.027818
[Siamese] Epoch 12/50, Avg Loss=0.027782
[Siamese] Epoch 13/50, Avg Loss=0.027788
[Siamese] Epoch 14/50, Avg Loss=0.027736
[Siamese] Epoch 15/50, Avg Loss=0.027666
[Siamese] Epoch 16/50, Avg Loss=0.027666
[Siamese] Epoch 17/50, Avg Loss=0.027662
[Siamese] Epoch 18/50, Avg Loss=0.028123
[Siamese] Epoch 19/50, Avg Loss=0.028074
[Siamese] Epoch 20/50, Avg Loss=0.028021
[Siamese] Epoch 21/50, Avg Loss=0.028072
[Siamese] Epoch 22/50, Avg Loss=0.027967
[Siamese] Epoch 23/50, Avg Loss=0.027977
[Siamese] Epoch 24/50, Avg Lo

=== FINAL SUMMARY ===
  accuracy | mean=0.9032 ± 0.0126
 precision | mean=0.4878 ± 0.4007
    recall | mean=0.4239 ± 0.3377
        f1 | mean=0.4533 ± 0.3663
  log_loss | mean=0.6941 ± 0.0080