In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

2.6.0+cu124


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset, Subset
import numpy as np
import random
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, log_loss
from munkres import Munkres
from torchvision import models

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

In [3]:
class PairDataset(Dataset):
    def __init__(self, pairs, labels):
        self.pairs = pairs
        self.labels = labels

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        return self.pairs[idx][0], self.pairs[idx][1], self.labels[idx]

In [4]:
class SiameseNet(nn.Module):
    def __init__(self, input_dim, hidden_dim=256):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)

    def forward_once(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

    def forward(self, x1, x2):
        out1 = self.forward_once(x1)
        out2 = self.forward_once(x2)
        dist = F.pairwise_distance(out1, out2)
        return dist, out1, out2

In [5]:
class OrthoLinear(nn.Module):
    def __init__(self, in_dim, out_dim, eps=1e-4):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim)
        self.eps = eps

    def forward(self, x):
        Y_tilde = self.fc(x)
        gram = Y_tilde.T @ Y_tilde + self.eps * torch.eye(Y_tilde.shape[1], device=x.device)
        L = torch.linalg.cholesky(gram)
        L_inv = torch.inverse(L)
        return Y_tilde @ L_inv.T

In [6]:
class SpectralNet(nn.Module):
    def __init__(self, input_dim, n_clusters, hidden_dim=256):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.ortho = OrthoLinear(hidden_dim, n_clusters)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.ortho(x)

In [7]:
def contrastive_loss(distance, label, margin=1.0):
    pos = label * torch.pow(distance, 2)
    neg = (1 - label) * torch.pow(torch.clamp(margin - distance, min=0.0), 2)
    return torch.mean(pos + neg)

def spectral_loss(Y, W):
    D = torch.diag(W.sum(axis=1))
    L = D - W
    num = torch.trace(Y.T @ L @ Y)
    denom = torch.trace(Y.T @ D @ Y)
    return num / (denom + 1e-12)

In [8]:
def compute_affinity(X, scale, n_neighbors=20):
    nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(X)
    distances, indices = nbrs.kneighbors(X)
    distances, indices = distances[:, 1:], indices[:, 1:]
    W = np.zeros((len(X), len(X)))
    for i in range(len(X)):
        for j in range(n_neighbors):
            W[i, indices[i, j]] = np.exp(-distances[i, j] ** 2 / (2 * scale ** 2))
            W[indices[i, j], i] = W[i, indices[i, j]]
    return W

def compute_scale(X, n_neighbors=20):
    nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(X)
    distances, _ = nbrs.kneighbors(X)
    return np.median(distances[:, -1])

def calculate_accuracy(y_pred, y_true, n_clusters):
    cm = confusion_matrix(y_true, y_pred)
    cost = np.zeros((n_clusters, n_clusters))
    for i in range(n_clusters):
        for j in range(n_clusters):
            cost[i, j] = cm[:, j].sum() - cm[i, j]
    m = Munkres()
    mapping = m.compute(cost.tolist())
    new_labels = np.zeros_like(y_pred)
    for row, col in mapping:
        new_labels[y_pred == row] = col
    return (new_labels == y_true).mean()

In [9]:
def train_siamese(siamese_net, dataloader, epochs=50, lr=1e-3, device='cpu'):
    siamese_net.to(device)
    opt = optim.Adam(siamese_net.parameters(), lr=lr)
    for ep in range(epochs):
        siamese_net.train()
        total_loss = 0.0
        for x1, x2, labels in dataloader:
            x1 = x1.to(device).float()
            x2 = x2.to(device).float()
            labels = labels.to(device).float()
            dist, _, _ = siamese_net(x1, x2)
            loss = contrastive_loss(dist, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item() * x1.size(0)
        avg_loss = total_loss / len(dataloader.dataset)
        print(f"[Siamese] Epoch {ep + 1}/{epochs}, Avg Loss={avg_loss:.6f}")
    siamese_net.to('cpu')
    return siamese_net

def train_spectral(spectral_net, X_train, W, epochs=50, lr=1e-3, tol=1e-6, device='cpu'):
    spectral_net.to(device)
    opt = optim.Adam([
        {'params': spectral_net.fc1.parameters()},
        {'params': spectral_net.fc2.parameters()},
        {'params': spectral_net.ortho.fc.parameters()}
    ], lr=lr)

    X_tensor = torch.tensor(X_train, dtype=torch.float32, device=device)
    W_tensor = torch.tensor(W, dtype=torch.float32, device=device)

    prev_loss = float('inf')
    for ep in range(epochs):
        spectral_net.train()
        Y = spectral_net(X_tensor)  # (n_samples, n_clusters)
        loss = spectral_loss(Y, W_tensor)
        opt.zero_grad()
        loss.backward()
        opt.step()

        loss_item = loss.item()
        print(f"[SpectralNet] Epoch {ep + 1}/{epochs}, Loss={loss_item:.8f}")

        if abs(prev_loss - loss_item) < tol:
            print("SpectralNet converged (tol reached).")
            break
        prev_loss = loss_item
    spectral_net.to('cpu')
    return spectral_net

In [None]:
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # ====== Load PneumoniaMNIST ======
    data = np.load('/home/snu/Downloads/breastmnist_224.npz', allow_pickle=True)
    all_images = np.concatenate([data['train_images'], data['val_images'], data['test_images']], axis=0)
    all_labels = np.concatenate([data['train_labels'], data['val_labels'], data['test_labels']], axis=0).squeeze()

    images = all_images.astype(np.float32) / 255.0
    images = np.repeat(images[:, None, :, :], 3, axis=1)  # (N, 3, 224, 224)
    labels = all_labels.astype(np.int64)

    # ====== Select 2000 samples per class ======
    selected_indices = []
    num_per_class = 2000
    classes = np.unique(labels)
    for c in classes:
        class_idx = np.where(labels == c)[0]
        chosen = np.random.choice(class_idx, size=min(num_per_class, len(class_idx)), replace=False)
        selected_indices.extend(chosen)

    selected_indices = np.array(selected_indices)
    images = images[selected_indices]
    labels = labels[selected_indices]

    # ====== Create dataset ======
    dataset = TensorDataset(torch.tensor(images), torch.tensor(labels))
    loader = DataLoader(dataset, batch_size=64, shuffle=False)

    # ====== Load ResNet-18 (pretrained) ======
    resnet = models.resnet18(pretrained=True)
    resnet.fc = nn.Identity()  # remove final classification layer
    resnet = resnet.to(device)
    resnet.eval()

    # ====== Extract ResNet features ======
    feats, y_list = [], []
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(device).float()
            f = resnet(imgs)
            feats.append(f.cpu())
            y_list.append(lbls)

    X = torch.cat(feats, dim=0).numpy().astype(np.float32)
    y = torch.cat(y_list, dim=0).numpy().astype(np.int64)

    # Shuffle
    perm = np.random.permutation(len(X))
    X, y = X[perm], y[perm]
    print("Balanced subset:", X.shape, y.shape)

    num_nodes, num_feats = X.shape
    n_clusters = 2
    hidden_dim = 256
    batch_size = 16
    n_neighbors = 20

    # ====== Build siamese pairs ======
    pairs, labels_pairs = [], []
    nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(X)
    distances, indices = nbrs.kneighbors(X)
    for i in range(len(X)):
        for j in indices[i, 1:]:
            pairs.append([X[i], X[j]])
            labels_pairs.append(1)
        all_indices = set(range(len(X)))
        neighbor_set = set(indices[i, 1:])
        non_neighbors = list(all_indices - neighbor_set - {i})
        j = np.random.choice(non_neighbors)
        pairs.append([X[i], X[j]])
        labels_pairs.append(0)

    dataset_pairs = PairDataset(pairs, labels_pairs)
    dataloader_pairs = DataLoader(dataset_pairs, batch_size=batch_size, shuffle=True)

    # ====== Train Siamese ======
    siamese = SiameseNet(num_feats, hidden_dim)
    siamese = train_siamese(siamese, dataloader_pairs, epochs=50, device=device)

    with torch.no_grad():
        X_embed = siamese.forward_once(torch.tensor(X, dtype=torch.float32)).numpy()

    # ====== Train SpectralNet ======
    scale = compute_scale(X_embed, n_neighbors=n_neighbors)
    W = compute_affinity(X_embed, scale, n_neighbors=n_neighbors)
    spectral = SpectralNet(num_feats, n_clusters, hidden_dim)
    spectral = train_spectral(spectral, X, W, epochs=50, device=device)

    with torch.no_grad():
        Y = spectral(torch.tensor(X, dtype=torch.float32)).numpy()
        y_pred_proba = F.softmax(torch.tensor(Y), dim=1).numpy()

    kmeans = KMeans(n_clusters=n_clusters, n_init=20)
    y_pred = kmeans.fit_predict(Y)

    acc_score = calculate_accuracy(y_pred, y, n_clusters)
    acc_score_inverted = calculate_accuracy(1 - y_pred, y, n_clusters)
    if acc_score_inverted > acc_score:
        acc_score = acc_score_inverted
        y_pred = 1 - y_pred

    prec = precision_score(y, y_pred)
    rec = recall_score(y, y_pred)
    f1 = f1_score(y, y_pred)
    ll = log_loss(y, y_pred_proba)

    print("Final clustering accuracy:", acc_score)
    print("Precision:", prec)
    print("Recall:", rec)
    print("F1 Score:", f1)
    print("Log Loss:", ll)

    return {"accuracy": acc_score, "precision": prec, "recall": rec, "f1": f1, "log_loss": ll}


# --------------------------
# Multi-runs
# --------------------------
if __name__ == "__main__":
    num_runs = 10
    all_results = {"accuracy": [], "precision": [], "recall": [], "f1": [], "log_loss": []}
    for run in range(num_runs):
        print(f"\n--- Run {run+1}/{num_runs} ---")
        set_seed(run)
        res = main()
        for k in all_results.keys():
            all_results[k].append(res[k])

    print("\n=== FINAL SUMMARY ===")
    for metric, vals in all_results.items():
        print(f"{metric:>10} | mean={np.mean(vals):.4f} ± {np.std(vals):.4f}")


--- Run 1/10 ---




Balanced subset: (780, 512) (780,)
[Siamese] Epoch 1/50, Avg Loss=0.039203
[Siamese] Epoch 2/50, Avg Loss=0.036830
[Siamese] Epoch 3/50, Avg Loss=0.036653
[Siamese] Epoch 4/50, Avg Loss=0.036221
[Siamese] Epoch 5/50, Avg Loss=0.035884
[Siamese] Epoch 6/50, Avg Loss=0.035659
[Siamese] Epoch 7/50, Avg Loss=0.035608
[Siamese] Epoch 8/50, Avg Loss=0.035700
[Siamese] Epoch 9/50, Avg Loss=0.035384
[Siamese] Epoch 10/50, Avg Loss=0.035357
[Siamese] Epoch 11/50, Avg Loss=0.035345
[Siamese] Epoch 12/50, Avg Loss=0.035360
[Siamese] Epoch 13/50, Avg Loss=0.035205
[Siamese] Epoch 14/50, Avg Loss=0.035168
[Siamese] Epoch 15/50, Avg Loss=0.035153
[Siamese] Epoch 16/50, Avg Loss=0.035076
[Siamese] Epoch 17/50, Avg Loss=0.035146
[Siamese] Epoch 18/50, Avg Loss=0.034914
[Siamese] Epoch 19/50, Avg Loss=0.034985
[Siamese] Epoch 20/50, Avg Loss=0.034969
[Siamese] Epoch 21/50, Avg Loss=0.035045
[Siamese] Epoch 22/50, Avg Loss=0.034930
[Siamese] Epoch 23/50, Avg Loss=0.034837
[Siamese] Epoch 24/50, Avg Loss



Balanced subset: (780, 512) (780,)
[Siamese] Epoch 1/50, Avg Loss=0.039430
[Siamese] Epoch 2/50, Avg Loss=0.037566
[Siamese] Epoch 3/50, Avg Loss=0.036858
[Siamese] Epoch 4/50, Avg Loss=0.036749
[Siamese] Epoch 5/50, Avg Loss=0.036466
[Siamese] Epoch 6/50, Avg Loss=0.036420
[Siamese] Epoch 7/50, Avg Loss=0.036075
[Siamese] Epoch 8/50, Avg Loss=0.036038
[Siamese] Epoch 9/50, Avg Loss=0.036025
[Siamese] Epoch 10/50, Avg Loss=0.035881
[Siamese] Epoch 11/50, Avg Loss=0.035700
[Siamese] Epoch 12/50, Avg Loss=0.035755
[Siamese] Epoch 13/50, Avg Loss=0.035715
[Siamese] Epoch 14/50, Avg Loss=0.035711
[Siamese] Epoch 15/50, Avg Loss=0.035557
[Siamese] Epoch 16/50, Avg Loss=0.035649
[Siamese] Epoch 17/50, Avg Loss=0.035576
[Siamese] Epoch 18/50, Avg Loss=0.035638
[Siamese] Epoch 19/50, Avg Loss=0.035388
[Siamese] Epoch 20/50, Avg Loss=0.035393
[Siamese] Epoch 21/50, Avg Loss=0.035462
[Siamese] Epoch 22/50, Avg Loss=0.035360
[Siamese] Epoch 23/50, Avg Loss=0.035258
[Siamese] Epoch 24/50, Avg Loss



Balanced subset: (780, 512) (780,)
[Siamese] Epoch 1/50, Avg Loss=0.039738
[Siamese] Epoch 2/50, Avg Loss=0.037769
[Siamese] Epoch 3/50, Avg Loss=0.036603
[Siamese] Epoch 4/50, Avg Loss=0.036258
[Siamese] Epoch 5/50, Avg Loss=0.036074
[Siamese] Epoch 6/50, Avg Loss=0.036047
[Siamese] Epoch 7/50, Avg Loss=0.035837
[Siamese] Epoch 8/50, Avg Loss=0.035705
[Siamese] Epoch 9/50, Avg Loss=0.035689
[Siamese] Epoch 10/50, Avg Loss=0.035558
[Siamese] Epoch 11/50, Avg Loss=0.035585
[Siamese] Epoch 12/50, Avg Loss=0.035506
[Siamese] Epoch 13/50, Avg Loss=0.035367
[Siamese] Epoch 14/50, Avg Loss=0.035428
[Siamese] Epoch 15/50, Avg Loss=0.035510
[Siamese] Epoch 16/50, Avg Loss=0.035377
[Siamese] Epoch 17/50, Avg Loss=0.035249
[Siamese] Epoch 18/50, Avg Loss=0.035287
[Siamese] Epoch 19/50, Avg Loss=0.035232
[Siamese] Epoch 20/50, Avg Loss=0.035231
[Siamese] Epoch 21/50, Avg Loss=0.035116
[Siamese] Epoch 22/50, Avg Loss=0.035196
[Siamese] Epoch 23/50, Avg Loss=0.035180
[Siamese] Epoch 24/50, Avg Loss



Balanced subset: (780, 512) (780,)
[Siamese] Epoch 1/50, Avg Loss=0.038887
[Siamese] Epoch 2/50, Avg Loss=0.037745
[Siamese] Epoch 3/50, Avg Loss=0.036993
[Siamese] Epoch 4/50, Avg Loss=0.036601
[Siamese] Epoch 5/50, Avg Loss=0.036183
[Siamese] Epoch 6/50, Avg Loss=0.036304
[Siamese] Epoch 7/50, Avg Loss=0.036081
[Siamese] Epoch 8/50, Avg Loss=0.036006
[Siamese] Epoch 9/50, Avg Loss=0.035912
[Siamese] Epoch 10/50, Avg Loss=0.035759
[Siamese] Epoch 11/50, Avg Loss=0.035827
[Siamese] Epoch 12/50, Avg Loss=0.035890
[Siamese] Epoch 13/50, Avg Loss=0.035645
[Siamese] Epoch 14/50, Avg Loss=0.035418
[Siamese] Epoch 15/50, Avg Loss=0.035634
[Siamese] Epoch 16/50, Avg Loss=0.035769
[Siamese] Epoch 17/50, Avg Loss=0.035477
[Siamese] Epoch 18/50, Avg Loss=0.035570
[Siamese] Epoch 19/50, Avg Loss=0.035364
[Siamese] Epoch 20/50, Avg Loss=0.035413
[Siamese] Epoch 21/50, Avg Loss=0.035432
[Siamese] Epoch 22/50, Avg Loss=0.035456
[Siamese] Epoch 23/50, Avg Loss=0.035334
[Siamese] Epoch 24/50, Avg Loss



Balanced subset: (780, 512) (780,)
[Siamese] Epoch 1/50, Avg Loss=0.039361
[Siamese] Epoch 2/50, Avg Loss=0.037449
[Siamese] Epoch 3/50, Avg Loss=0.036844
[Siamese] Epoch 4/50, Avg Loss=0.036271
[Siamese] Epoch 5/50, Avg Loss=0.036168
[Siamese] Epoch 6/50, Avg Loss=0.035843
[Siamese] Epoch 7/50, Avg Loss=0.035714
[Siamese] Epoch 8/50, Avg Loss=0.035522
[Siamese] Epoch 9/50, Avg Loss=0.035543
[Siamese] Epoch 10/50, Avg Loss=0.035468
[Siamese] Epoch 11/50, Avg Loss=0.035406
[Siamese] Epoch 12/50, Avg Loss=0.035224
[Siamese] Epoch 13/50, Avg Loss=0.035139
[Siamese] Epoch 14/50, Avg Loss=0.035296
[Siamese] Epoch 15/50, Avg Loss=0.035206
[Siamese] Epoch 16/50, Avg Loss=0.035250
[Siamese] Epoch 17/50, Avg Loss=0.035055
[Siamese] Epoch 18/50, Avg Loss=0.035159
[Siamese] Epoch 19/50, Avg Loss=0.035051
[Siamese] Epoch 20/50, Avg Loss=0.034938
[Siamese] Epoch 21/50, Avg Loss=0.034959
[Siamese] Epoch 22/50, Avg Loss=0.035081
[Siamese] Epoch 23/50, Avg Loss=0.035007
[Siamese] Epoch 24/50, Avg Loss



Balanced subset: (780, 512) (780,)
[Siamese] Epoch 1/50, Avg Loss=0.039057
[Siamese] Epoch 2/50, Avg Loss=0.036702
[Siamese] Epoch 3/50, Avg Loss=0.036158
[Siamese] Epoch 4/50, Avg Loss=0.035707
[Siamese] Epoch 5/50, Avg Loss=0.035474
[Siamese] Epoch 6/50, Avg Loss=0.035277
[Siamese] Epoch 7/50, Avg Loss=0.035327
[Siamese] Epoch 8/50, Avg Loss=0.035250
[Siamese] Epoch 9/50, Avg Loss=0.035121
[Siamese] Epoch 10/50, Avg Loss=0.034923
[Siamese] Epoch 11/50, Avg Loss=0.034873
[Siamese] Epoch 12/50, Avg Loss=0.034810
[Siamese] Epoch 13/50, Avg Loss=0.034810
[Siamese] Epoch 14/50, Avg Loss=0.034656
[Siamese] Epoch 15/50, Avg Loss=0.034809
[Siamese] Epoch 16/50, Avg Loss=0.034764
[Siamese] Epoch 17/50, Avg Loss=0.034710
[Siamese] Epoch 18/50, Avg Loss=0.034661
[Siamese] Epoch 19/50, Avg Loss=0.034646
[Siamese] Epoch 20/50, Avg Loss=0.034565
[Siamese] Epoch 21/50, Avg Loss=0.034462
[Siamese] Epoch 22/50, Avg Loss=0.034650
[Siamese] Epoch 23/50, Avg Loss=0.034489
[Siamese] Epoch 24/50, Avg Loss



Balanced subset: (780, 512) (780,)
[Siamese] Epoch 1/50, Avg Loss=0.038976
[Siamese] Epoch 2/50, Avg Loss=0.036685
[Siamese] Epoch 3/50, Avg Loss=0.035882
[Siamese] Epoch 4/50, Avg Loss=0.035803
[Siamese] Epoch 5/50, Avg Loss=0.035401
[Siamese] Epoch 6/50, Avg Loss=0.035316
[Siamese] Epoch 7/50, Avg Loss=0.035378
[Siamese] Epoch 8/50, Avg Loss=0.035300
[Siamese] Epoch 9/50, Avg Loss=0.034913
[Siamese] Epoch 10/50, Avg Loss=0.034955
[Siamese] Epoch 11/50, Avg Loss=0.034896
[Siamese] Epoch 12/50, Avg Loss=0.034964
[Siamese] Epoch 13/50, Avg Loss=0.034691
[Siamese] Epoch 14/50, Avg Loss=0.034778
[Siamese] Epoch 15/50, Avg Loss=0.034672
[Siamese] Epoch 16/50, Avg Loss=0.034613
[Siamese] Epoch 17/50, Avg Loss=0.034649
[Siamese] Epoch 18/50, Avg Loss=0.034586
[Siamese] Epoch 19/50, Avg Loss=0.034599
[Siamese] Epoch 20/50, Avg Loss=0.034649
[Siamese] Epoch 21/50, Avg Loss=0.034651
[Siamese] Epoch 22/50, Avg Loss=0.034584
[Siamese] Epoch 23/50, Avg Loss=0.034437
[Siamese] Epoch 24/50, Avg Loss



Balanced subset: (780, 512) (780,)
[Siamese] Epoch 1/50, Avg Loss=0.039234
[Siamese] Epoch 2/50, Avg Loss=0.037716
[Siamese] Epoch 3/50, Avg Loss=0.036822
[Siamese] Epoch 4/50, Avg Loss=0.036429
[Siamese] Epoch 5/50, Avg Loss=0.035911
[Siamese] Epoch 6/50, Avg Loss=0.035857
[Siamese] Epoch 7/50, Avg Loss=0.035688
[Siamese] Epoch 8/50, Avg Loss=0.035537
[Siamese] Epoch 9/50, Avg Loss=0.035425
[Siamese] Epoch 10/50, Avg Loss=0.035406
[Siamese] Epoch 11/50, Avg Loss=0.035326
[Siamese] Epoch 12/50, Avg Loss=0.035389
[Siamese] Epoch 13/50, Avg Loss=0.035110
[Siamese] Epoch 14/50, Avg Loss=0.035172
[Siamese] Epoch 15/50, Avg Loss=0.035067
[Siamese] Epoch 16/50, Avg Loss=0.035230
[Siamese] Epoch 17/50, Avg Loss=0.035076
[Siamese] Epoch 18/50, Avg Loss=0.035030
[Siamese] Epoch 19/50, Avg Loss=0.034961
[Siamese] Epoch 20/50, Avg Loss=0.035042
[Siamese] Epoch 21/50, Avg Loss=0.035030
[Siamese] Epoch 22/50, Avg Loss=0.034905
[Siamese] Epoch 23/50, Avg Loss=0.034881
[Siamese] Epoch 24/50, Avg Loss



Balanced subset: (780, 512) (780,)
[Siamese] Epoch 1/50, Avg Loss=0.039332
[Siamese] Epoch 2/50, Avg Loss=0.037010
[Siamese] Epoch 3/50, Avg Loss=0.036266
[Siamese] Epoch 4/50, Avg Loss=0.036314
[Siamese] Epoch 5/50, Avg Loss=0.035848
[Siamese] Epoch 6/50, Avg Loss=0.035588
[Siamese] Epoch 7/50, Avg Loss=0.035599
[Siamese] Epoch 8/50, Avg Loss=0.035485
[Siamese] Epoch 9/50, Avg Loss=0.035419
[Siamese] Epoch 10/50, Avg Loss=0.035277
[Siamese] Epoch 11/50, Avg Loss=0.035256
[Siamese] Epoch 12/50, Avg Loss=0.035186
[Siamese] Epoch 13/50, Avg Loss=0.035153
[Siamese] Epoch 14/50, Avg Loss=0.034987
[Siamese] Epoch 15/50, Avg Loss=0.034986
[Siamese] Epoch 16/50, Avg Loss=0.034980
[Siamese] Epoch 17/50, Avg Loss=0.034955
[Siamese] Epoch 18/50, Avg Loss=0.034879
[Siamese] Epoch 19/50, Avg Loss=0.035021
[Siamese] Epoch 20/50, Avg Loss=0.034833
[Siamese] Epoch 21/50, Avg Loss=0.034771
[Siamese] Epoch 22/50, Avg Loss=0.034858
[Siamese] Epoch 23/50, Avg Loss=0.034758
[Siamese] Epoch 24/50, Avg Loss



Balanced subset: (780, 512) (780,)
[Siamese] Epoch 1/50, Avg Loss=0.039505
[Siamese] Epoch 2/50, Avg Loss=0.037368
[Siamese] Epoch 3/50, Avg Loss=0.036873
[Siamese] Epoch 4/50, Avg Loss=0.036496
[Siamese] Epoch 5/50, Avg Loss=0.036288
[Siamese] Epoch 6/50, Avg Loss=0.036082
[Siamese] Epoch 7/50, Avg Loss=0.036028
[Siamese] Epoch 8/50, Avg Loss=0.035737
[Siamese] Epoch 9/50, Avg Loss=0.035704
[Siamese] Epoch 10/50, Avg Loss=0.035700
[Siamese] Epoch 11/50, Avg Loss=0.035802
[Siamese] Epoch 12/50, Avg Loss=0.035453
[Siamese] Epoch 13/50, Avg Loss=0.035449
[Siamese] Epoch 14/50, Avg Loss=0.035426
[Siamese] Epoch 15/50, Avg Loss=0.035350
[Siamese] Epoch 16/50, Avg Loss=0.035378
[Siamese] Epoch 17/50, Avg Loss=0.035381
[Siamese] Epoch 18/50, Avg Loss=0.035312
[Siamese] Epoch 19/50, Avg Loss=0.035373
[Siamese] Epoch 20/50, Avg Loss=0.035230
[Siamese] Epoch 21/50, Avg Loss=0.035190
[Siamese] Epoch 22/50, Avg Loss=0.035162
[Siamese] Epoch 23/50, Avg Loss=0.035279
[Siamese] Epoch 24/50, Avg Loss

=== FINAL SUMMARY ===
  accuracy | mean=0.5271 ± 0.0186
 precision | mean=0.6707 ± 0.0621
    recall | mean=0.3043 ± 0.2325
        f1 | mean=0.3661 ± 0.1580
  log_loss | mean=0.6926 ± 0.0041 (batch size = 32)

=== FINAL SUMMARY ===
  accuracy | mean=0.5439 ± 0.0262
 precision | mean=0.6301 ± 0.0744
    recall | mean=0.3763 ± 0.3009
        f1 | mean=0.3889 ± 0.2028
  log_loss | mean=0.6909 ± 0.0050(batch size = 16)