In [1]:
import argparse
import os
import os.path as osp
import shutil
import time
from itertools import product
from json import dumps
import random
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch.optim import Adam
from torch_geometric.loader import DataListLoader, DataLoader
from torch_geometric.nn import DataParallel
from torch_geometric.seed import seed_everything
from torch_geometric.utils.convert import to_networkx
from torch_geometric.data import DataLoader, Data
from torch import nn
from tqdm import tqdm

from pytorch_metric_learning.losses import NTXentLoss, VICRegLoss
from sklearn.model_selection import GridSearchCV, StratifiedKFold

import train_utils
from data_utils import extract_edge_attributes
from torch_geometric.datasets import TUDataset
from layers.input_encoder import LinearEncoder, LinearEdgeEncoder
from layers.layer_utils import make_gnn_layer
from models.GraphClassification import GraphClassification
from models.model_utils import make_GNN

import GCL.losses as L
import GCL.augmentors as A
from GCL.eval import get_split, SVMEvaluator, LREvaluator
from GCL.models import DualBranchContrast

import warnings
warnings.filterwarnings('ignore')

In [2]:
parser = argparse.ArgumentParser(f'arguments for training and testing')
parser.add_argument('--save_dir', type=str, default='./save', help='Base directory for saving information.')
parser.add_argument('--seed', type=int, default=2, help='Random seed for reproducibility.')
parser.add_argument('--dataset_name', type=str, default="PROTEINS",
                    choices=("MUTAG", "PROTEINS", "PTC_MR", "IMDBBINARY"), help='Name of dataset')
parser.add_argument('--drop_prob', type=float, default=0.5,
                    help='Probability of zeroing an activation in dropout layers.')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size per GPU. Scales automatically when \
                        multiple GPUs are available.')
parser.add_argument("--parallel", action="store_true",
                    help="If true, use DataParallel for multi-gpu training")
parser.add_argument('--num_workers', type=int, default=0, help='Number of worker.')
parser.add_argument('--load_path', type=str, default=None, help='Path to load as a model checkpoint.')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.')
parser.add_argument('--l2_wd', type=float, default=3e-4, help='L2 weight decay.')
parser.add_argument('--num_epochs', type=int, default=100, help='Number of epochs.')
parser.add_argument("--hidden_size", type=int, default=128, help="Hidden size of the model")
parser.add_argument("--embedding_size", type=int, default=16, help="Output size of the logistic model")
parser.add_argument("--model_name", type=str, default="KHopGNNConv",
                    choices=("KHopGNNConv"), help="Base GNN model")
parser.add_argument("--K", type=int, default=2, help="Number of hop to consider")
parser.add_argument("--num_layer", type=int, default=2, help="Number of layer for feature encoder")
parser.add_argument("--JK", type=str, default="sum", choices=("sum", "max", "mean", "attention", "last", "concat"),
                    help="Jumping knowledge method")
parser.add_argument("--residual", default=True, action="store_true", help="If true, use residual connection between each layer")
parser.add_argument("--virtual_node", action="store_true", default=False, 
                    help="If true, add virtual node information in each layer")
parser.add_argument("--eps", type=float, default=0., help="Initial epsilon in GIN")
parser.add_argument("--train_eps", action="store_true", help="If true, the epsilon in GIN model is trainable")
parser.add_argument("--combine", type=str, default="geometric", choices=("attention", "geometric"),
                    help="Combine method in k-hop aggregation")
parser.add_argument("--pooling_method", type=str, default="sum", choices=("mean", "sum", "attention"),
                    help="Pooling method in graph classification")
parser.add_argument('--norm_type', type=str, default="Batch",
                    choices=("Batch", "Layer", "Instance", "GraphSize", "Pair"),
                    help="Normalization method in model")
parser.add_argument('--aggr', type=str, default="add",
                    help='Aggregation method in GNN layer, only works in GraphSAGE')
parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
parser.add_argument('--factor', type=float, default=0.5, help='Factor for reducing learning rate scheduler')
parser.add_argument('--reprocess', action="store_true", help='If true, reprocess the dataset')
parser.add_argument('--search', action="store_true", help='If true, search hyper-parameters')
parser.add_argument("--pos_enc_dim", type=int, default=6, help="Initial positional dim.")
parser.add_argument("--pos_attr", type=bool, default=False, help="Positional attributes.")

_StoreAction(option_strings=['--pos_attr'], dest='pos_attr', nargs=None, const=None, default=False, type=<class 'bool'>, choices=None, required=False, help='Positional attributes.', metavar=None)

In [3]:
args = parser.parse_args("")

In [4]:
args.name = args.model_name + "_" + str(args.K) + "_" + str(args.search)

In [5]:
torch.set_num_threads(1)

In [6]:
# Set up logging and devices
args.save_dir = train_utils.get_save_dir(args.save_dir, args.name, type=args.dataset_name)
log = train_utils.get_logger(args.save_dir, args.name)
device, args.gpu_ids = train_utils.get_available_devices()

In [7]:
if len(args.gpu_ids) > 1 and args.parallel:
    log.info(f'Using multi-gpu training')
    args.parallel = True
    loader = DataListLoader
    args.batch_size *= max(1, len(args.gpu_ids))
else:
    log.info(f'Using single-gpu training')
    args.parallel = False
    loader = DataLoader

[02.13.25 11:37:43] Using single-gpu training


In [8]:
# Set random seed
seed = args.seed
log.info(f'Using random seed {seed}...')
seed_everything(seed)

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

[02.13.25 11:37:43] Using random seed 2...


In [9]:
def edge_feature_transform(g):
    return extract_edge_attributes(g, args.pos_enc_dim)

In [10]:
tag = str(int(time.time()))

In [11]:
def get_model(args):
    layer = make_gnn_layer(args)
    init_emb = LinearEncoder(args.input_size, args.hidden_size, pos_attr=args.pos_attr)
    init_edge_attr_emb = LinearEdgeEncoder(args.edge_attr_size, args.hidden_size, edge_attr=True)
    init_edge_attr_v2_emb = LinearEdgeEncoder(args.edge_attr_v2_size, args.hidden_size, edge_attr=False)
    
    GNNModel = make_GNN(args)
    
    gnn = GNNModel(
        num_layer=args.num_layer,
        gnn_layer=layer,
        JK=args.JK,
        norm_type=args.norm_type,
        init_emb=init_emb,
        init_edge_attr_emb=init_edge_attr_emb,
        init_edge_attr_v2_emb=init_edge_attr_v2_emb,
        residual=args.residual,
        virtual_node=args.virtual_node,
        drop_prob=args.drop_prob)

    model = GraphClassification(embedding_model=gnn,
                                pooling_method=args.pooling_method,
                                output_size=args.output_size)
    
    model.reset_parameters()

    if args.parallel:
        model = DataParallel(model, args.gpu_ids)
    return model

In [12]:
class Projection(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Projection, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU()
        )
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x) + self.linear(x)

In [13]:
def vicreg_loss(embeddings1, embeddings2, lambda_var=25, lambda_cov=25, mu=1):
    """
    Calculate the VICReg loss between two sets of embeddings with the correct handling of covariance differences.
    
    Args:
    embeddings1, embeddings2 (torch.Tensor): Embeddings from two views, shape (batch_size, feature_dim).
    lambda_var (float): Coefficient for the variance loss.
    lambda_cov (float): Coefficient for the covariance loss.
    mu (float): Coefficient for the invariance loss.
    
    Returns:
    torch.Tensor: The total VICReg loss.
    """
    # Invariance Loss
    invariance_loss = F.mse_loss(embeddings1, embeddings2)

    # Variance Loss
    def variance_loss(embeddings1, embeddings2):
        mean_embeddings1 = embeddings1.mean(dim=0)
        mean_embeddings2 = embeddings2.mean(dim=0)
        
        std_dev1 = torch.sqrt((embeddings1 - mean_embeddings1).var(dim=0) + 1e-4)
        std_dev2 = torch.sqrt((embeddings2 - mean_embeddings2).var(dim=0) + 1e-4)
        
        return torch.mean(torch.abs(F.relu(1 - std_dev1) - F.relu(1 - std_dev2)))

    variance_loss_value = variance_loss(embeddings1, embeddings2)

    # Covariance Loss
    def covariance_loss(embeddings1, embeddings2):
        batch_size, feature_dim = embeddings1.size()
        
        embeddings_centered1 = embeddings1 - embeddings1.mean(dim=0)
        embeddings_centered2 = embeddings2 - embeddings2.mean(dim=0)
        
        covariance_matrix1 = torch.matmul(embeddings_centered1.T, embeddings_centered1) / (batch_size - 1)
        covariance_matrix2 = torch.matmul(embeddings_centered2.T, embeddings_centered2) / (batch_size - 1)
        
        covariance_matrix1.fill_diagonal_(0)
        covariance_matrix2.fill_diagonal_(0)
        
        cov_diff = torch.abs(covariance_matrix1.pow(2) - covariance_matrix2.pow(2))
        return cov_diff.sum() / feature_dim

    covariance_loss_value = covariance_loss(embeddings1, embeddings2)

    total_loss = mu * invariance_loss + lambda_var * variance_loss_value + lambda_cov * covariance_loss_value
    
    return total_loss

In [14]:
class Encoder(torch.nn.Module):
    def __init__(self, model_1, model_2, mlp1, mlp2, aug1, aug2):
        super(Encoder, self).__init__()
        self.model_1 = model_1
        self.model_2 = model_2
        self.mlp1 = mlp1
        self.mlp2 = mlp2
        self.aug1 = aug1
        self.aug2 = aug2
        
    def get_embedding(self, data):
        z, g = self.model_1(data)
        z_pos, g_pos = self.model_2(data)
        
        z = self.mlp1(z)
        g = self.mlp2(g)
        
        z_pos = self.mlp1(z_pos)
        g_pos = self.mlp2(g_pos)

        g = torch.cat((g, g_pos), 1)
        z = torch.cat((z, z_pos), 1)

        return g.detach(), z.detach()

    def forward(self, data):
        data1 = self.aug1(data.x, data.edge_index, data.y, data.pos, data.edge_attr,
                          data.edge_attr_v2, data.batch, data.ptr)
        data2 = self.aug2(data.x, data.edge_index, data.y, data.pos, data.edge_attr,
                          data.edge_attr_v2, data.batch, data.ptr)
        
        # Structural features
        z1, g1 = self.model_1(data1)
        z2, g2 = self.model_1(data2)
        
        # Positional features
        z1_pos, g1_pos = self.model_2(data1)
        z2_pos, g2_pos = self.model_2(data2)
        
        h1, h2 = [self.mlp1(h) for h in [z1, z2]]
        g1, g2 = [self.mlp2(g) for g in [g1, g2]]
        
        h1_pos, h2_pos = [self.mlp1(h_pos) for h_pos in [z1_pos, z2_pos]]
        g1_pos, g2_pos = [self.mlp2(g_pos) for g_pos in [g1_pos, g2_pos]]
        
        h1 = torch.cat((h1, h1_pos), 1)
        h2 = torch.cat((h2, h2_pos), 1)
        g1 = torch.cat((g1, g1_pos), 1)
        g2 = torch.cat((g2, g2_pos), 1)
        
        return h1, h2, g1, g2

In [15]:
def train(encoder_model, dataloader, optimizer, device):
    best = float("inf")
    cnt_wait = 0
    best_t = 0
    
    loss_func = NTXentLoss(temperature=0.10)
    
    encoder_model.train()
    epoch_loss = 0
    for data in dataloader:
        data = data.to(device)
        optimizer.zero_grad()

        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)

        h1, h2, g1, g2 = encoder_model(data)
        
        embeddings = torch.cat((g1, g2))
        
        # The same index corresponds to a positive pair
        indices = torch.arange(0, g1.size(0), device=device)
        labels = torch.cat((indices, indices))
        
        reg_loss = vicreg_loss(h1, h2, lambda_var=25, lambda_cov=25, mu=1)
        loss = loss_func(embeddings, labels) + 0.009*reg_loss
        
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        
        if epoch % 20 == 0:
            print("Epoch: {0}, Loss: {1:0.4f}".format(epoch, epoch_loss))

        if epoch_loss < best:
            best = epoch_loss
            best_t = epoch
            cnt_wait = 0
            torch.save(encoder_model.state_dict(), './pkl/best_model_'+ args.dataset_name + tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            print("Early stopping")
            break
            
    return epoch_loss

In [16]:
def test(encoder_model, dataloader, seeds, device):
    encoder_model.eval()
    x = []
    y = []
    for data in dataloader:
        data = data.to(device)
        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)
        graph_embeds, node_embeds = encoder_model.get_embedding(data)
        x.append(graph_embeds)
        y.append(data.y)
    x = torch.cat(x, dim=0)
    y = torch.cat(y, dim=0)

    result = []
        
    random.shuffle(seeds)
    seeds = seeds.tolist()
    for _ in np.arange(10):
        random_seed = seeds.pop()
        split = get_split(num_samples=x.size()[0], train_ratio=0.1, test_ratio=0.8, seed=random_seed)
        result.append(LREvaluator()(x, y, split))
    
    return result

In [17]:
path = "./data/data_splits/" + args.dataset_name

In [18]:
train_indices, test_indices, train_test_indices = [], [], []
for i in range(10):
    train_filename = os.path.join(path, '10fold_idx', 'train_idx-{}.txt'.format(i + 1))
    test_filename = os.path.join(path, '10fold_idx', 'test_idx-{}.txt'.format(i + 1))
    train_indices.append(torch.from_numpy(np.loadtxt(train_filename, dtype=int)).to(torch.long))
    test_indices.append(torch.from_numpy(np.loadtxt(test_filename, dtype=int)).to(torch.long))

train_test_filename = os.path.join(path, '10fold_idx', 'train_idx-{}.txt'.format(i + 2))
train_test_indices.append(torch.from_numpy(np.loadtxt(train_test_filename, dtype=int)).to(torch.long))

In [19]:
dataset = TUDataset(root='./data/'+args.dataset_name, name=args.dataset_name, 
                    pre_transform=T.Compose([edge_feature_transform]))

In [20]:
classes = torch.unique(dataset.y)
args.n_classes = len(classes)

In [21]:
args.input_size = dataset.num_node_features
args.pos_size = args.pos_enc_dim
args.output_size = args.hidden_size
args.edge_attr_size =  dataset.edge_attr.shape[1]
args.edge_attr_v2_size =  dataset.edge_attr_v2.shape[1]

In [22]:
aug1 = A.RandomChoice([A.FeatureDropout(pf=0.1),
                           A.FeatureMasking(pf=0.1),
                           A.EdgeAttrMasking(pf=0.1)], 1)

aug2 = A.RandomChoice([A.RWSampling(num_seeds=1000, walk_length=10),
                       A.NodeDropping(pn=0.1),
                       A.EdgeRemoving(pe=0.1)], 1)

model_1 = get_model(args)
model_1.to(device)

args.pos_attr = True
args.input_size = args.pos_size
model_2 = get_model(args)
model_2.to(device)

mlp1 = Projection(input_dim=args.hidden_size, output_dim=args.hidden_size)
mlp2 = Projection(input_dim=args.hidden_size, output_dim=args.hidden_size)

encoder_model = Encoder(model_1=model_1, model_2=model_2, mlp1=mlp1, mlp2=mlp2, aug1=aug1, aug2=aug2).to(device)

In [23]:
optimizer = Adam(encoder_model.parameters(), lr=args.lr, weight_decay=args.l2_wd)
dataloader = DataLoader(dataset, batch_size=args.batch_size)
with tqdm(total=100, desc='(T)') as pbar:
    for epoch in range(1, 101):
        loss = train(encoder_model, dataloader, optimizer, device)
        pbar.set_postfix({'loss': loss})
        pbar.update()

(T):   1%|▍                                        | 1/100 [00:04<06:59,  4.23s/it, loss=59.6]

Early stopping


(T):   2%|▊                                        | 2/100 [00:08<07:11,  4.40s/it, loss=45.2]

Early stopping


(T):   3%|█▏                                       | 3/100 [00:13<07:11,  4.45s/it, loss=41.5]

Early stopping


(T):   4%|█▋                                       | 4/100 [00:18<07:25,  4.64s/it, loss=38.3]

Early stopping


(T):   5%|██                                       | 5/100 [00:23<07:45,  4.90s/it, loss=36.5]

Early stopping


(T):   6%|██▌                                        | 6/100 [00:29<08:00,  5.12s/it, loss=28]

Early stopping


(T):   7%|██▊                                      | 7/100 [00:34<07:50,  5.06s/it, loss=34.1]

Early stopping


(T):   8%|███▎                                     | 8/100 [00:38<07:30,  4.89s/it, loss=23.8]

Early stopping


(T):   9%|███▋                                     | 9/100 [00:43<07:17,  4.81s/it, loss=27.5]

Early stopping


(T):  10%|████                                    | 10/100 [00:47<07:05,  4.73s/it, loss=34.7]

Early stopping


(T):  11%|████▍                                   | 11/100 [00:52<06:54,  4.66s/it, loss=22.2]

Early stopping


(T):  12%|████▊                                   | 12/100 [00:57<06:59,  4.76s/it, loss=27.5]

Early stopping


(T):  13%|█████▍                                    | 13/100 [01:02<07:11,  4.95s/it, loss=25]

Early stopping


(T):  14%|█████▌                                  | 14/100 [01:07<07:02,  4.92s/it, loss=23.9]

Early stopping


(T):  15%|██████                                  | 15/100 [01:12<07:10,  5.06s/it, loss=22.7]

Early stopping


(T):  16%|██████▍                                 | 16/100 [01:17<06:53,  4.92s/it, loss=21.7]

Early stopping


(T):  17%|██████▊                                 | 17/100 [01:22<06:42,  4.85s/it, loss=20.4]

Early stopping


(T):  18%|███████▏                                | 18/100 [01:26<06:29,  4.75s/it, loss=23.2]

Early stopping


(T):  19%|███████▌                                | 19/100 [01:31<06:19,  4.69s/it, loss=19.1]

Early stopping
Epoch: 20, Loss: 0.5069
Epoch: 20, Loss: 2.1861
Epoch: 20, Loss: 2.9472
Epoch: 20, Loss: 4.8554
Epoch: 20, Loss: 5.3580
Epoch: 20, Loss: 5.8757
Epoch: 20, Loss: 6.8246
Epoch: 20, Loss: 7.9041
Epoch: 20, Loss: 8.3205
Epoch: 20, Loss: 8.7998
Epoch: 20, Loss: 11.0796
Epoch: 20, Loss: 13.3746
Epoch: 20, Loss: 13.9348
Epoch: 20, Loss: 14.9302
Epoch: 20, Loss: 15.4240
Epoch: 20, Loss: 16.6865
Epoch: 20, Loss: 17.8733
Epoch: 20, Loss: 19.2533
Epoch: 20, Loss: 19.7038


(T):  20%|████████                                | 20/100 [01:35<06:14,  4.69s/it, loss=20.5]

Epoch: 20, Loss: 20.1502
Epoch: 20, Loss: 20.5472
Early stopping


(T):  21%|████████▍                               | 21/100 [01:41<06:31,  4.95s/it, loss=15.4]

Early stopping


(T):  22%|████████▊                               | 22/100 [01:46<06:17,  4.84s/it, loss=29.1]

Early stopping


(T):  23%|█████████▏                              | 23/100 [01:50<06:04,  4.73s/it, loss=25.6]

Early stopping


(T):  24%|█████████▌                              | 24/100 [01:55<05:54,  4.66s/it, loss=23.9]

Early stopping


(T):  25%|██████████                              | 25/100 [01:59<05:44,  4.59s/it, loss=23.2]

Early stopping


(T):  26%|██████████▍                             | 26/100 [02:03<05:35,  4.53s/it, loss=22.3]

Early stopping


(T):  27%|██████████▊                             | 27/100 [02:08<05:28,  4.50s/it, loss=21.5]

Early stopping


(T):  28%|███████████▏                            | 28/100 [02:13<05:31,  4.61s/it, loss=21.2]

Early stopping


(T):  29%|████████████▏                             | 29/100 [02:17<05:31,  4.67s/it, loss=19]

Early stopping


(T):  30%|████████████                            | 30/100 [02:22<05:23,  4.63s/it, loss=21.8]

Early stopping


(T):  31%|████████████▍                           | 31/100 [02:26<05:15,  4.58s/it, loss=22.4]

Early stopping


(T):  32%|████████████▊                           | 32/100 [02:31<05:07,  4.53s/it, loss=19.8]

Early stopping


(T):  33%|█████████████▏                          | 33/100 [02:35<04:59,  4.47s/it, loss=15.6]

Early stopping


(T):  34%|█████████████▌                          | 34/100 [02:40<04:59,  4.54s/it, loss=12.8]

Early stopping


(T):  35%|██████████████                          | 35/100 [02:45<04:58,  4.59s/it, loss=16.9]

Early stopping


(T):  36%|██████████████▍                         | 36/100 [02:49<04:42,  4.41s/it, loss=19.8]

Early stopping


(T):  37%|██████████████▊                         | 37/100 [02:53<04:32,  4.32s/it, loss=19.3]

Early stopping


(T):  38%|███████████████▏                        | 38/100 [02:57<04:34,  4.42s/it, loss=13.8]

Early stopping


(T):  39%|███████████████▌                        | 39/100 [03:02<04:29,  4.43s/it, loss=17.5]

Early stopping
Epoch: 40, Loss: 1.1820
Epoch: 40, Loss: 2.0806
Epoch: 40, Loss: 2.7755
Epoch: 40, Loss: 3.7279
Epoch: 40, Loss: 5.4143
Epoch: 40, Loss: 5.9102
Epoch: 40, Loss: 6.8360
Epoch: 40, Loss: 7.1130
Epoch: 40, Loss: 7.8233
Epoch: 40, Loss: 8.4112
Epoch: 40, Loss: 8.5550
Epoch: 40, Loss: 9.6106
Epoch: 40, Loss: 10.0490
Epoch: 40, Loss: 10.2390
Epoch: 40, Loss: 10.3545
Epoch: 40, Loss: 10.5842
Epoch: 40, Loss: 11.6652
Epoch: 40, Loss: 13.1184
Epoch: 40, Loss: 13.3543
Epoch: 40, Loss: 14.5302


(T):  40%|████████████████                        | 40/100 [03:06<04:19,  4.32s/it, loss=15.8]

Epoch: 40, Loss: 15.8335
Early stopping


(T):  41%|████████████████▍                       | 41/100 [03:10<04:19,  4.40s/it, loss=13.2]

Early stopping


(T):  42%|████████████████▊                       | 42/100 [03:15<04:14,  4.38s/it, loss=16.4]

Early stopping


(T):  43%|█████████████████▏                      | 43/100 [03:19<04:14,  4.46s/it, loss=10.3]

Early stopping


(T):  44%|█████████████████▌                      | 44/100 [03:24<04:06,  4.40s/it, loss=18.7]

Early stopping


(T):  45%|██████████████████                      | 45/100 [03:28<04:04,  4.44s/it, loss=20.4]

Early stopping


(T):  46%|██████████████████▍                     | 46/100 [03:33<03:59,  4.43s/it, loss=14.9]

Early stopping


(T):  47%|███████████████████▋                      | 47/100 [03:37<03:50,  4.34s/it, loss=16]

Early stopping


(T):  48%|███████████████████▏                    | 48/100 [03:41<03:39,  4.22s/it, loss=16.5]

Early stopping


(T):  49%|███████████████████▌                    | 49/100 [03:45<03:36,  4.24s/it, loss=11.4]

Early stopping


(T):  50%|████████████████████                    | 50/100 [03:49<03:31,  4.23s/it, loss=12.2]

Early stopping


(T):  51%|████████████████████▍                   | 51/100 [03:54<03:38,  4.47s/it, loss=13.1]

Early stopping


(T):  52%|████████████████████▊                   | 52/100 [03:59<03:44,  4.68s/it, loss=14.2]

Early stopping


(T):  53%|█████████████████████▏                  | 53/100 [04:05<03:48,  4.86s/it, loss=8.55]

Early stopping


(T):  54%|█████████████████████▌                  | 54/100 [04:09<03:39,  4.77s/it, loss=14.4]

Early stopping


(T):  55%|██████████████████████                  | 55/100 [04:14<03:30,  4.68s/it, loss=15.5]

Early stopping


(T):  56%|██████████████████████▍                 | 56/100 [04:18<03:22,  4.61s/it, loss=14.7]

Early stopping


(T):  57%|██████████████████████▊                 | 57/100 [04:23<03:16,  4.58s/it, loss=13.8]

Early stopping


(T):  58%|███████████████████████▏                | 58/100 [04:27<03:12,  4.58s/it, loss=13.3]

Early stopping


(T):  59%|███████████████████████▌                | 59/100 [04:32<03:05,  4.52s/it, loss=11.4]

Early stopping
Epoch: 60, Loss: 0.1766
Epoch: 60, Loss: 0.8937
Epoch: 60, Loss: 1.9420
Epoch: 60, Loss: 2.8620
Epoch: 60, Loss: 2.9711
Epoch: 60, Loss: 3.2968
Epoch: 60, Loss: 3.4780
Epoch: 60, Loss: 3.6210
Epoch: 60, Loss: 4.1132
Epoch: 60, Loss: 4.3299
Epoch: 60, Loss: 4.9357
Epoch: 60, Loss: 5.1429
Epoch: 60, Loss: 7.1111
Epoch: 60, Loss: 7.5363
Epoch: 60, Loss: 7.6428
Epoch: 60, Loss: 7.8175
Epoch: 60, Loss: 8.1108
Epoch: 60, Loss: 9.6247
Epoch: 60, Loss: 9.9594


(T):  60%|████████████████████████                | 60/100 [04:37<03:05,  4.65s/it, loss=10.7]

Epoch: 60, Loss: 10.5909
Epoch: 60, Loss: 10.7421
Early stopping


(T):  61%|████████████████████████▍               | 61/100 [04:42<03:08,  4.83s/it, loss=8.47]

Early stopping


(T):  62%|████████████████████████▊               | 62/100 [04:47<03:03,  4.83s/it, loss=8.19]

Early stopping


(T):  63%|█████████████████████████▏              | 63/100 [04:52<03:01,  4.92s/it, loss=10.3]

Early stopping


(T):  64%|█████████████████████████▌              | 64/100 [04:56<02:52,  4.79s/it, loss=22.1]

Early stopping


(T):  65%|██████████████████████████              | 65/100 [05:01<02:45,  4.72s/it, loss=19.4]

Early stopping


(T):  66%|██████████████████████████▍             | 66/100 [05:05<02:38,  4.66s/it, loss=13.3]

Early stopping


(T):  67%|██████████████████████████▊             | 67/100 [05:10<02:33,  4.64s/it, loss=13.6]

Early stopping


(T):  68%|███████████████████████████▏            | 68/100 [05:14<02:26,  4.58s/it, loss=12.9]

Early stopping


(T):  69%|███████████████████████████▌            | 69/100 [05:19<02:21,  4.58s/it, loss=12.3]

Early stopping


(T):  70%|████████████████████████████            | 70/100 [05:24<02:17,  4.58s/it, loss=7.17]

Early stopping


(T):  71%|████████████████████████████▍           | 71/100 [05:28<02:12,  4.56s/it, loss=14.9]

Early stopping


(T):  72%|████████████████████████████▊           | 72/100 [05:33<02:07,  4.55s/it, loss=15.5]

Early stopping


(T):  73%|█████████████████████████████▏          | 73/100 [05:37<02:04,  4.63s/it, loss=12.9]

Early stopping


(T):  74%|█████████████████████████████▌          | 74/100 [05:42<01:58,  4.56s/it, loss=14.2]

Early stopping


(T):  75%|██████████████████████████████          | 75/100 [05:46<01:54,  4.56s/it, loss=13.8]

Early stopping


(T):  76%|██████████████████████████████▍         | 76/100 [05:51<01:49,  4.58s/it, loss=9.44]

Early stopping


(T):  77%|██████████████████████████████▊         | 77/100 [05:56<01:46,  4.61s/it, loss=12.9]

Early stopping


(T):  78%|███████████████████████████████▏        | 78/100 [06:00<01:42,  4.68s/it, loss=9.04]

Early stopping


(T):  79%|███████████████████████████████▌        | 79/100 [06:05<01:37,  4.66s/it, loss=10.6]

Early stopping
Epoch: 80, Loss: 0.9314
Epoch: 80, Loss: 1.2710
Epoch: 80, Loss: 1.9919
Epoch: 80, Loss: 2.1255
Epoch: 80, Loss: 2.7446
Epoch: 80, Loss: 2.8742
Epoch: 80, Loss: 3.2143
Epoch: 80, Loss: 3.7770
Epoch: 80, Loss: 4.1744
Epoch: 80, Loss: 4.3178
Epoch: 80, Loss: 4.8850
Epoch: 80, Loss: 5.1040
Epoch: 80, Loss: 5.7503
Epoch: 80, Loss: 6.2849
Epoch: 80, Loss: 6.5430
Epoch: 80, Loss: 6.6875
Epoch: 80, Loss: 7.6747
Epoch: 80, Loss: 7.8482
Epoch: 80, Loss: 9.0617


(T):  80%|████████████████████████████████        | 80/100 [06:09<01:29,  4.48s/it, loss=10.8]

Epoch: 80, Loss: 10.4546
Epoch: 80, Loss: 10.7827
Early stopping


(T):  81%|████████████████████████████████▍       | 81/100 [06:14<01:24,  4.44s/it, loss=9.03]

Early stopping


(T):  82%|████████████████████████████████▊       | 82/100 [06:18<01:21,  4.51s/it, loss=14.1]

Early stopping


(T):  83%|█████████████████████████████████▏      | 83/100 [06:22<01:15,  4.44s/it, loss=10.1]

Early stopping


(T):  84%|█████████████████████████████████▌      | 84/100 [06:27<01:10,  4.41s/it, loss=12.6]

Early stopping


(T):  85%|██████████████████████████████████      | 85/100 [06:31<01:07,  4.47s/it, loss=9.15]

Early stopping


(T):  86%|██████████████████████████████████▍     | 86/100 [06:36<01:03,  4.51s/it, loss=9.86]

Early stopping


(T):  87%|██████████████████████████████████▊     | 87/100 [06:41<00:58,  4.52s/it, loss=9.81]

Early stopping


(T):  88%|███████████████████████████████████▏    | 88/100 [06:45<00:54,  4.57s/it, loss=10.4]

Early stopping


(T):  89%|███████████████████████████████████▌    | 89/100 [06:49<00:49,  4.46s/it, loss=12.2]

Early stopping


(T):  90%|████████████████████████████████████    | 90/100 [06:54<00:43,  4.39s/it, loss=10.3]

Early stopping


(T):  91%|████████████████████████████████████▍   | 91/100 [06:58<00:39,  4.36s/it, loss=8.49]

Early stopping


(T):  92%|████████████████████████████████████▊   | 92/100 [07:02<00:33,  4.22s/it, loss=9.63]

Early stopping


(T):  93%|██████████████████████████████████████▏  | 93/100 [07:06<00:30,  4.35s/it, loss=9.2]

Early stopping


(T):  94%|█████████████████████████████████████▌  | 94/100 [07:11<00:26,  4.45s/it, loss=7.07]

Early stopping


(T):  95%|██████████████████████████████████████  | 95/100 [07:16<00:22,  4.48s/it, loss=7.77]

Early stopping


(T):  96%|██████████████████████████████████████▍ | 96/100 [07:20<00:17,  4.49s/it, loss=11.1]

Early stopping


(T):  97%|██████████████████████████████████████▊ | 97/100 [07:24<00:13,  4.35s/it, loss=8.83]

Early stopping


(T):  98%|███████████████████████████████████████▏| 98/100 [07:29<00:08,  4.38s/it, loss=7.95]

Early stopping


(T):  99%|███████████████████████████████████████▌| 99/100 [07:33<00:04,  4.45s/it, loss=9.65]

Early stopping
Epoch: 100, Loss: 1.0972
Epoch: 100, Loss: 1.7667
Epoch: 100, Loss: 1.9132
Epoch: 100, Loss: 2.8157
Epoch: 100, Loss: 3.1807
Epoch: 100, Loss: 3.3105
Epoch: 100, Loss: 4.2527
Epoch: 100, Loss: 4.6620
Epoch: 100, Loss: 5.5781
Epoch: 100, Loss: 6.4165
Epoch: 100, Loss: 6.5938
Epoch: 100, Loss: 7.6808
Epoch: 100, Loss: 7.9153
Epoch: 100, Loss: 8.0672
Epoch: 100, Loss: 8.2872
Epoch: 100, Loss: 8.4384
Epoch: 100, Loss: 10.4190
Epoch: 100, Loss: 10.7538
Epoch: 100, Loss: 11.4345
Epoch: 100, Loss: 11.9728


(T): 100%|███████████████████████████████████████| 100/100 [07:37<00:00,  4.58s/it, loss=12.3]

Epoch: 100, Loss: 12.3317
Early stopping





In [24]:
test_result = test(encoder_model, dataloader, train_test_indices[0], device)

(LR): 100%|████████████████████████| 5000/5000 [00:01<00:00, best test F1Mi=0.786, F1Ma=0.771]
(LR): 100%|████████████████████████| 5000/5000 [00:00<00:00, best test F1Mi=0.759, F1Ma=0.742]
(LR): 100%|████████████████████████| 5000/5000 [00:00<00:00, best test F1Mi=0.759, F1Ma=0.742]
(LR): 100%|████████████████████████| 5000/5000 [00:01<00:00, best test F1Mi=0.759, F1Ma=0.756]
(LR): 100%|████████████████████████| 5000/5000 [00:00<00:00, best test F1Mi=0.759, F1Ma=0.746]
(LR): 100%|████████████████████████| 5000/5000 [00:01<00:00, best test F1Mi=0.768, F1Ma=0.759]
(LR): 100%|████████████████████████| 5000/5000 [00:00<00:00, best test F1Mi=0.759, F1Ma=0.755]
(LR): 100%|████████████████████████| 5000/5000 [00:01<00:00, best test F1Mi=0.786, F1Ma=0.774]
(LR): 100%|████████████████████████| 5000/5000 [00:01<00:00, best test F1Mi=0.804, F1Ma=0.786]
(LR): 100%|████████████████████████| 5000/5000 [00:00<00:00, best test F1Mi=0.786, F1Ma=0.775]


In [25]:
micro_f1_values = [entry['micro_f1']*100 for entry in test_result]

In [26]:
np_micro_f1_values = np.array(micro_f1_values)
micro_f1_mean = np.mean(np_micro_f1_values)
micro_f1_std = np.std(np_micro_f1_values)
uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(np_micro_f1_values, func=np.mean, 
                                                                      n_boot=1000), 95) - np_micro_f1_values.mean()))

print(f'test acc mean = {micro_f1_mean:.4f} ± {micro_f1_std:.4f}')

test acc mean = 77.2321 ± 1.5593
