In [1]:
import argparse
import os
import os.path as osp
import shutil
import time
from itertools import product
from json import dumps
import random
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch.optim import Adam
from torch_geometric.loader import DataListLoader, DataLoader
from torch_geometric.nn import DataParallel
from torch_geometric.seed import seed_everything
from torch_geometric.utils.convert import to_networkx
from torch_geometric.data import DataLoader, Data
from torch import nn
from tqdm import tqdm

from pytorch_metric_learning.losses import NTXentLoss, VICRegLoss
from sklearn.model_selection import GridSearchCV, StratifiedKFold

import train_utils
from data_utils import extract_edge_attributes
from torch_geometric.datasets import TUDataset
from layers.input_encoder import LinearEncoder, LinearEdgeEncoder
from layers.layer_utils import make_gnn_layer
from models.GraphClassification import GraphClassification
from models.model_utils import make_GNN

import GCL.losses as L
import GCL.augmentors as A
from GCL.eval import get_split, SVMEvaluator, LREvaluator
from GCL.models import DualBranchContrast

import warnings
warnings.filterwarnings('ignore')

In [2]:
parser = argparse.ArgumentParser(f'arguments for training and testing')
parser.add_argument('--save_dir', type=str, default='./save', help='Base directory for saving information.')
parser.add_argument('--seed', type=int, default=2, help='Random seed for reproducibility.')
parser.add_argument('--dataset_name', type=str, default="COLLAB",
                    choices=("MUTAG", "PROTEINS", "PTC_MR", "IMDBBINARY", "COLLAB"), help='Name of dataset')
parser.add_argument('--drop_prob', type=float, default=0.5,
                    help='Probability of zeroing an activation in dropout layers.') 
parser.add_argument('--batch_size', type=int, default=32, help='Batch size per GPU. Scales automatically when \
                        multiple GPUs are available.')
parser.add_argument("--parallel", action="store_true",
                    help="If true, use DataParallel for multi-gpu training")
parser.add_argument('--num_workers', type=int, default=0, help='Number of worker.')
parser.add_argument('--load_path', type=str, default=None, help='Path to load as a model checkpoint.')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.')
parser.add_argument('--l2_wd', type=float, default=3e-3, help='L2 weight decay.')
parser.add_argument('--num_epochs', type=int, default=100, help='Number of epochs.')
parser.add_argument("--hidden_size", type=int, default=128, help="Hidden size of the model")
parser.add_argument("--embedding_size", type=int, default=16, help="Output size of the logistic model")
parser.add_argument("--model_name", type=str, default="KHopGNNConv",
                    choices=("KHopGNNConv"), help="Base GNN model")
parser.add_argument("--K", type=int, default=2, help="Number of hop to consider")
parser.add_argument("--num_layer", type=int, default=2, help="Number of layer for feature encoder")
parser.add_argument("--JK", type=str, default="sum", choices=("sum", "max", "mean", "attention", "last", "concat"),
                    help="Jumping knowledge method")
parser.add_argument("--residual", default=True, action="store_true", help="If true, use residual connection between each layer")
parser.add_argument("--virtual_node", action="store_true", default=False, 
                    help="If true, add virtual node information in each layer")
parser.add_argument("--eps", type=float, default=0., help="Initial epsilon in GIN")
parser.add_argument("--train_eps", action="store_true", help="If true, the epsilon in GIN model is trainable")
parser.add_argument("--combine", type=str, default="geometric", choices=("attention", "geometric"),
                    help="Combine method in k-hop aggregation")
parser.add_argument("--pooling_method", type=str, default="sum", choices=("mean", "sum", "attention"),
                    help="Pooling method in graph classification")
parser.add_argument('--norm_type', type=str, default="Batch",
                    choices=("Batch", "Layer", "Instance", "GraphSize", "Pair"),
                    help="Normalization method in model")
parser.add_argument('--aggr', type=str, default="add",
                    help='Aggregation method in GNN layer, only works in GraphSAGE')
parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
parser.add_argument('--factor', type=float, default=0.5, help='Factor for reducing learning rate scheduler')
parser.add_argument('--reprocess', action="store_true", help='If true, reprocess the dataset')
parser.add_argument('--search', action="store_true", help='If true, search hyper-parameters')
parser.add_argument("--pos_enc_dim", type=int, default=6, help="Initial positional dim.")
parser.add_argument("--pos_attr", type=bool, default=False, help="Positional attributes.")

_StoreAction(option_strings=['--pos_attr'], dest='pos_attr', nargs=None, const=None, default=False, type=<class 'bool'>, choices=None, required=False, help='Positional attributes.', metavar=None)

In [3]:
args = parser.parse_args("")

In [4]:
args.name = args.model_name + "_" + str(args.K) + "_" + str(args.search)

In [5]:
torch.set_num_threads(1)

In [6]:
# Set up logging and devices
args.save_dir = train_utils.get_save_dir(args.save_dir, args.name, type=args.dataset_name)
log = train_utils.get_logger(args.save_dir, args.name)
device, args.gpu_ids = train_utils.get_available_devices()

In [7]:
if len(args.gpu_ids) > 1 and args.parallel:
    log.info(f'Using multi-gpu training')
    args.parallel = True
    loader = DataListLoader
    args.batch_size *= max(1, len(args.gpu_ids))
else:
    log.info(f'Using single-gpu training')
    args.parallel = False
    loader = DataLoader

[02.13.25 14:31:44] Using single-gpu training


In [8]:
# Set random seed
seed = args.seed
log.info(f'Using random seed {seed}...')
seed_everything(seed)

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

[02.13.25 14:31:44] Using random seed 2...


In [9]:
def edge_feature_transform(g):
    return extract_edge_attributes(g, args.pos_enc_dim)

In [10]:
tag = str(int(time.time()))

In [11]:
def get_model(args):
    layer = make_gnn_layer(args)
    init_emb = LinearEncoder(args.input_size, args.hidden_size, pos_attr=args.pos_attr)
    init_edge_attr_emb = LinearEdgeEncoder(args.edge_attr_size, args.hidden_size, edge_attr=True)
    init_edge_attr_v2_emb = LinearEdgeEncoder(args.edge_attr_v2_size, args.hidden_size, edge_attr=False)
    
    GNNModel = make_GNN(args)
    
    gnn = GNNModel(
        num_layer=args.num_layer,
        gnn_layer=layer,
        JK=args.JK,
        norm_type=args.norm_type,
        init_emb=init_emb,
        init_edge_attr_emb=init_edge_attr_emb,
        init_edge_attr_v2_emb=init_edge_attr_v2_emb,
        residual=args.residual,
        virtual_node=args.virtual_node,
        drop_prob=args.drop_prob)

    model = GraphClassification(embedding_model=gnn,
                                pooling_method=args.pooling_method,
                                output_size=args.output_size)
    
    model.reset_parameters()

    if args.parallel:
        model = DataParallel(model, args.gpu_ids)
    return model

In [12]:
class Projection(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Projection, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU()
        )
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x) + self.linear(x)

In [13]:
def vicreg_loss(embeddings1, embeddings2, lambda_var=25, lambda_cov=25, mu=1):
    """
    Calculate the VICReg loss between two sets of embeddings with the correct handling of covariance differences.
    
    Args:
    embeddings1, embeddings2 (torch.Tensor): Embeddings from two views, shape (batch_size, feature_dim).
    lambda_var (float): Coefficient for the variance loss.
    lambda_cov (float): Coefficient for the covariance loss.
    mu (float): Coefficient for the invariance loss.
    
    Returns:
    torch.Tensor: The total VICReg loss.
    """
    # Invariance Loss
    invariance_loss = F.mse_loss(embeddings1, embeddings2)

    # Variance Loss
    def variance_loss(embeddings1, embeddings2):
        mean_embeddings1 = embeddings1.mean(dim=0)
        mean_embeddings2 = embeddings2.mean(dim=0)
        
        std_dev1 = torch.sqrt((embeddings1 - mean_embeddings1).var(dim=0) + 1e-4)
        std_dev2 = torch.sqrt((embeddings2 - mean_embeddings2).var(dim=0) + 1e-4)
        
        return torch.mean(torch.abs(F.relu(1 - std_dev1) - F.relu(1 - std_dev2)))

    variance_loss_value = variance_loss(embeddings1, embeddings2)

    # Covariance Loss
    def covariance_loss(embeddings1, embeddings2):
        batch_size, feature_dim = embeddings1.size()
        
        embeddings_centered1 = embeddings1 - embeddings1.mean(dim=0)
        embeddings_centered2 = embeddings2 - embeddings2.mean(dim=0)
        
        covariance_matrix1 = torch.matmul(embeddings_centered1.T, embeddings_centered1) / (batch_size - 1)
        covariance_matrix2 = torch.matmul(embeddings_centered2.T, embeddings_centered2) / (batch_size - 1)
        
        covariance_matrix1.fill_diagonal_(0)
        covariance_matrix2.fill_diagonal_(0)
        
        cov_diff = torch.abs(covariance_matrix1.pow(2) - covariance_matrix2.pow(2))
        return cov_diff.sum() / feature_dim

    covariance_loss_value = covariance_loss(embeddings1, embeddings2)

    total_loss = mu * invariance_loss + lambda_var * variance_loss_value + lambda_cov * covariance_loss_value
    
    return total_loss

In [14]:
class Encoder(torch.nn.Module):
    def __init__(self, model_1, model_2, mlp1, mlp2, aug1, aug2):
        super(Encoder, self).__init__()
        self.model_1 = model_1
        self.model_2 = model_2
        self.mlp1 = mlp1
        self.mlp2 = mlp2
        self.aug1 = aug1
        self.aug2 = aug2
        
    def get_embedding(self, data):
        z, g = self.model_1(data)
        z_pos, g_pos = self.model_2(data)
        
        z = self.mlp1(z)
        g = self.mlp2(g)
        
        z_pos = self.mlp1(z_pos)
        g_pos = self.mlp2(g_pos)

        g = torch.cat((g, g_pos), 1)
        z = torch.cat((z, z_pos), 1)

        return g.detach(), z.detach()

    def forward(self, data):
        data1 = self.aug1(data.x, data.edge_index, data.y, data.pos, data.edge_attr,
                          data.edge_attr_v2, data.batch, data.ptr)
        data2 = self.aug2(data.x, data.edge_index, data.y, data.pos, data.edge_attr,
                          data.edge_attr_v2, data.batch, data.ptr)
        
        # Structural features
        z1, g1 = self.model_1(data1)
        z2, g2 = self.model_1(data2)
        
        # Positional features
        z1_pos, g1_pos = self.model_2(data1)
        z2_pos, g2_pos = self.model_2(data2)
        
        h1, h2 = [self.mlp1(h) for h in [z1, z2]]
        g1, g2 = [self.mlp2(g) for g in [g1, g2]]
        
        h1_pos, h2_pos = [self.mlp1(h_pos) for h_pos in [z1_pos, z2_pos]]
        g1_pos, g2_pos = [self.mlp2(g_pos) for g_pos in [g1_pos, g2_pos]]
        
        h1 = torch.cat((h1, h1_pos), 1)
        h2 = torch.cat((h2, h2_pos), 1)
        g1 = torch.cat((g1, g1_pos), 1)
        g2 = torch.cat((g2, g2_pos), 1)
        
        return h1, h2, g1, g2

In [15]:
def train(encoder_model, dataloader, optimizer, device):
    best = float("inf")
    cnt_wait = 0
    best_t = 0
    
    loss_func = NTXentLoss(temperature=0.10)
    
    encoder_model.train()
    epoch_loss = 0
    for data in dataloader:
        data = data.to(device)
        optimizer.zero_grad()

        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)

        h1, h2, g1, g2 = encoder_model(data)
        
        embeddings = torch.cat((g1, g2))
        
        # The same index corresponds to a positive pair
        indices = torch.arange(0, g1.size(0), device=device)
        labels = torch.cat((indices, indices))
        
        reg_loss = vicreg_loss(h1, h2, lambda_var=24, lambda_cov=24, mu=1)
        loss = loss_func(embeddings, labels) + 0.005*reg_loss
        
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        
        if epoch % 20 == 0:
            print("Epoch: {0}, Loss: {1:0.4f}".format(epoch, epoch_loss))

        if epoch_loss < best:
            best = epoch_loss
            best_t = epoch
            cnt_wait = 0
            torch.save(encoder_model.state_dict(), './pkl/best_model_'+ args.dataset_name + tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            print("Early stopping")
            break
            
    return epoch_loss

In [16]:
def test(encoder_model, dataloader, seeds, device):
    encoder_model.eval()
    x = []
    y = []
    for data in dataloader:
        data = data.to(device)
        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)
        graph_embeds, node_embeds = encoder_model.get_embedding(data)
        x.append(graph_embeds)
        y.append(data.y)
    x = torch.cat(x, dim=0)
    y = torch.cat(y, dim=0)

    result = []
        
    random.shuffle(seeds)
    seeds = seeds.tolist()
    for _ in np.arange(10):
        random_seed = seeds.pop()
        split = get_split(num_samples=x.size()[0], train_ratio=0.1, test_ratio=0.8, seed=random_seed)
        result.append(LREvaluator()(x, y, split))
    
    return result

In [17]:
path = "./data/data_splits/" + args.dataset_name

In [18]:
train_indices, test_indices, train_test_indices = [], [], []
for i in range(10):
    train_filename = os.path.join(path, '10fold_idx', 'train_idx-{}.txt'.format(i + 1))
    test_filename = os.path.join(path, '10fold_idx', 'test_idx-{}.txt'.format(i + 1))
    train_indices.append(torch.from_numpy(np.loadtxt(train_filename, dtype=int)).to(torch.long))
    test_indices.append(torch.from_numpy(np.loadtxt(test_filename, dtype=int)).to(torch.long))

train_test_filename = os.path.join(path, '10fold_idx', 'train_idx-{}.txt'.format(i + 2))
train_test_indices.append(torch.from_numpy(np.loadtxt(train_test_filename, dtype=int)).to(torch.long))

In [19]:
dataset = TUDataset(root='./data/'+args.dataset_name, name=args.dataset_name, 
                    pre_transform=T.Compose([edge_feature_transform]))

In [20]:
# Determine the maximum degree in the dataset
max_degree = max([data.num_nodes for data in dataset])

# Apply the OneHotDegree transform
dataset.transform = T.OneHotDegree(max_degree)

In [21]:
classes = torch.unique(dataset.y)
args.n_classes = len(classes)

In [22]:
args.input_size = dataset.num_node_features
args.pos_size = args.pos_enc_dim
args.output_size = args.hidden_size
args.edge_attr_size =  dataset.edge_attr.shape[1]
args.edge_attr_v2_size =  dataset.edge_attr_v2.shape[1]

In [23]:
aug1 = A.RandomChoice([A.FeatureDropout(pf=0.1),
                           A.FeatureMasking(pf=0.1),
                           A.EdgeAttrMasking(pf=0.1)], 1)

aug2 = A.RandomChoice([A.RWSampling(num_seeds=1000, walk_length=10),
                       A.NodeDropping(pn=0.1),
                       A.EdgeRemoving(pe=0.1)], 1)

model_1 = get_model(args)
model_1.to(device)

args.pos_attr = True
args.input_size = args.pos_size
model_2 = get_model(args)
model_2.to(device)

mlp1 = Projection(input_dim=args.hidden_size, output_dim=args.hidden_size)
mlp2 = Projection(input_dim=args.hidden_size, output_dim=args.hidden_size)

encoder_model = Encoder(model_1=model_1, model_2=model_2, mlp1=mlp1, mlp2=mlp2, aug1=aug1, aug2=aug2).to(device)

In [24]:
optimizer = Adam(encoder_model.parameters(), lr=args.lr, weight_decay=args.l2_wd)
dataloader = DataLoader(dataset, batch_size=args.batch_size)
with tqdm(total=100, desc='(T)') as pbar:
    for epoch in range(1, 101):
        loss = train(encoder_model, dataloader, optimizer, device)
        pbar.set_postfix({'loss': loss})
        pbar.update()

(T):   1%|▍                                        | 1/100 [00:20<33:05, 20.05s/it, loss=39.6]

Early stopping


(T):   2%|▊                                        | 2/100 [00:39<32:18, 19.78s/it, loss=22.2]

Early stopping


(T):   3%|█▏                                       | 3/100 [00:58<31:29, 19.48s/it, loss=26.7]

Early stopping


(T):   4%|█▋                                       | 4/100 [01:16<30:04, 18.80s/it, loss=22.9]

Early stopping


(T):   5%|██                                       | 5/100 [01:35<30:07, 19.03s/it, loss=20.2]

Early stopping


(T):   6%|██▌                                        | 6/100 [01:54<29:21, 18.74s/it, loss=17]

Early stopping


(T):   7%|██▊                                      | 7/100 [02:13<29:23, 18.96s/it, loss=17.4]

Early stopping


(T):   8%|███▎                                     | 8/100 [02:31<28:25, 18.53s/it, loss=14.6]

Early stopping


(T):   9%|███▋                                     | 9/100 [02:50<28:19, 18.67s/it, loss=12.9]

Early stopping


(T):  10%|████                                    | 10/100 [03:07<27:28, 18.32s/it, loss=11.5]

Early stopping


(T):  11%|████▌                                     | 11/100 [03:26<27:16, 18.39s/it, loss=11]

Early stopping


(T):  12%|████▊                                   | 12/100 [03:44<27:04, 18.46s/it, loss=11.5]

Early stopping


(T):  13%|█████▏                                  | 13/100 [04:03<26:55, 18.57s/it, loss=9.87]

Early stopping


(T):  14%|█████▌                                  | 14/100 [04:23<27:18, 19.05s/it, loss=9.92]

Early stopping


(T):  15%|██████                                  | 15/100 [04:45<28:03, 19.81s/it, loss=14.6]

Early stopping


(T):  16%|██████▍                                 | 16/100 [05:07<28:53, 20.63s/it, loss=12.7]

Early stopping


(T):  17%|██████▊                                 | 17/100 [05:29<28:54, 20.90s/it, loss=9.86]

Early stopping


(T):  18%|███████▏                                | 18/100 [06:51<53:34, 39.20s/it, loss=10.9]

Early stopping


(T):  19%|███████▌                                | 19/100 [07:11<45:12, 33.48s/it, loss=11.6]

Early stopping
Epoch: 20, Loss: 0.3412
Epoch: 20, Loss: 0.8206
Epoch: 20, Loss: 1.1216
Epoch: 20, Loss: 1.4762
Epoch: 20, Loss: 2.1474
Epoch: 20, Loss: 2.4616
Epoch: 20, Loss: 3.2713
Epoch: 20, Loss: 3.5104
Epoch: 20, Loss: 3.7768
Epoch: 20, Loss: 4.2789
Epoch: 20, Loss: 4.5001
Epoch: 20, Loss: 4.9948
Epoch: 20, Loss: 5.3547
Epoch: 20, Loss: 5.5658
Epoch: 20, Loss: 5.8985
Epoch: 20, Loss: 6.1594
Epoch: 20, Loss: 6.7395
Epoch: 20, Loss: 7.7740
Epoch: 20, Loss: 7.9403
Epoch: 20, Loss: 8.1006


(T):  20%|████████                                | 20/100 [07:30<38:57, 29.22s/it, loss=8.73]

Epoch: 20, Loss: 8.7347
Early stopping


(T):  21%|████████▍                               | 21/100 [07:51<35:06, 26.67s/it, loss=9.61]

Early stopping


(T):  22%|████████▊                               | 22/100 [08:11<32:03, 24.66s/it, loss=8.65]

Early stopping


(T):  23%|█████████▏                              | 23/100 [08:31<29:47, 23.21s/it, loss=11.4]

Early stopping


(T):  24%|█████████▌                              | 24/100 [08:51<28:05, 22.18s/it, loss=9.89]

Early stopping


(T):  25%|██████████                              | 25/100 [09:09<26:31, 21.22s/it, loss=9.08]

Early stopping


(T):  26%|██████████▍                             | 26/100 [09:29<25:35, 20.75s/it, loss=8.28]

Early stopping


(T):  27%|██████████▊                             | 27/100 [09:48<24:38, 20.25s/it, loss=7.17]

Early stopping


(T):  28%|███████████▏                            | 28/100 [10:07<23:42, 19.76s/it, loss=7.14]

Early stopping


(T):  29%|███████████▌                            | 29/100 [10:27<23:29, 19.85s/it, loss=5.99]

Early stopping


(T):  30%|████████████                            | 30/100 [10:46<22:43, 19.48s/it, loss=7.55]

Early stopping


(T):  31%|████████████▍                           | 31/100 [11:04<22:10, 19.28s/it, loss=6.28]

Early stopping


(T):  32%|████████████▊                           | 32/100 [11:24<22:01, 19.43s/it, loss=6.31]

Early stopping


(T):  33%|█████████████▏                          | 33/100 [11:43<21:27, 19.21s/it, loss=9.36]

Early stopping


(T):  34%|█████████████▌                          | 34/100 [12:02<21:05, 19.18s/it, loss=9.16]

Early stopping


(T):  35%|██████████████                          | 35/100 [12:22<21:02, 19.43s/it, loss=6.51]

Early stopping


(T):  36%|███████████████▍                           | 36/100 [12:43<21:05, 19.77s/it, loss=6]

Early stopping


(T):  37%|██████████████▊                         | 37/100 [13:03<20:53, 19.90s/it, loss=6.17]

Early stopping


(T):  38%|███████████████▏                        | 38/100 [13:22<20:18, 19.65s/it, loss=6.46]

Early stopping


(T):  39%|███████████████▌                        | 39/100 [13:41<19:44, 19.42s/it, loss=7.44]

Early stopping
Epoch: 40, Loss: 0.6319
Epoch: 40, Loss: 0.9817
Epoch: 40, Loss: 1.1930
Epoch: 40, Loss: 1.6296
Epoch: 40, Loss: 1.9636
Epoch: 40, Loss: 2.4716
Epoch: 40, Loss: 2.5966
Epoch: 40, Loss: 2.8360
Epoch: 40, Loss: 3.1752
Epoch: 40, Loss: 3.5718
Epoch: 40, Loss: 3.8313
Epoch: 40, Loss: 4.0338
Epoch: 40, Loss: 4.1419
Epoch: 40, Loss: 4.2443
Epoch: 40, Loss: 4.6227
Epoch: 40, Loss: 5.0432
Epoch: 40, Loss: 5.2656
Epoch: 40, Loss: 5.9580
Epoch: 40, Loss: 6.2319
Epoch: 40, Loss: 6.4521


(T):  40%|████████████████                        | 40/100 [13:59<19:08, 19.14s/it, loss=6.76]

Epoch: 40, Loss: 6.7614
Early stopping


(T):  41%|████████████████▍                       | 41/100 [14:17<18:34, 18.89s/it, loss=7.58]

Early stopping


(T):  42%|████████████████▊                       | 42/100 [14:37<18:25, 19.07s/it, loss=6.81]

Early stopping


(T):  43%|█████████████████▏                      | 43/100 [14:57<18:24, 19.38s/it, loss=9.29]

Early stopping


(T):  44%|█████████████████▌                      | 44/100 [15:16<17:57, 19.24s/it, loss=8.93]

Early stopping


(T):  45%|██████████████████                      | 45/100 [15:34<17:25, 19.00s/it, loss=8.58]

Early stopping


(T):  46%|██████████████████▍                     | 46/100 [15:53<17:00, 18.91s/it, loss=7.97]

Early stopping


(T):  47%|██████████████████▊                     | 47/100 [16:13<16:50, 19.07s/it, loss=7.07]

Early stopping


(T):  48%|███████████████████▏                    | 48/100 [16:32<16:36, 19.17s/it, loss=7.01]

Early stopping


(T):  49%|███████████████████▌                    | 49/100 [16:51<16:12, 19.07s/it, loss=7.14]

Early stopping


(T):  50%|████████████████████                    | 50/100 [17:10<15:54, 19.08s/it, loss=7.04]

Early stopping


(T):  51%|████████████████████▍                   | 51/100 [17:29<15:37, 19.14s/it, loss=6.27]

Early stopping


(T):  52%|████████████████████▊                   | 52/100 [17:48<15:21, 19.19s/it, loss=8.23]

Early stopping


(T):  53%|█████████████████████▏                  | 53/100 [18:07<14:58, 19.12s/it, loss=7.11]

Early stopping


(T):  54%|█████████████████████▌                  | 54/100 [18:27<14:39, 19.11s/it, loss=6.78]

Early stopping


(T):  55%|██████████████████████▌                  | 55/100 [18:46<14:18, 19.07s/it, loss=6.5]

Early stopping


(T):  56%|██████████████████████▍                 | 56/100 [19:06<14:14, 19.42s/it, loss=6.49]

Early stopping


(T):  57%|██████████████████████▊                 | 57/100 [19:24<13:42, 19.13s/it, loss=7.05]

Early stopping


(T):  58%|███████████████████████▏                | 58/100 [19:42<13:11, 18.83s/it, loss=5.36]

Early stopping


(T):  59%|███████████████████████▌                | 59/100 [20:02<12:57, 18.97s/it, loss=7.87]

Early stopping
Epoch: 60, Loss: 0.3394
Epoch: 60, Loss: 0.9140
Epoch: 60, Loss: 1.3405
Epoch: 60, Loss: 1.5124
Epoch: 60, Loss: 1.8358
Epoch: 60, Loss: 2.3631
Epoch: 60, Loss: 2.5222
Epoch: 60, Loss: 2.7182
Epoch: 60, Loss: 2.9669
Epoch: 60, Loss: 3.4280
Epoch: 60, Loss: 3.5651
Epoch: 60, Loss: 3.7749
Epoch: 60, Loss: 4.0850
Epoch: 60, Loss: 4.9810
Epoch: 60, Loss: 5.2345
Epoch: 60, Loss: 5.3553
Epoch: 60, Loss: 5.7720
Epoch: 60, Loss: 6.3051
Epoch: 60, Loss: 6.4917
Epoch: 60, Loss: 7.3329


(T):  60%|████████████████████████▌                | 60/100 [20:21<12:39, 19.00s/it, loss=8.4]

Epoch: 60, Loss: 8.4014
Early stopping


(T):  61%|████████████████████████▍               | 61/100 [20:40<12:20, 18.99s/it, loss=6.95]

Early stopping


(T):  62%|████████████████████████▊               | 62/100 [21:00<12:13, 19.31s/it, loss=7.89]

Early stopping


(T):  63%|█████████████████████████▏              | 63/100 [21:19<11:54, 19.31s/it, loss=5.37]

Early stopping


(T):  64%|█████████████████████████▌              | 64/100 [21:39<11:46, 19.62s/it, loss=5.06]

Early stopping


(T):  65%|██████████████████████████              | 65/100 [21:57<11:10, 19.17s/it, loss=6.61]

Early stopping


(T):  66%|██████████████████████████▍             | 66/100 [22:17<10:51, 19.16s/it, loss=5.91]

Early stopping


(T):  67%|██████████████████████████▊             | 67/100 [22:35<10:27, 19.02s/it, loss=5.44]

Early stopping


(T):  68%|███████████████████████████▏            | 68/100 [22:54<10:03, 18.85s/it, loss=6.28]

Early stopping


(T):  69%|███████████████████████████▌            | 69/100 [23:14<09:53, 19.14s/it, loss=6.16]

Early stopping


(T):  70%|████████████████████████████            | 70/100 [23:32<09:31, 19.06s/it, loss=5.15]

Early stopping


(T):  71%|████████████████████████████▍           | 71/100 [23:52<09:18, 19.24s/it, loss=5.93]

Early stopping


(T):  72%|████████████████████████████▊           | 72/100 [24:12<09:06, 19.52s/it, loss=3.95]

Early stopping


(T):  73%|█████████████████████████████▏          | 73/100 [24:32<08:45, 19.48s/it, loss=4.35]

Early stopping


(T):  74%|█████████████████████████████▌          | 74/100 [24:50<08:19, 19.19s/it, loss=5.14]

Early stopping


(T):  75%|██████████████████████████████          | 75/100 [25:09<07:55, 19.02s/it, loss=6.13]

Early stopping


(T):  76%|██████████████████████████████▍         | 76/100 [25:27<07:28, 18.70s/it, loss=6.54]

Early stopping


(T):  77%|██████████████████████████████▊         | 77/100 [25:46<07:15, 18.93s/it, loss=8.07]

Early stopping


(T):  78%|███████████████████████████████▏        | 78/100 [26:07<07:09, 19.51s/it, loss=7.54]

Early stopping


(T):  79%|███████████████████████████████▌        | 79/100 [26:27<06:52, 19.66s/it, loss=6.32]

Early stopping
Epoch: 80, Loss: 0.7124
Epoch: 80, Loss: 1.4763
Epoch: 80, Loss: 1.7031
Epoch: 80, Loss: 1.8335
Epoch: 80, Loss: 2.0491
Epoch: 80, Loss: 2.4651
Epoch: 80, Loss: 2.6776
Epoch: 80, Loss: 2.8584
Epoch: 80, Loss: 2.9703
Epoch: 80, Loss: 3.5927
Epoch: 80, Loss: 3.9752
Epoch: 80, Loss: 4.5072
Epoch: 80, Loss: 4.7231
Epoch: 80, Loss: 4.8292
Epoch: 80, Loss: 4.9129
Epoch: 80, Loss: 5.1215
Epoch: 80, Loss: 6.1630
Epoch: 80, Loss: 6.5236
Epoch: 80, Loss: 7.3365
Epoch: 80, Loss: 7.6222


(T):  80%|████████████████████████████████        | 80/100 [26:46<06:30, 19.54s/it, loss=7.77]

Epoch: 80, Loss: 7.7736
Early stopping


(T):  81%|████████████████████████████████▍       | 81/100 [27:08<06:20, 20.03s/it, loss=6.14]

Early stopping


(T):  82%|█████████████████████████████████▌       | 82/100 [27:26<05:52, 19.58s/it, loss=4.8]

Early stopping


(T):  83%|█████████████████████████████████▏      | 83/100 [27:45<05:30, 19.46s/it, loss=6.87]

Early stopping


(T):  84%|█████████████████████████████████▌      | 84/100 [28:05<05:11, 19.45s/it, loss=6.88]

Early stopping


(T):  85%|██████████████████████████████████      | 85/100 [28:23<04:48, 19.24s/it, loss=5.88]

Early stopping


(T):  86%|███████████████████████████████████▎     | 86/100 [28:41<04:22, 18.74s/it, loss=5.2]

Early stopping


(T):  87%|██████████████████████████████████▊     | 87/100 [29:01<04:07, 19.07s/it, loss=5.06]

Early stopping


(T):  88%|███████████████████████████████████▏    | 88/100 [29:21<03:52, 19.33s/it, loss=5.89]

Early stopping


(T):  89%|███████████████████████████████████▌    | 89/100 [29:41<03:35, 19.56s/it, loss=6.75]

Early stopping


(T):  90%|████████████████████████████████████    | 90/100 [29:59<03:12, 19.26s/it, loss=7.32]

Early stopping


(T):  91%|████████████████████████████████████▍   | 91/100 [30:19<02:53, 19.23s/it, loss=5.71]

Early stopping


(T):  92%|████████████████████████████████████▊   | 92/100 [30:38<02:33, 19.17s/it, loss=5.68]

Early stopping


(T):  93%|█████████████████████████████████████▏  | 93/100 [30:58<02:16, 19.53s/it, loss=5.92]

Early stopping


(T):  94%|█████████████████████████████████████▌  | 94/100 [31:16<01:54, 19.09s/it, loss=6.12]

Early stopping


(T):  95%|██████████████████████████████████████  | 95/100 [31:36<01:36, 19.25s/it, loss=5.42]

Early stopping


(T):  96%|██████████████████████████████████████▍ | 96/100 [31:55<01:16, 19.21s/it, loss=5.51]

Early stopping


(T):  97%|██████████████████████████████████████▊ | 97/100 [32:13<00:56, 18.97s/it, loss=5.13]

Early stopping


(T):  98%|███████████████████████████████████████▏| 98/100 [32:32<00:37, 18.86s/it, loss=5.09]

Early stopping


(T):  99%|███████████████████████████████████████▌| 99/100 [32:51<00:19, 19.04s/it, loss=7.62]

Early stopping
Epoch: 100, Loss: 0.1778
Epoch: 100, Loss: 0.9387
Epoch: 100, Loss: 1.1140
Epoch: 100, Loss: 1.7223
Epoch: 100, Loss: 1.9993
Epoch: 100, Loss: 2.3975
Epoch: 100, Loss: 2.7458
Epoch: 100, Loss: 3.2711
Epoch: 100, Loss: 3.4458
Epoch: 100, Loss: 3.5478
Epoch: 100, Loss: 3.8667
Epoch: 100, Loss: 4.2854
Epoch: 100, Loss: 4.6378
Epoch: 100, Loss: 4.9271
Epoch: 100, Loss: 5.0584
Epoch: 100, Loss: 5.5715
Epoch: 100, Loss: 5.8336
Epoch: 100, Loss: 6.1097
Epoch: 100, Loss: 6.6783
Epoch: 100, Loss: 7.4625


(T): 100%|███████████████████████████████████████| 100/100 [33:10<00:00, 19.90s/it, loss=7.81]

Epoch: 100, Loss: 7.8110
Early stopping





In [25]:
test_result = test(encoder_model, dataloader, train_test_indices[0], device)

(LR): 100%|████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.762, F1Ma=0.724]
(LR): 100%|████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.766, F1Ma=0.743]
(LR): 100%|████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.756, F1Ma=0.713]
(LR): 100%|████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.752, F1Ma=0.703]
(LR): 100%|████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.752, F1Ma=0.719]
(LR): 100%|█████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.76, F1Ma=0.743]
(LR): 100%|████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.762, F1Ma=0.741]
(LR): 100%|████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.752, F1Ma=0.732]
(LR): 100%|████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.758, F1Ma=0.728]
(LR): 100%|████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.754, F1Ma=0.732]


In [26]:
micro_f1_values = [entry['micro_f1']*100 for entry in test_result]

In [27]:
np_micro_f1_values = np.array(micro_f1_values)
micro_f1_mean = np.mean(np_micro_f1_values)
micro_f1_std = np.std(np_micro_f1_values)

print(f'test acc mean = {micro_f1_mean:.4f} ± {micro_f1_std:.4f}')

test acc mean = 75.7400 ± 0.4737
