In [1]:
import argparse
import os
import os.path as osp
import shutil
import time
from itertools import product
from json import dumps
import random
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch.optim import Adam
from torch_geometric.loader import DataListLoader, DataLoader
from torch_geometric.nn import DataParallel
from torch_geometric.seed import seed_everything
from torch_geometric.utils.convert import to_networkx
from torch_geometric.data import DataLoader, Data
from torch import nn
from tqdm import tqdm

from pytorch_metric_learning.losses import NTXentLoss, VICRegLoss
from sklearn.model_selection import GridSearchCV, StratifiedKFold

import train_utils
from data_utils import extract_edge_attributes
from torch_geometric.datasets import TUDataset
from layers.input_encoder import LinearEncoder, LinearEdgeEncoder
from layers.layer_utils import make_gnn_layer
from models.GraphClassification import GraphClassification
from models.model_utils import make_GNN

import GCL.losses as L
import GCL.augmentors as A
from GCL.eval import get_split, SVMEvaluator, LREvaluator
from GCL.models import DualBranchContrast

import warnings
warnings.filterwarnings('ignore')

In [2]:
parser = argparse.ArgumentParser(f'arguments for training and testing')
parser.add_argument('--save_dir', type=str, default='./save', help='Base directory for saving information.')
parser.add_argument('--seed', type=int, default=2, help='Random seed for reproducibility.')
parser.add_argument('--dataset_name', type=str, default="COLLAB",
                    choices=("MUTAG", "PROTEINS", "PTC_MR", "IMDBBINARY", "COLLAB"), help='Name of dataset')
parser.add_argument('--drop_prob', type=float, default=0.5,
                    help='Probability of zeroing an activation in dropout layers.') 
parser.add_argument('--batch_size', type=int, default=32, help='Batch size per GPU. Scales automatically when \
                        multiple GPUs are available.')
parser.add_argument("--parallel", action="store_true",
                    help="If true, use DataParallel for multi-gpu training")
parser.add_argument('--num_workers', type=int, default=0, help='Number of worker.')
parser.add_argument('--load_path', type=str, default=None, help='Path to load as a model checkpoint.')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.')
parser.add_argument('--l2_wd', type=float, default=3e-3, help='L2 weight decay.')
parser.add_argument('--num_epochs', type=int, default=100, help='Number of epochs.')
parser.add_argument("--hidden_size", type=int, default=128, help="Hidden size of the model")
parser.add_argument("--embedding_size", type=int, default=16, help="Output size of the logistic model")
parser.add_argument("--model_name", type=str, default="KHopGNNConv",
                    choices=("KHopGNNConv"), help="Base GNN model")
parser.add_argument("--K", type=int, default=2, help="Number of hop to consider")
parser.add_argument("--num_layer", type=int, default=2, help="Number of layer for feature encoder")
parser.add_argument("--JK", type=str, default="sum", choices=("sum", "max", "mean", "attention", "last", "concat"),
                    help="Jumping knowledge method")
parser.add_argument("--residual", default=True, action="store_true", help="If true, use residual connection between each layer")
parser.add_argument("--virtual_node", action="store_true", default=False, 
                    help="If true, add virtual node information in each layer")
parser.add_argument("--eps", type=float, default=0., help="Initial epsilon in GIN")
parser.add_argument("--train_eps", action="store_true", help="If true, the epsilon in GIN model is trainable")
parser.add_argument("--combine", type=str, default="geometric", choices=("attention", "geometric"),
                    help="Combine method in k-hop aggregation")
parser.add_argument("--pooling_method", type=str, default="sum", choices=("mean", "sum", "attention"),
                    help="Pooling method in graph classification")
parser.add_argument('--norm_type', type=str, default="Batch",
                    choices=("Batch", "Layer", "Instance", "GraphSize", "Pair"),
                    help="Normalization method in model")
parser.add_argument('--aggr', type=str, default="add",
                    help='Aggregation method in GNN layer, only works in GraphSAGE')
parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
parser.add_argument('--factor', type=float, default=0.5, help='Factor for reducing learning rate scheduler')
parser.add_argument('--reprocess', action="store_true", help='If true, reprocess the dataset')
parser.add_argument('--search', action="store_true", help='If true, search hyper-parameters')
parser.add_argument("--pos_enc_dim", type=int, default=6, help="Initial positional dim.")
parser.add_argument("--pos_attr", type=bool, default=False, help="Positional attributes.")
parser.add_argument("--feature_augmentation", type=bool, default=True, help="If true, feature augmentation.")

_StoreAction(option_strings=['--feature_augmentation'], dest='feature_augmentation', nargs=None, const=None, default=True, type=<class 'bool'>, choices=None, required=False, help='If true, feature augmentation.', metavar=None)

In [3]:
args = parser.parse_args("")

In [4]:
args.name = args.model_name + "_" + str(args.K) + "_" + str(args.search)

In [5]:
torch.set_num_threads(1)

In [6]:
# Set up logging and devices
args.save_dir = train_utils.get_save_dir(args.save_dir, args.name, type=args.dataset_name)
log = train_utils.get_logger(args.save_dir, args.name)
device, args.gpu_ids = train_utils.get_available_devices()

In [7]:
if len(args.gpu_ids) > 1 and args.parallel:
    log.info(f'Using multi-gpu training')
    args.parallel = True
    loader = DataListLoader
    args.batch_size *= max(1, len(args.gpu_ids))
else:
    log.info(f'Using single-gpu training')
    args.parallel = False
    loader = DataLoader

[02.14.25 09:14:56] Using single-gpu training


In [8]:
# Set random seed
seed = args.seed
log.info(f'Using random seed {seed}...')
seed_everything(seed)

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

[02.14.25 09:14:56] Using random seed 2...


In [9]:
def edge_feature_transform(g):
    return extract_edge_attributes(g, args.pos_enc_dim)

In [10]:
tag = str(int(time.time()))

In [11]:
def get_model(args):
    layer = make_gnn_layer(args)
    init_emb = LinearEncoder(args.input_size, args.hidden_size, pos_attr=args.pos_attr)
    init_edge_attr_emb = LinearEdgeEncoder(args.edge_attr_size, args.hidden_size, edge_attr=True)
    init_edge_attr_v2_emb = LinearEdgeEncoder(args.edge_attr_v2_size, args.hidden_size, edge_attr=False)
    
    GNNModel = make_GNN(args)
    
    gnn = GNNModel(
        num_layer=args.num_layer,
        gnn_layer=layer,
        JK=args.JK,
        norm_type=args.norm_type,
        init_emb=init_emb,
        init_edge_attr_emb=init_edge_attr_emb,
        init_edge_attr_v2_emb=init_edge_attr_v2_emb,
        residual=args.residual,
        virtual_node=args.virtual_node,
        drop_prob=args.drop_prob)

    model = GraphClassification(embedding_model=gnn,
                                pooling_method=args.pooling_method,
                                output_size=args.output_size)
    
    model.reset_parameters()

    if args.parallel:
        model = DataParallel(model, args.gpu_ids)
    return model

In [12]:
class Projection(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Projection, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU()
        )
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x) + self.linear(x)

In [13]:
def vicreg_loss(embeddings1, embeddings2, lambda_var=25, lambda_cov=25, mu=1):
    """
    Calculate the VICReg loss between two sets of embeddings with the correct handling of covariance differences.
    
    Args:
    embeddings1, embeddings2 (torch.Tensor): Embeddings from two views, shape (batch_size, feature_dim).
    lambda_var (float): Coefficient for the variance loss.
    lambda_cov (float): Coefficient for the covariance loss.
    mu (float): Coefficient for the invariance loss.
    
    Returns:
    torch.Tensor: The total VICReg loss.
    """
    # Invariance Loss
    invariance_loss = F.mse_loss(embeddings1, embeddings2)

    # Variance Loss
    def variance_loss(embeddings1, embeddings2):
        mean_embeddings1 = embeddings1.mean(dim=0)
        mean_embeddings2 = embeddings2.mean(dim=0)
        
        std_dev1 = torch.sqrt((embeddings1 - mean_embeddings1).var(dim=0) + 1e-4)
        std_dev2 = torch.sqrt((embeddings2 - mean_embeddings2).var(dim=0) + 1e-4)
        
        return torch.mean(torch.abs(F.relu(1 - std_dev1) - F.relu(1 - std_dev2)))

    variance_loss_value = variance_loss(embeddings1, embeddings2)

    # Covariance Loss
    def covariance_loss(embeddings1, embeddings2):
        batch_size, feature_dim = embeddings1.size()
        
        embeddings_centered1 = embeddings1 - embeddings1.mean(dim=0)
        embeddings_centered2 = embeddings2 - embeddings2.mean(dim=0)
        
        covariance_matrix1 = torch.matmul(embeddings_centered1.T, embeddings_centered1) / (batch_size - 1)
        covariance_matrix2 = torch.matmul(embeddings_centered2.T, embeddings_centered2) / (batch_size - 1)
        
        covariance_matrix1.fill_diagonal_(0)
        covariance_matrix2.fill_diagonal_(0)
        
        cov_diff = torch.abs(covariance_matrix1.pow(2) - covariance_matrix2.pow(2))
        return cov_diff.sum() / feature_dim

    covariance_loss_value = covariance_loss(embeddings1, embeddings2)

    total_loss = mu * invariance_loss + lambda_var * variance_loss_value + lambda_cov * covariance_loss_value
    
    return total_loss

In [14]:
class Encoder(torch.nn.Module):
    def __init__(self, model_1, model_2, mlp1, mlp2, aug1, aug2):
        super(Encoder, self).__init__()
        self.model_1 = model_1
        self.model_2 = model_2
        self.mlp1 = mlp1
        self.mlp2 = mlp2
        self.aug1 = aug1
        self.aug2 = aug2
        
    def get_embedding(self, data):
        z, g = self.model_1(data)
        z_pos, g_pos = self.model_2(data)
        
        z = self.mlp1(z)
        g = self.mlp2(g)
        
        z_pos = self.mlp1(z_pos)
        g_pos = self.mlp2(g_pos)

        g = torch.cat((g, g_pos), 1)
        z = torch.cat((z, z_pos), 1)

        return g.detach(), z.detach()

    def forward(self, data):
        data1 = self.aug1(data.x, data.edge_index, data.y, data.pos, data.edge_attr,
                          data.edge_attr_v2, data.batch, data.ptr)
        data2 = self.aug2(data.x, data.edge_index, data.y, data.pos, data.edge_attr,
                          data.edge_attr_v2, data.batch, data.ptr)
        
        # Structural features
        z1, g1 = self.model_1(data1)
        z2, g2 = self.model_1(data2)
        
        # Positional features
        z1_pos, g1_pos = self.model_2(data1)
        z2_pos, g2_pos = self.model_2(data2)
        
        h1, h2 = [self.mlp1(h) for h in [z1, z2]]
        g1, g2 = [self.mlp2(g) for g in [g1, g2]]
        
        h1_pos, h2_pos = [self.mlp1(h_pos) for h_pos in [z1_pos, z2_pos]]
        g1_pos, g2_pos = [self.mlp2(g_pos) for g_pos in [g1_pos, g2_pos]]
        
        h1 = torch.cat((h1, h1_pos), 1)
        h2 = torch.cat((h2, h2_pos), 1)
        g1 = torch.cat((g1, g1_pos), 1)
        g2 = torch.cat((g2, g2_pos), 1)
        
        return h1, h2, g1, g2

In [15]:
def train(encoder_model, dataloader, optimizer, device):
    best = float("inf")
    cnt_wait = 0
    best_t = 0
    
    loss_func = NTXentLoss(temperature=0.10)
    
    encoder_model.train()
    epoch_loss = 0
    for data in dataloader:
        data = data.to(device)
        optimizer.zero_grad()

        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)

        h1, h2, g1, g2 = encoder_model(data)
        
        embeddings = torch.cat((g1, g2))
        
        # The same index corresponds to a positive pair
        indices = torch.arange(0, g1.size(0), device=device)
        labels = torch.cat((indices, indices))
        
        reg_loss = vicreg_loss(h1, h2, lambda_var=24, lambda_cov=24, mu=1)
        loss = loss_func(embeddings, labels) + 0.005*reg_loss
        
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        
        if epoch % 20 == 0:
            print("Epoch: {0}, Loss: {1:0.4f}".format(epoch, epoch_loss))

        if epoch_loss < best:
            best = epoch_loss
            best_t = epoch
            cnt_wait = 0
            torch.save(encoder_model.state_dict(), './pkl/best_model_'+ args.dataset_name + tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            print("Early stopping")
            break
            
    return epoch_loss

In [16]:
def test(encoder_model, dataloader, seeds, device):
    encoder_model.eval()
    x = []
    y = []
    for data in dataloader:
        data = data.to(device)
        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)
        graph_embeds, node_embeds = encoder_model.get_embedding(data)
        x.append(graph_embeds)
        y.append(data.y)
    x = torch.cat(x, dim=0)
    y = torch.cat(y, dim=0)

    result = []
        
    random.shuffle(seeds)
    seeds = seeds.tolist()
    for _ in np.arange(10):
        random_seed = seeds.pop()
        split = get_split(num_samples=x.size()[0], train_ratio=0.1, test_ratio=0.8, seed=random_seed)
        result.append(LREvaluator()(x, y, split))
    
    return result

In [17]:
path = "./data/data_splits/" + args.dataset_name

In [18]:
train_indices, test_indices, train_test_indices = [], [], []
for i in range(10):
    train_filename = os.path.join(path, '10fold_idx', 'train_idx-{}.txt'.format(i + 1))
    test_filename = os.path.join(path, '10fold_idx', 'test_idx-{}.txt'.format(i + 1))
    train_indices.append(torch.from_numpy(np.loadtxt(train_filename, dtype=int)).to(torch.long))
    test_indices.append(torch.from_numpy(np.loadtxt(test_filename, dtype=int)).to(torch.long))

if args.feature_augmentation:
    train_test_filename = os.path.join(path, '10fold_idx', 'train_idx-{}.txt'.format(i + 2))
    train_test_indices.append(torch.from_numpy(np.loadtxt(train_test_filename, dtype=int)).to(torch.long))
else:
    train_test_filename = os.path.join(path, '10fold_idx', 'train_idx-{}.txt'.format(i + 3))
    train_test_indices.append(torch.from_numpy(np.loadtxt(train_test_filename, dtype=int)).to(torch.long))

In [19]:
dataset = TUDataset(root='./data/'+args.dataset_name, name=args.dataset_name, 
                    pre_transform=T.Compose([edge_feature_transform]))

In [20]:
# Determine the maximum degree in the dataset
max_degree = max([data.num_nodes for data in dataset])

# Apply the OneHotDegree transform
dataset.transform = T.OneHotDegree(max_degree)

In [21]:
classes = torch.unique(dataset.y)
args.n_classes = len(classes)

In [22]:
args.input_size = dataset.num_node_features
args.pos_size = args.pos_enc_dim
args.output_size = args.hidden_size
args.edge_attr_size =  dataset.edge_attr.shape[1]
args.edge_attr_v2_size =  dataset.edge_attr_v2.shape[1]

In [23]:
aug1 = A.Identity()

if args.feature_augmentation:
    # Feature Augmentation
    aug2 = A.RandomChoice([A.FeatureDropout(pf=0.1),
                           A.FeatureMasking(pf=0.1),
                           A.EdgeAttrMasking(pf=0.1)], 1)
else:
    # Structure Augmentation
    aug2 = A.RandomChoice([A.RWSampling(num_seeds=1000, walk_length=10),
                       A.NodeDropping(pn=0.1),
                       A.EdgeRemoving(pe=0.1)], 1)

model_1 = get_model(args)
model_1.to(device)

args.pos_attr = True
args.input_size = args.pos_size
model_2 = get_model(args)
model_2.to(device)

mlp1 = Projection(input_dim=args.hidden_size, output_dim=args.hidden_size)
mlp2 = Projection(input_dim=args.hidden_size, output_dim=args.hidden_size)

encoder_model = Encoder(model_1=model_1, model_2=model_2, mlp1=mlp1, mlp2=mlp2, aug1=aug1, aug2=aug2).to(device)

In [24]:
optimizer = Adam(encoder_model.parameters(), lr=args.lr, weight_decay=args.l2_wd)
dataloader = DataLoader(dataset, batch_size=args.batch_size)
with tqdm(total=100, desc='(T)') as pbar:
    for epoch in range(1, 101):
        loss = train(encoder_model, dataloader, optimizer, device)
        pbar.set_postfix({'loss': loss})
        pbar.update()

(T):   1%|▍                                      | 1/100 [00:39<1:04:29, 39.08s/it, loss=27.7]

Early stopping


(T):   2%|▊                                        | 2/100 [01:10<56:42, 34.72s/it, loss=17.5]

Early stopping


(T):   3%|█▏                                       | 3/100 [01:39<51:25, 31.80s/it, loss=15.6]

Early stopping


(T):   4%|█▋                                       | 4/100 [02:07<48:25, 30.27s/it, loss=16.4]

Early stopping


(T):   5%|██                                       | 5/100 [02:34<46:36, 29.44s/it, loss=15.6]

Early stopping


(T):   6%|██▍                                      | 6/100 [03:02<45:11, 28.84s/it, loss=13.4]

Early stopping


(T):   7%|██▊                                      | 7/100 [03:30<44:08, 28.48s/it, loss=15.1]

Early stopping


(T):   8%|███▎                                     | 8/100 [03:58<43:18, 28.24s/it, loss=11.3]

Early stopping


(T):   9%|███▋                                     | 9/100 [04:25<42:34, 28.07s/it, loss=9.53]

Early stopping


(T):  10%|████                                    | 10/100 [04:53<41:59, 27.99s/it, loss=7.86]

Early stopping


(T):  11%|████▍                                   | 11/100 [05:24<42:48, 28.85s/it, loss=8.37]

Early stopping


(T):  12%|████▊                                   | 12/100 [06:00<45:35, 31.09s/it, loss=8.38]

Early stopping


(T):  13%|█████▏                                  | 13/100 [06:34<46:27, 32.04s/it, loss=11.2]

Early stopping


(T):  14%|█████▌                                  | 14/100 [07:10<47:31, 33.15s/it, loss=9.09]

Early stopping


(T):  15%|██████                                  | 15/100 [07:44<47:24, 33.46s/it, loss=7.12]

Early stopping


(T):  16%|██████▍                                 | 16/100 [08:18<47:07, 33.66s/it, loss=6.52]

Early stopping


(T):  17%|██████▊                                 | 17/100 [08:52<46:37, 33.71s/it, loss=7.71]

Early stopping


(T):  18%|███████▏                                | 18/100 [09:24<45:25, 33.24s/it, loss=6.98]

Early stopping


(T):  19%|███████▊                                 | 19/100 [09:53<42:52, 31.76s/it, loss=6.2]

Early stopping
Epoch: 20, Loss: 0.2806
Epoch: 20, Loss: 0.6440
Epoch: 20, Loss: 0.9793
Epoch: 20, Loss: 1.5018
Epoch: 20, Loss: 1.8629
Epoch: 20, Loss: 2.1796
Epoch: 20, Loss: 2.4085
Epoch: 20, Loss: 2.6465
Epoch: 20, Loss: 2.9136
Epoch: 20, Loss: 3.0658
Epoch: 20, Loss: 3.2338
Epoch: 20, Loss: 3.3626
Epoch: 20, Loss: 3.5782
Epoch: 20, Loss: 3.7291
Epoch: 20, Loss: 3.9097
Epoch: 20, Loss: 4.1138
Epoch: 20, Loss: 4.5478
Epoch: 20, Loss: 5.0096
Epoch: 20, Loss: 5.1490
Epoch: 20, Loss: 5.6244


(T):  20%|████████                                | 20/100 [10:19<40:14, 30.18s/it, loss=5.84]

Epoch: 20, Loss: 5.8355
Early stopping


(T):  21%|████████▍                               | 21/100 [10:45<38:01, 28.88s/it, loss=5.19]

Early stopping


(T):  22%|████████▊                               | 22/100 [11:11<36:17, 27.92s/it, loss=5.56]

Early stopping


(T):  23%|█████████▏                              | 23/100 [11:42<37:16, 29.05s/it, loss=5.93]

Early stopping


(T):  24%|█████████▌                              | 24/100 [12:17<38:54, 30.72s/it, loss=4.64]

Early stopping


(T):  25%|██████████                              | 25/100 [12:52<40:06, 32.09s/it, loss=4.69]

Early stopping


(T):  26%|██████████▍                             | 26/100 [13:27<40:43, 33.01s/it, loss=5.04]

Early stopping


(T):  27%|███████████                              | 27/100 [14:01<40:22, 33.19s/it, loss=4.6]

Early stopping


(T):  28%|███████████▏                            | 28/100 [14:36<40:31, 33.76s/it, loss=7.86]

Early stopping


(T):  29%|███████████▌                            | 29/100 [15:11<40:18, 34.06s/it, loss=7.72]

Early stopping


(T):  30%|████████████                            | 30/100 [15:46<39:58, 34.27s/it, loss=7.05]

Early stopping


(T):  31%|████████████▍                           | 31/100 [16:20<39:23, 34.26s/it, loss=6.52]

Early stopping


(T):  32%|████████████▊                           | 32/100 [16:52<38:02, 33.57s/it, loss=5.28]

Early stopping


(T):  33%|█████████████▏                          | 33/100 [17:20<35:45, 32.03s/it, loss=6.98]

Early stopping


(T):  34%|█████████████▉                           | 34/100 [17:48<33:57, 30.87s/it, loss=5.8]

Early stopping


(T):  35%|██████████████                          | 35/100 [18:17<32:31, 30.02s/it, loss=5.64]

Early stopping


(T):  36%|██████████████▍                         | 36/100 [18:44<31:15, 29.31s/it, loss=7.13]

Early stopping


(T):  37%|██████████████▊                         | 37/100 [19:12<30:09, 28.73s/it, loss=5.86]

Early stopping


(T):  38%|███████████████▏                        | 38/100 [19:39<29:11, 28.25s/it, loss=5.66]

Early stopping


(T):  39%|███████████████▌                        | 39/100 [20:05<28:14, 27.79s/it, loss=6.71]

Early stopping
Epoch: 40, Loss: 0.3080
Epoch: 40, Loss: 0.6871
Epoch: 40, Loss: 0.8316
Epoch: 40, Loss: 1.3723
Epoch: 40, Loss: 1.6913
Epoch: 40, Loss: 1.9221
Epoch: 40, Loss: 2.1981
Epoch: 40, Loss: 2.4949
Epoch: 40, Loss: 2.6755
Epoch: 40, Loss: 2.8083
Epoch: 40, Loss: 3.1045
Epoch: 40, Loss: 3.2185
Epoch: 40, Loss: 3.4709
Epoch: 40, Loss: 3.6751
Epoch: 40, Loss: 3.8145
Epoch: 40, Loss: 4.0001
Epoch: 40, Loss: 4.3805
Epoch: 40, Loss: 4.9088
Epoch: 40, Loss: 5.2007
Epoch: 40, Loss: 5.4482


(T):  40%|████████████████                        | 40/100 [20:32<27:27, 27.46s/it, loss=5.58]

Epoch: 40, Loss: 5.5761
Early stopping


(T):  41%|████████████████▍                       | 41/100 [20:57<26:18, 26.76s/it, loss=5.61]

Early stopping


(T):  42%|████████████████▊                       | 42/100 [21:20<24:49, 25.68s/it, loss=4.66]

Early stopping


(T):  43%|█████████████████▏                      | 43/100 [21:43<23:34, 24.81s/it, loss=5.31]

Early stopping


(T):  44%|█████████████████▌                      | 44/100 [22:06<22:35, 24.21s/it, loss=4.14]

Early stopping


(T):  45%|██████████████████▍                      | 45/100 [22:28<21:39, 23.62s/it, loss=4.6]

Early stopping


(T):  46%|██████████████████▍                     | 46/100 [22:51<20:58, 23.31s/it, loss=5.76]

Early stopping


(T):  47%|██████████████████▊                     | 47/100 [23:13<20:20, 23.04s/it, loss=4.91]

Early stopping


(T):  48%|███████████████████▏                    | 48/100 [23:36<19:51, 22.92s/it, loss=5.54]

Early stopping


(T):  49%|███████████████████▌                    | 49/100 [23:59<19:26, 22.86s/it, loss=5.46]

Early stopping


(T):  50%|████████████████████                    | 50/100 [24:22<19:14, 23.10s/it, loss=4.77]

Early stopping


(T):  51%|████████████████████▍                   | 51/100 [24:45<18:48, 23.04s/it, loss=4.65]

Early stopping


(T):  52%|████████████████████▊                   | 52/100 [25:08<18:24, 23.02s/it, loss=4.29]

Early stopping


(T):  53%|█████████████████████▏                  | 53/100 [25:31<17:54, 22.87s/it, loss=5.11]

Early stopping


(T):  54%|█████████████████████▌                  | 54/100 [25:53<17:26, 22.76s/it, loss=6.01]

Early stopping


(T):  55%|██████████████████████                  | 55/100 [26:15<16:57, 22.62s/it, loss=4.06]

Early stopping


(T):  56%|██████████████████████▍                 | 56/100 [26:38<16:34, 22.61s/it, loss=3.87]

Early stopping


(T):  57%|███████████████████████▎                 | 57/100 [27:01<16:11, 22.59s/it, loss=5.2]

Early stopping


(T):  58%|███████████████████████▏                | 58/100 [27:23<15:46, 22.53s/it, loss=4.09]

Early stopping


(T):  59%|███████████████████████▌                | 59/100 [27:45<15:22, 22.51s/it, loss=6.15]

Early stopping
Epoch: 60, Loss: 0.2079
Epoch: 60, Loss: 0.5961
Epoch: 60, Loss: 0.8339
Epoch: 60, Loss: 1.1812
Epoch: 60, Loss: 1.3866
Epoch: 60, Loss: 1.9315
Epoch: 60, Loss: 2.1323
Epoch: 60, Loss: 2.4042
Epoch: 60, Loss: 2.6347
Epoch: 60, Loss: 2.8201
Epoch: 60, Loss: 2.9954
Epoch: 60, Loss: 3.1037
Epoch: 60, Loss: 3.1847
Epoch: 60, Loss: 3.3417
Epoch: 60, Loss: 3.9948
Epoch: 60, Loss: 4.1320
Epoch: 60, Loss: 4.3713
Epoch: 60, Loss: 4.6619
Epoch: 60, Loss: 4.8751
Epoch: 60, Loss: 4.9994


(T):  60%|████████████████████████                | 60/100 [28:08<14:57, 22.43s/it, loss=5.13]

Epoch: 60, Loss: 5.1256
Early stopping


(T):  61%|████████████████████████▍               | 61/100 [28:31<14:42, 22.62s/it, loss=6.07]

Early stopping


(T):  62%|████████████████████████▊               | 62/100 [28:53<14:16, 22.54s/it, loss=4.65]

Early stopping


(T):  63%|█████████████████████████▊               | 63/100 [29:15<13:51, 22.48s/it, loss=4.2]

Early stopping


(T):  64%|█████████████████████████▌              | 64/100 [29:38<13:29, 22.48s/it, loss=3.89]

Early stopping


(T):  65%|██████████████████████████              | 65/100 [30:01<13:14, 22.70s/it, loss=3.84]

Early stopping


(T):  66%|██████████████████████████▍             | 66/100 [30:24<12:53, 22.76s/it, loss=4.06]

Early stopping


(T):  67%|██████████████████████████▊             | 67/100 [30:47<12:31, 22.77s/it, loss=3.73]

Early stopping


(T):  68%|███████████████████████████▏            | 68/100 [31:10<12:09, 22.80s/it, loss=4.31]

Early stopping


(T):  69%|████████████████████████████▎            | 69/100 [31:33<11:50, 22.93s/it, loss=4.5]

Early stopping


(T):  70%|████████████████████████████            | 70/100 [31:56<11:25, 22.86s/it, loss=3.88]

Early stopping


(T):  71%|████████████████████████████▍           | 71/100 [32:19<11:03, 22.88s/it, loss=3.79]

Early stopping


(T):  72%|████████████████████████████▊           | 72/100 [32:41<10:36, 22.75s/it, loss=3.61]

Early stopping


(T):  73%|█████████████████████████████▏          | 73/100 [33:04<10:14, 22.75s/it, loss=4.27]

Early stopping


(T):  74%|█████████████████████████████▌          | 74/100 [33:27<09:54, 22.86s/it, loss=4.33]

Early stopping


(T):  75%|██████████████████████████████          | 75/100 [33:49<09:27, 22.70s/it, loss=3.86]

Early stopping


(T):  76%|██████████████████████████████▍         | 76/100 [34:12<09:02, 22.62s/it, loss=4.44]

Early stopping


(T):  77%|██████████████████████████████▊         | 77/100 [34:34<08:37, 22.49s/it, loss=4.02]

Early stopping


(T):  78%|███████████████████████████████▏        | 78/100 [34:56<08:12, 22.40s/it, loss=5.18]

Early stopping


(T):  79%|███████████████████████████████▌        | 79/100 [35:18<07:50, 22.43s/it, loss=4.28]

Early stopping
Epoch: 80, Loss: 0.1696
Epoch: 80, Loss: 0.3684
Epoch: 80, Loss: 0.6747
Epoch: 80, Loss: 0.9490
Epoch: 80, Loss: 1.3546
Epoch: 80, Loss: 1.7762
Epoch: 80, Loss: 1.8769
Epoch: 80, Loss: 2.0354
Epoch: 80, Loss: 2.1360
Epoch: 80, Loss: 2.4165
Epoch: 80, Loss: 2.5276
Epoch: 80, Loss: 2.6570
Epoch: 80, Loss: 2.9548
Epoch: 80, Loss: 3.0930
Epoch: 80, Loss: 3.3046
Epoch: 80, Loss: 3.5717
Epoch: 80, Loss: 3.7572
Epoch: 80, Loss: 4.1012
Epoch: 80, Loss: 4.2687
Epoch: 80, Loss: 4.3436


(T):  80%|████████████████████████████████        | 80/100 [35:41<07:26, 22.34s/it, loss=4.44]

Epoch: 80, Loss: 4.4446
Early stopping


(T):  81%|████████████████████████████████▍       | 81/100 [36:03<07:03, 22.31s/it, loss=4.23]

Early stopping


(T):  82%|████████████████████████████████▊       | 82/100 [36:25<06:41, 22.32s/it, loss=3.02]

Early stopping


(T):  83%|██████████████████████████████████       | 83/100 [36:47<06:19, 22.31s/it, loss=4.2]

Early stopping


(T):  84%|█████████████████████████████████▌      | 84/100 [37:10<05:57, 22.33s/it, loss=4.75]

Early stopping


(T):  85%|██████████████████████████████████      | 85/100 [37:32<05:35, 22.34s/it, loss=4.08]

Early stopping


(T):  86%|███████████████████████████████████▎     | 86/100 [37:55<05:12, 22.35s/it, loss=4.9]

Early stopping


(T):  87%|██████████████████████████████████▊     | 87/100 [38:17<04:51, 22.46s/it, loss=5.32]

Early stopping


(T):  88%|███████████████████████████████████▏    | 88/100 [38:40<04:30, 22.51s/it, loss=5.59]

Early stopping


(T):  89%|████████████████████████████████████▍    | 89/100 [39:02<04:07, 22.49s/it, loss=5.2]

Early stopping


(T):  90%|████████████████████████████████████    | 90/100 [39:25<03:44, 22.48s/it, loss=3.58]

Early stopping


(T):  91%|████████████████████████████████████▍   | 91/100 [39:47<03:21, 22.44s/it, loss=3.88]

Early stopping


(T):  92%|████████████████████████████████████▊   | 92/100 [40:10<02:59, 22.46s/it, loss=3.76]

Early stopping


(T):  93%|█████████████████████████████████████▏  | 93/100 [40:32<02:37, 22.54s/it, loss=3.63]

Early stopping


(T):  94%|█████████████████████████████████████▌  | 94/100 [40:55<02:14, 22.45s/it, loss=4.07]

Early stopping


(T):  95%|██████████████████████████████████████  | 95/100 [41:17<01:52, 22.48s/it, loss=3.73]

Early stopping


(T):  96%|██████████████████████████████████████▍ | 96/100 [41:39<01:29, 22.42s/it, loss=4.48]

Early stopping


(T):  97%|██████████████████████████████████████▊ | 97/100 [42:02<01:07, 22.44s/it, loss=4.29]

Early stopping


(T):  98%|███████████████████████████████████████▏| 98/100 [42:24<00:44, 22.39s/it, loss=3.32]

Early stopping


(T):  99%|███████████████████████████████████████▌| 99/100 [42:47<00:22, 22.42s/it, loss=3.19]

Early stopping
Epoch: 100, Loss: 0.1253
Epoch: 100, Loss: 0.3703
Epoch: 100, Loss: 0.4330
Epoch: 100, Loss: 0.5695
Epoch: 100, Loss: 0.7208
Epoch: 100, Loss: 0.8521
Epoch: 100, Loss: 0.9192
Epoch: 100, Loss: 1.0250
Epoch: 100, Loss: 1.1582
Epoch: 100, Loss: 1.2172
Epoch: 100, Loss: 1.4543
Epoch: 100, Loss: 1.6082
Epoch: 100, Loss: 1.7354
Epoch: 100, Loss: 1.7981
Epoch: 100, Loss: 2.0492
Epoch: 100, Loss: 2.2590
Epoch: 100, Loss: 2.4072
Epoch: 100, Loss: 2.5505
Epoch: 100, Loss: 2.7144
Epoch: 100, Loss: 2.8168


(T): 100%|███████████████████████████████████████| 100/100 [43:09<00:00, 25.90s/it, loss=3.01]

Epoch: 100, Loss: 3.0101
Early stopping





In [None]:
test_result = test(encoder_model, dataloader, train_test_indices[0], device)

In [None]:
micro_f1_values = [entry['micro_f1']*100 for entry in test_result]

In [None]:
np_micro_f1_values = np.array(micro_f1_values)
micro_f1_mean = np.mean(np_micro_f1_values)
micro_f1_std = np.std(np_micro_f1_values)

print(f'test acc mean = {micro_f1_mean:.4f} ± {micro_f1_std:.4f}')