In [1]:
import argparse
import os
import os.path as osp
import shutil
import time
from itertools import product
from json import dumps
import random
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch.optim import Adam
from torch_geometric.loader import DataListLoader, DataLoader
from torch_geometric.nn import DataParallel
from torch_geometric.seed import seed_everything
from torch_geometric.utils.convert import to_networkx
from torch_geometric.data import DataLoader, Data
from torch import nn
from tqdm import tqdm

from pytorch_metric_learning.losses import NTXentLoss, VICRegLoss
from sklearn.model_selection import GridSearchCV, StratifiedKFold

import train_utils
from data_utils import extract_edge_attributes
from torch_geometric.datasets import TUDataset
from layers.input_encoder import LinearEncoder, LinearEdgeEncoder
from layers.layer_utils import make_gnn_layer
from models.GraphClassification import GraphClassification
from models.model_utils import make_GNN

import GCL.losses as L
import GCL.augmentors as A
from GCL.eval import get_split, SVMEvaluator, LREvaluator
from GCL.models import DualBranchContrast

import warnings
warnings.filterwarnings('ignore')

In [2]:
parser = argparse.ArgumentParser(f'arguments for training and testing')
parser.add_argument('--save_dir', type=str, default='./save', help='Base directory for saving information.')
parser.add_argument('--seed', type=int, default=2, help='Random seed for reproducibility.')
parser.add_argument('--dataset_name', type=str, default="PROTEINS",
                    choices=("MUTAG", "PROTEINS", "PTC_MR", "IMDBBINARY"), help='Name of dataset')
parser.add_argument('--drop_prob', type=float, default=0.5,
                    help='Probability of zeroing an activation in dropout layers.')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size per GPU. Scales automatically when \
                        multiple GPUs are available.')
parser.add_argument("--parallel", action="store_true",
                    help="If true, use DataParallel for multi-gpu training")
parser.add_argument('--num_workers', type=int, default=0, help='Number of worker.')
parser.add_argument('--load_path', type=str, default=None, help='Path to load as a model checkpoint.')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.')
parser.add_argument('--l2_wd', type=float, default=3e-4, help='L2 weight decay.')
parser.add_argument('--num_epochs', type=int, default=100, help='Number of epochs.')
parser.add_argument("--hidden_size", type=int, default=128, help="Hidden size of the model")
parser.add_argument("--embedding_size", type=int, default=16, help="Output size of the logistic model")
parser.add_argument("--model_name", type=str, default="KHopGNNConv",
                    choices=("KHopGNNConv"), help="Base GNN model")
parser.add_argument("--K", type=int, default=2, help="Number of hop to consider")
parser.add_argument("--num_layer", type=int, default=2, help="Number of layer for feature encoder")
parser.add_argument("--JK", type=str, default="sum", choices=("sum", "max", "mean", "attention", "last", "concat"),
                    help="Jumping knowledge method")
parser.add_argument("--residual", default=True, action="store_true", help="If true, use residual connection between each layer")
parser.add_argument("--virtual_node", action="store_true", default=False, 
                    help="If true, add virtual node information in each layer")
parser.add_argument("--eps", type=float, default=0., help="Initial epsilon in GIN")
parser.add_argument("--train_eps", action="store_true", help="If true, the epsilon in GIN model is trainable")
parser.add_argument("--combine", type=str, default="geometric", choices=("attention", "geometric"),
                    help="Combine method in k-hop aggregation")
parser.add_argument("--pooling_method", type=str, default="sum", choices=("mean", "sum", "attention"),
                    help="Pooling method in graph classification")
parser.add_argument('--norm_type', type=str, default="Batch",
                    choices=("Batch", "Layer", "Instance", "GraphSize", "Pair"),
                    help="Normalization method in model")
parser.add_argument('--aggr', type=str, default="add",
                    help='Aggregation method in GNN layer, only works in GraphSAGE')
parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
parser.add_argument('--factor', type=float, default=0.5, help='Factor for reducing learning rate scheduler')
parser.add_argument('--reprocess', action="store_true", help='If true, reprocess the dataset')
parser.add_argument('--search', action="store_true", help='If true, search hyper-parameters')
parser.add_argument("--pos_enc_dim", type=int, default=6, help="Initial positional dim.")
parser.add_argument("--pos_attr", type=bool, default=False, help="Positional attributes.")
parser.add_argument("--feature_augmentation", type=bool, default=False, help="If true, feature augmentation.")

_StoreAction(option_strings=['--feature_augmentation'], dest='feature_augmentation', nargs=None, const=None, default=False, type=<class 'bool'>, choices=None, required=False, help='If true, feature augmentation.', metavar=None)

In [3]:
args = parser.parse_args("")

In [4]:
args.name = args.model_name + "_" + str(args.K) + "_" + str(args.search)

In [5]:
# Set up logging and devices
args.save_dir = train_utils.get_save_dir(args.save_dir, args.name, type=args.dataset_name)
log = train_utils.get_logger(args.save_dir, args.name)
device, args.gpu_ids = train_utils.get_available_devices()

In [6]:
if len(args.gpu_ids) > 1 and args.parallel:
    log.info(f'Using multi-gpu training')
    args.parallel = True
    loader = DataListLoader
    args.batch_size *= max(1, len(args.gpu_ids))
else:
    log.info(f'Using single-gpu training')
    args.parallel = False
    loader = DataLoader

[02.14.25 09:26:12] Using single-gpu training


In [7]:
# Set random seed
seed = args.seed
log.info(f'Using random seed {seed}...')
seed_everything(seed)

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

[02.14.25 09:26:12] Using random seed 2...


In [8]:
def edge_feature_transform(g):
    return extract_edge_attributes(g, args.pos_enc_dim)

In [9]:
tag = str(int(time.time()))

In [10]:
def get_model(args):
    layer = make_gnn_layer(args)
    init_emb = LinearEncoder(args.input_size, args.hidden_size, pos_attr=args.pos_attr)
    init_edge_attr_emb = LinearEdgeEncoder(args.edge_attr_size, args.hidden_size, edge_attr=True)
    init_edge_attr_v2_emb = LinearEdgeEncoder(args.edge_attr_v2_size, args.hidden_size, edge_attr=False)
    
    GNNModel = make_GNN(args)
    
    gnn = GNNModel(
        num_layer=args.num_layer,
        gnn_layer=layer,
        JK=args.JK,
        norm_type=args.norm_type,
        init_emb=init_emb,
        init_edge_attr_emb=init_edge_attr_emb,
        init_edge_attr_v2_emb=init_edge_attr_v2_emb,
        residual=args.residual,
        virtual_node=args.virtual_node,
        drop_prob=args.drop_prob)

    model = GraphClassification(embedding_model=gnn,
                                pooling_method=args.pooling_method,
                                output_size=args.output_size)
    
    model.reset_parameters()

    if args.parallel:
        model = DataParallel(model, args.gpu_ids)
    return model

In [11]:
class Projection(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Projection, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU()
        )
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x) + self.linear(x)

In [12]:
def vicreg_loss(embeddings1, embeddings2, lambda_var=25, lambda_cov=25, mu=1):
    """
    Calculate the VICReg loss between two sets of embeddings with the correct handling of covariance differences.
    
    Args:
    embeddings1, embeddings2 (torch.Tensor): Embeddings from two views, shape (batch_size, feature_dim).
    lambda_var (float): Coefficient for the variance loss.
    lambda_cov (float): Coefficient for the covariance loss.
    mu (float): Coefficient for the invariance loss.
    
    Returns:
    torch.Tensor: The total VICReg loss.
    """
    # Invariance Loss
    invariance_loss = F.mse_loss(embeddings1, embeddings2)

    # Variance Loss
    def variance_loss(embeddings1, embeddings2):
        mean_embeddings1 = embeddings1.mean(dim=0)
        mean_embeddings2 = embeddings2.mean(dim=0)
        
        std_dev1 = torch.sqrt((embeddings1 - mean_embeddings1).var(dim=0) + 1e-4)
        std_dev2 = torch.sqrt((embeddings2 - mean_embeddings2).var(dim=0) + 1e-4)
        
        return torch.mean(torch.abs(F.relu(1 - std_dev1) - F.relu(1 - std_dev2)))

    variance_loss_value = variance_loss(embeddings1, embeddings2)

    # Covariance Loss
    def covariance_loss(embeddings1, embeddings2):
        batch_size, feature_dim = embeddings1.size()
        
        embeddings_centered1 = embeddings1 - embeddings1.mean(dim=0)
        embeddings_centered2 = embeddings2 - embeddings2.mean(dim=0)
        
        covariance_matrix1 = torch.matmul(embeddings_centered1.T, embeddings_centered1) / (batch_size - 1)
        covariance_matrix2 = torch.matmul(embeddings_centered2.T, embeddings_centered2) / (batch_size - 1)
        
        covariance_matrix1.fill_diagonal_(0)
        covariance_matrix2.fill_diagonal_(0)
        
        cov_diff = torch.abs(covariance_matrix1.pow(2) - covariance_matrix2.pow(2))
        return cov_diff.sum() / feature_dim

    covariance_loss_value = covariance_loss(embeddings1, embeddings2)

    total_loss = mu * invariance_loss + lambda_var * variance_loss_value + lambda_cov * covariance_loss_value
    
    return total_loss

In [13]:
class Encoder(torch.nn.Module):
    def __init__(self, model_1, model_2, mlp1, mlp2, aug1, aug2):
        super(Encoder, self).__init__()
        self.model_1 = model_1
        self.model_2 = model_2
        self.mlp1 = mlp1
        self.mlp2 = mlp2
        self.aug1 = aug1
        self.aug2 = aug2
        
    def get_embedding(self, data):
        z, g = self.model_1(data)
        z_pos, g_pos = self.model_2(data)
        
        z = self.mlp1(z)
        g = self.mlp2(g)
        
        z_pos = self.mlp1(z_pos)
        g_pos = self.mlp2(g_pos)

        g = torch.cat((g, g_pos), 1)
        z = torch.cat((z, z_pos), 1)

        return g.detach(), z.detach()

    def forward(self, data):
        data1 = self.aug1(data.x, data.edge_index, data.y, data.pos, data.edge_attr,
                          data.edge_attr_v2, data.batch, data.ptr)
        data2 = self.aug2(data.x, data.edge_index, data.y, data.pos, data.edge_attr,
                          data.edge_attr_v2, data.batch, data.ptr)
        
        # Structural features
        z1, g1 = self.model_1(data1)
        z2, g2 = self.model_1(data2)
        
        # Positional features
        z1_pos, g1_pos = self.model_2(data1)
        z2_pos, g2_pos = self.model_2(data2)
        
        h1, h2 = [self.mlp1(h) for h in [z1, z2]]
        g1, g2 = [self.mlp2(g) for g in [g1, g2]]
        
        h1_pos, h2_pos = [self.mlp1(h_pos) for h_pos in [z1_pos, z2_pos]]
        g1_pos, g2_pos = [self.mlp2(g_pos) for g_pos in [g1_pos, g2_pos]]
        
        h1 = torch.cat((h1, h1_pos), 1)
        h2 = torch.cat((h2, h2_pos), 1)
        g1 = torch.cat((g1, g1_pos), 1)
        g2 = torch.cat((g2, g2_pos), 1)
        
        return h1, h2, g1, g2

In [14]:
def train(encoder_model, dataloader, optimizer, device):
    best = float("inf")
    cnt_wait = 0
    best_t = 0
    
    loss_func = NTXentLoss(temperature=0.10)
    
    encoder_model.train()
    epoch_loss = 0
    for data in dataloader:
        data = data.to(device)
        optimizer.zero_grad()

        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)

        h1, h2, g1, g2 = encoder_model(data)
        
        embeddings = torch.cat((g1, g2))
        
        # The same index corresponds to a positive pair
        indices = torch.arange(0, g1.size(0), device=device)
        labels = torch.cat((indices, indices))
        
        reg_loss = vicreg_loss(h1, h2, lambda_var=25, lambda_cov=25, mu=1)
        loss = loss_func(embeddings, labels) + 0.009*reg_loss
        
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        
        if epoch % 20 == 0:
            print("Epoch: {0}, Loss: {1:0.4f}".format(epoch, epoch_loss))

        if epoch_loss < best:
            best = epoch_loss
            best_t = epoch
            cnt_wait = 0
            torch.save(encoder_model.state_dict(), './pkl/best_model_'+ args.dataset_name + tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            print("Early stopping")
            break
            
    return epoch_loss

In [15]:
def test(encoder_model, dataloader, seeds, device):
    encoder_model.eval()
    x = []
    y = []
    for data in dataloader:
        data = data.to(device)
        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)
        graph_embeds, node_embeds = encoder_model.get_embedding(data)
        x.append(graph_embeds)
        y.append(data.y)
    x = torch.cat(x, dim=0)
    y = torch.cat(y, dim=0)

    result = []
        
    random.shuffle(seeds)
    seeds = seeds.tolist()
    for _ in np.arange(10):
        random_seed = seeds.pop()
        split = get_split(num_samples=x.size()[0], train_ratio=0.1, test_ratio=0.8, seed=random_seed)
        result.append(LREvaluator()(x, y, split))
    
    return result

In [16]:
path = "./data/data_splits/" + args.dataset_name

In [17]:
train_indices, test_indices, train_test_indices = [], [], []
for i in range(10):
    train_filename = os.path.join(path, '10fold_idx', 'train_idx-{}.txt'.format(i + 1))
    test_filename = os.path.join(path, '10fold_idx', 'test_idx-{}.txt'.format(i + 1))
    train_indices.append(torch.from_numpy(np.loadtxt(train_filename, dtype=int)).to(torch.long))
    test_indices.append(torch.from_numpy(np.loadtxt(test_filename, dtype=int)).to(torch.long))

if args.feature_augmentation:
    train_test_filename = os.path.join(path, '10fold_idx', 'train_idx-{}.txt'.format(i + 2))
    train_test_indices.append(torch.from_numpy(np.loadtxt(train_test_filename, dtype=int)).to(torch.long))
else:
    train_test_filename = os.path.join(path, '10fold_idx', 'train_idx-{}.txt'.format(i + 3))
    train_test_indices.append(torch.from_numpy(np.loadtxt(train_test_filename, dtype=int)).to(torch.long))

In [18]:
dataset = TUDataset(root='./data/'+args.dataset_name, name=args.dataset_name, 
                    pre_transform=T.Compose([edge_feature_transform]))

In [19]:
classes = torch.unique(dataset.y)
args.n_classes = len(classes)

In [20]:
args.input_size = dataset.num_node_features
args.pos_size = args.pos_enc_dim
args.output_size = args.hidden_size
args.edge_attr_size =  dataset.edge_attr.shape[1]
args.edge_attr_v2_size =  dataset.edge_attr_v2.shape[1]

In [21]:
aug1 = A.Identity()

if args.feature_augmentation:
    # Feature Augmentation
    aug2 = A.RandomChoice([A.FeatureDropout(pf=0.1),
                           A.FeatureMasking(pf=0.1),
                           A.EdgeAttrMasking(pf=0.1)], 1)
else:
    # Structure Augmentation
    aug2 = A.RandomChoice([A.RWSampling(num_seeds=1000, walk_length=10),
                       A.NodeDropping(pn=0.1),
                       A.EdgeRemoving(pe=0.1)], 1)

model_1 = get_model(args)
model_1.to(device)

args.pos_attr = True
args.input_size = args.pos_size
model_2 = get_model(args)
model_2.to(device)

mlp1 = Projection(input_dim=args.hidden_size, output_dim=args.hidden_size)
mlp2 = Projection(input_dim=args.hidden_size, output_dim=args.hidden_size)

encoder_model = Encoder(model_1=model_1, model_2=model_2, mlp1=mlp1, mlp2=mlp2, aug1=aug1, aug2=aug2).to(device)

In [22]:
optimizer = Adam(encoder_model.parameters(), lr=args.lr, weight_decay=args.l2_wd)
dataloader = DataLoader(dataset, batch_size=args.batch_size)
with tqdm(total=100, desc='(T)') as pbar:
    for epoch in range(1, 101):
        loss = train(encoder_model, dataloader, optimizer, device)
        pbar.set_postfix({'loss': loss})
        pbar.update()

(T):   1%|▍                                        | 1/100 [00:04<08:05,  4.91s/it, loss=47.5]

Early stopping


(T):   2%|▊                                          | 2/100 [00:10<08:34,  5.25s/it, loss=27]

Early stopping


(T):   3%|█▏                                       | 3/100 [00:15<08:37,  5.33s/it, loss=26.3]

Early stopping


(T):   4%|█▋                                       | 4/100 [00:21<08:49,  5.52s/it, loss=21.7]

Early stopping


(T):   5%|██                                       | 5/100 [00:27<09:03,  5.72s/it, loss=17.5]

Early stopping


(T):   6%|██▍                                      | 6/100 [00:33<09:07,  5.83s/it, loss=16.9]

Early stopping


(T):   7%|██▊                                      | 7/100 [00:39<09:10,  5.92s/it, loss=14.3]

Early stopping


(T):   8%|███▎                                     | 8/100 [00:46<09:15,  6.03s/it, loss=13.2]

Early stopping


(T):   9%|███▋                                     | 9/100 [00:51<09:00,  5.93s/it, loss=13.7]

Early stopping


(T):  10%|████                                    | 10/100 [00:57<08:50,  5.89s/it, loss=12.1]

Early stopping


(T):  11%|████▍                                   | 11/100 [01:03<08:50,  5.96s/it, loss=10.3]

Early stopping


(T):  12%|████▊                                   | 12/100 [01:09<08:49,  6.02s/it, loss=9.38]

Early stopping


(T):  13%|█████▏                                  | 13/100 [01:15<08:37,  5.94s/it, loss=9.15]

Early stopping


(T):  14%|█████▌                                  | 14/100 [01:21<08:35,  5.99s/it, loss=8.63]

Early stopping


(T):  15%|██████                                  | 15/100 [01:27<08:33,  6.04s/it, loss=6.59]

Early stopping


(T):  16%|██████▍                                 | 16/100 [01:34<08:28,  6.05s/it, loss=7.62]

Early stopping


(T):  17%|██████▊                                 | 17/100 [01:40<08:33,  6.19s/it, loss=6.19]

Early stopping


(T):  18%|███████▏                                | 18/100 [01:46<08:23,  6.14s/it, loss=6.32]

Early stopping


(T):  19%|███████▌                                | 19/100 [01:53<08:26,  6.25s/it, loss=6.17]

Early stopping
Epoch: 20, Loss: 0.2257
Epoch: 20, Loss: 0.6164
Epoch: 20, Loss: 0.8603
Epoch: 20, Loss: 1.0409
Epoch: 20, Loss: 1.1805
Epoch: 20, Loss: 1.3268
Epoch: 20, Loss: 1.5266
Epoch: 20, Loss: 1.6831
Epoch: 20, Loss: 1.8092
Epoch: 20, Loss: 1.9737
Epoch: 20, Loss: 2.0751
Epoch: 20, Loss: 2.2070
Epoch: 20, Loss: 2.5365
Epoch: 20, Loss: 2.7573
Epoch: 20, Loss: 2.9257
Epoch: 20, Loss: 3.1424
Epoch: 20, Loss: 3.5696
Epoch: 20, Loss: 4.0258
Epoch: 20, Loss: 4.1848
Epoch: 20, Loss: 4.4563


(T):  20%|████████▏                                | 20/100 [01:59<08:24,  6.30s/it, loss=4.6]

Epoch: 20, Loss: 4.6036
Early stopping


(T):  21%|████████▍                               | 21/100 [02:05<08:15,  6.28s/it, loss=5.97]

Early stopping


(T):  22%|█████████                                | 22/100 [02:11<08:01,  6.18s/it, loss=6.5]

Early stopping


(T):  23%|█████████▏                              | 23/100 [02:17<07:55,  6.18s/it, loss=5.89]

Early stopping


(T):  24%|█████████▌                              | 24/100 [02:23<07:43,  6.10s/it, loss=5.73]

Early stopping


(T):  25%|██████████                              | 25/100 [02:29<07:40,  6.14s/it, loss=5.04]

Early stopping


(T):  26%|██████████▍                             | 26/100 [02:36<07:36,  6.16s/it, loss=4.57]

Early stopping


(T):  27%|██████████▊                             | 27/100 [02:42<07:35,  6.24s/it, loss=4.67]

Early stopping


(T):  28%|███████████▏                            | 28/100 [02:48<07:31,  6.27s/it, loss=6.06]

Early stopping


(T):  29%|███████████▌                            | 29/100 [02:54<07:18,  6.18s/it, loss=5.52]

Early stopping


(T):  30%|████████████                            | 30/100 [03:00<07:04,  6.06s/it, loss=4.01]

Early stopping


(T):  31%|████████████▍                           | 31/100 [03:06<06:55,  6.02s/it, loss=4.85]

Early stopping


(T):  32%|████████████▊                           | 32/100 [03:12<06:50,  6.03s/it, loss=4.03]

Early stopping


(T):  33%|█████████████▏                          | 33/100 [03:18<06:43,  6.02s/it, loss=5.38]

Early stopping


(T):  34%|█████████████▌                          | 34/100 [03:24<06:42,  6.09s/it, loss=5.52]

Early stopping


(T):  35%|██████████████                          | 35/100 [03:30<06:32,  6.03s/it, loss=4.05]

Early stopping


(T):  36%|██████████████▍                         | 36/100 [03:36<06:28,  6.07s/it, loss=4.42]

Early stopping


(T):  37%|██████████████▊                         | 37/100 [03:43<06:24,  6.10s/it, loss=3.94]

Early stopping


(T):  38%|███████████████▏                        | 38/100 [03:49<06:18,  6.10s/it, loss=3.89]

Early stopping


(T):  39%|███████████████▌                        | 39/100 [03:55<06:10,  6.08s/it, loss=4.01]

Early stopping
Epoch: 40, Loss: 0.1928
Epoch: 40, Loss: 0.3324
Epoch: 40, Loss: 0.4244
Epoch: 40, Loss: 0.4759
Epoch: 40, Loss: 0.5632
Epoch: 40, Loss: 0.7063
Epoch: 40, Loss: 0.8380
Epoch: 40, Loss: 0.9299
Epoch: 40, Loss: 1.3457
Epoch: 40, Loss: 1.7166
Epoch: 40, Loss: 1.8390
Epoch: 40, Loss: 2.0169
Epoch: 40, Loss: 2.2963
Epoch: 40, Loss: 2.4278
Epoch: 40, Loss: 2.6315
Epoch: 40, Loss: 2.7626
Epoch: 40, Loss: 3.1153
Epoch: 40, Loss: 3.1729
Epoch: 40, Loss: 3.5279


(T):  40%|████████████████                        | 40/100 [04:01<06:14,  6.24s/it, loss=3.76]

Epoch: 40, Loss: 3.6443
Epoch: 40, Loss: 3.7600
Early stopping


(T):  41%|████████████████▍                       | 41/100 [04:07<06:04,  6.17s/it, loss=3.38]

Early stopping


(T):  42%|█████████████████▏                       | 42/100 [04:14<06:07,  6.33s/it, loss=2.6]

Early stopping


(T):  43%|█████████████████▏                      | 43/100 [04:20<06:01,  6.34s/it, loss=2.93]

Early stopping


(T):  44%|█████████████████▌                      | 44/100 [04:27<05:57,  6.38s/it, loss=3.67]

Early stopping


(T):  45%|██████████████████                      | 45/100 [04:33<05:49,  6.35s/it, loss=3.57]

Early stopping


(T):  46%|██████████████████▍                     | 46/100 [04:40<05:47,  6.44s/it, loss=2.92]

Early stopping


(T):  47%|██████████████████▊                     | 47/100 [04:46<05:32,  6.27s/it, loss=2.81]

Early stopping


(T):  48%|███████████████████▏                    | 48/100 [04:52<05:23,  6.23s/it, loss=2.85]

Early stopping


(T):  49%|███████████████████▌                    | 49/100 [04:58<05:13,  6.15s/it, loss=3.19]

Early stopping


(T):  50%|████████████████████                    | 50/100 [05:04<05:05,  6.10s/it, loss=4.04]

Early stopping


(T):  51%|████████████████████▍                   | 51/100 [05:10<05:03,  6.20s/it, loss=3.82]

Early stopping


(T):  52%|████████████████████▊                   | 52/100 [05:16<04:54,  6.14s/it, loss=2.71]

Early stopping


(T):  53%|█████████████████████▏                  | 53/100 [05:22<04:40,  5.97s/it, loss=3.75]

Early stopping


(T):  54%|█████████████████████▌                  | 54/100 [05:27<04:25,  5.77s/it, loss=3.26]

Early stopping


(T):  55%|██████████████████████▌                  | 55/100 [05:33<04:18,  5.76s/it, loss=3.1]

Early stopping


(T):  56%|██████████████████████▍                 | 56/100 [05:39<04:15,  5.80s/it, loss=3.52]

Early stopping


(T):  57%|███████████████████████▎                 | 57/100 [05:44<04:02,  5.64s/it, loss=2.7]

Early stopping


(T):  58%|███████████████████████▏                | 58/100 [05:49<03:52,  5.54s/it, loss=2.49]

Early stopping


(T):  59%|███████████████████████▌                | 59/100 [05:54<03:41,  5.40s/it, loss=3.71]

Early stopping
Epoch: 60, Loss: 0.1107
Epoch: 60, Loss: 0.2493
Epoch: 60, Loss: 0.7541
Epoch: 60, Loss: 0.8813
Epoch: 60, Loss: 0.9860
Epoch: 60, Loss: 1.2470
Epoch: 60, Loss: 1.3133
Epoch: 60, Loss: 1.5739
Epoch: 60, Loss: 1.8105
Epoch: 60, Loss: 1.9330
Epoch: 60, Loss: 2.2020
Epoch: 60, Loss: 2.4296
Epoch: 60, Loss: 2.5157
Epoch: 60, Loss: 2.6406
Epoch: 60, Loss: 2.7052
Epoch: 60, Loss: 2.8366
Epoch: 60, Loss: 3.0021
Epoch: 60, Loss: 3.1699
Epoch: 60, Loss: 3.4913


(T):  60%|████████████████████████                | 60/100 [06:00<03:35,  5.38s/it, loss=3.81]

Epoch: 60, Loss: 3.6907
Epoch: 60, Loss: 3.8081
Early stopping


(T):  61%|████████████████████████▍               | 61/100 [06:05<03:26,  5.29s/it, loss=3.04]

Early stopping


(T):  62%|████████████████████████▊               | 62/100 [06:10<03:21,  5.31s/it, loss=2.18]

Early stopping


(T):  63%|█████████████████████████▏              | 63/100 [06:15<03:14,  5.25s/it, loss=2.91]

Early stopping


(T):  64%|█████████████████████████▌              | 64/100 [06:21<03:12,  5.34s/it, loss=2.88]

Early stopping


(T):  65%|██████████████████████████              | 65/100 [06:26<03:02,  5.23s/it, loss=3.35]

Early stopping


(T):  66%|██████████████████████████▍             | 66/100 [06:31<02:56,  5.19s/it, loss=2.14]

Early stopping


(T):  67%|██████████████████████████▊             | 67/100 [06:36<02:45,  5.01s/it, loss=3.31]

Early stopping


(T):  68%|███████████████████████████▏            | 68/100 [06:41<02:43,  5.12s/it, loss=2.73]

Early stopping


(T):  69%|███████████████████████████▌            | 69/100 [06:46<02:38,  5.11s/it, loss=2.58]

Early stopping


(T):  70%|████████████████████████████            | 70/100 [06:51<02:32,  5.07s/it, loss=2.71]

Early stopping


(T):  71%|█████████████████████████████            | 71/100 [06:56<02:27,  5.08s/it, loss=2.3]

Early stopping


(T):  72%|████████████████████████████▊           | 72/100 [07:01<02:23,  5.13s/it, loss=3.09]

Early stopping


(T):  73%|█████████████████████████████▉           | 73/100 [07:07<02:21,  5.26s/it, loss=2.7]

Early stopping


(T):  74%|█████████████████████████████▌          | 74/100 [07:12<02:11,  5.07s/it, loss=3.49]

Early stopping


(T):  75%|██████████████████████████████          | 75/100 [07:17<02:07,  5.08s/it, loss=3.31]

Early stopping


(T):  76%|██████████████████████████████▍         | 76/100 [07:21<01:59,  4.99s/it, loss=2.89]

Early stopping


(T):  77%|██████████████████████████████▊         | 77/100 [07:27<01:56,  5.05s/it, loss=2.17]

Early stopping


(T):  78%|███████████████████████████████▏        | 78/100 [07:32<01:53,  5.15s/it, loss=2.53]

Early stopping


(T):  79%|███████████████████████████████▌        | 79/100 [07:37<01:46,  5.09s/it, loss=1.87]

Early stopping
Epoch: 80, Loss: 0.2416
Epoch: 80, Loss: 0.2967
Epoch: 80, Loss: 0.3485
Epoch: 80, Loss: 0.5261
Epoch: 80, Loss: 0.7674
Epoch: 80, Loss: 0.8207
Epoch: 80, Loss: 0.9470
Epoch: 80, Loss: 1.0469
Epoch: 80, Loss: 1.0968
Epoch: 80, Loss: 1.2086
Epoch: 80, Loss: 1.2677
Epoch: 80, Loss: 1.3345
Epoch: 80, Loss: 1.4046
Epoch: 80, Loss: 1.4606
Epoch: 80, Loss: 1.6121
Epoch: 80, Loss: 1.7213
Epoch: 80, Loss: 2.0449
Epoch: 80, Loss: 2.1261
Epoch: 80, Loss: 2.1911
Epoch: 80, Loss: 2.2812


(T):  80%|████████████████████████████████        | 80/100 [07:42<01:44,  5.21s/it, loss=2.34]

Epoch: 80, Loss: 2.3395
Early stopping


(T):  81%|████████████████████████████████▍       | 81/100 [07:48<01:39,  5.23s/it, loss=1.94]

Early stopping


(T):  82%|████████████████████████████████▊       | 82/100 [07:53<01:33,  5.21s/it, loss=2.73]

Early stopping


(T):  83%|█████████████████████████████████▏      | 83/100 [07:58<01:27,  5.15s/it, loss=2.34]

Early stopping


(T):  84%|██████████████████████████████████▍      | 84/100 [08:03<01:22,  5.13s/it, loss=2.6]

Early stopping


(T):  85%|██████████████████████████████████      | 85/100 [08:08<01:16,  5.09s/it, loss=2.68]

Early stopping


(T):  86%|████████████████████████████████████▉      | 86/100 [08:13<01:11,  5.09s/it, loss=3]

Early stopping


(T):  87%|██████████████████████████████████▊     | 87/100 [08:18<01:04,  4.98s/it, loss=2.19]

Early stopping


(T):  88%|███████████████████████████████████▏    | 88/100 [08:23<00:59,  4.95s/it, loss=1.98]

Early stopping


(T):  89%|███████████████████████████████████▌    | 89/100 [08:28<00:55,  5.01s/it, loss=2.43]

Early stopping


(T):  90%|████████████████████████████████████    | 90/100 [08:33<00:49,  4.93s/it, loss=1.69]

Early stopping


(T):  91%|████████████████████████████████████▍   | 91/100 [08:38<00:44,  4.97s/it, loss=1.67]

Early stopping


(T):  92%|████████████████████████████████████▊   | 92/100 [08:43<00:39,  4.97s/it, loss=1.39]

Early stopping


(T):  93%|█████████████████████████████████████▏  | 93/100 [08:47<00:34,  4.88s/it, loss=1.66]

Early stopping


(T):  94%|█████████████████████████████████████▌  | 94/100 [08:52<00:29,  4.93s/it, loss=1.62]

Early stopping


(T):  95%|██████████████████████████████████████▉  | 95/100 [08:57<00:24,  4.86s/it, loss=1.9]

Early stopping


(T):  96%|██████████████████████████████████████▍ | 96/100 [09:01<00:18,  4.72s/it, loss=2.17]

Early stopping


(T):  97%|██████████████████████████████████████▊ | 97/100 [09:06<00:14,  4.82s/it, loss=2.53]

Early stopping


(T):  98%|███████████████████████████████████████▏| 98/100 [09:11<00:09,  4.79s/it, loss=1.92]

Early stopping


(T):  99%|███████████████████████████████████████▌| 99/100 [09:16<00:04,  4.81s/it, loss=1.99]

Early stopping
Epoch: 100, Loss: 0.1281
Epoch: 100, Loss: 0.2393
Epoch: 100, Loss: 0.3162
Epoch: 100, Loss: 0.3612
Epoch: 100, Loss: 0.3990
Epoch: 100, Loss: 0.4545
Epoch: 100, Loss: 0.5051
Epoch: 100, Loss: 0.5577
Epoch: 100, Loss: 0.7256
Epoch: 100, Loss: 0.9742
Epoch: 100, Loss: 1.0662
Epoch: 100, Loss: 1.1385
Epoch: 100, Loss: 1.2434
Epoch: 100, Loss: 1.3922
Epoch: 100, Loss: 1.4435
Epoch: 100, Loss: 1.5660
Epoch: 100, Loss: 1.8009
Epoch: 100, Loss: 1.9259
Epoch: 100, Loss: 1.9926


(T):  99%|███████████████████████████████████████▌| 99/100 [09:21<00:04,  4.81s/it, loss=2.19]

Epoch: 100, Loss: 2.0465
Epoch: 100, Loss: 2.1915
Early stopping


(T): 100%|███████████████████████████████████████| 100/100 [09:21<00:00,  5.62s/it, loss=2.19]


In [23]:
test_result = test(encoder_model, dataloader, train_test_indices[0], device)

(LR): 100%|████████████████████████| 5000/5000 [00:01<00:00, best test F1Mi=0.741, F1Ma=0.725]
(LR): 100%|█████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.795, F1Ma=0.78]
(LR): 100%|█████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.804, F1Ma=0.79]
(LR): 100%|████████████████████████| 5000/5000 [00:01<00:00, best test F1Mi=0.741, F1Ma=0.741]
(LR): 100%|████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.812, F1Ma=0.787]
(LR): 100%|████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.741, F1Ma=0.738]
(LR): 100%|█████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.768, F1Ma=0.76]
(LR): 100%|█████████████████████████| 5000/5000 [00:01<00:00, best test F1Mi=0.741, F1Ma=0.73]
(LR): 100%|████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.768, F1Ma=0.764]
(LR): 100%|█████████████████████████| 5000/5000 [00:02<00:00, best test F1Mi=0.741, F1Ma=0.73]


In [24]:
micro_f1_values = [entry['micro_f1']*100 for entry in test_result]

In [25]:
np_micro_f1_values = np.array(micro_f1_values)
micro_f1_mean = np.mean(np_micro_f1_values)
micro_f1_std = np.std(np_micro_f1_values)
uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(np_micro_f1_values, func=np.mean, 
                                                                      n_boot=1000), 95) - np_micro_f1_values.mean()))

print(f'test acc mean = {micro_f1_mean:.4f} ± {micro_f1_std:.4f}')

test acc mean = 76.5179 ± 2.7389
