In [1]:
import argparse
import os
import os.path as osp
import shutil
import time
from itertools import product
from json import dumps
import random
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch.optim import Adam
from torch_geometric.loader import DataListLoader, DataLoader
from torch_geometric.nn import DataParallel
from torch_geometric.seed import seed_everything
from torch_geometric.utils.convert import to_networkx
from torch_geometric.data import DataLoader, Data
from torch import nn
from tqdm import tqdm

from pytorch_metric_learning.losses import NTXentLoss, VICRegLoss
from sklearn.model_selection import GridSearchCV, StratifiedKFold

import train_utils
from data_utils import extract_edge_attributes
from torch_geometric.datasets import TUDataset
from layers.input_encoder import LinearEncoder, LinearEdgeEncoder
from layers.layer_utils import make_gnn_layer
from models.GraphClassification import GraphClassification
from models.model_utils import make_GNN

import GCL.losses as L
import GCL.augmentors as A
from GCL.eval import get_split, SVMEvaluator, LREvaluator
from GCL.models import DualBranchContrast

import warnings
warnings.filterwarnings('ignore')

In [2]:
# Feature Augmentation
# parser = argparse.ArgumentParser(f'arguments for training and testing')
# parser.add_argument('--save_dir', type=str, default='./save', help='Base directory for saving information.')
# parser.add_argument('--seed', type=int, default=2, help='Random seed for reproducibility.')
# parser.add_argument('--dataset_name', type=str, default="COLLAB",
#                     choices=("MUTAG", "PROTEINS", "PTC_MR", "IMDBBINARY", "COLLAB"), help='Name of dataset')
# parser.add_argument('--drop_prob', type=float, default=0.5,
#                     help='Probability of zeroing an activation in dropout layers.') 
# parser.add_argument('--batch_size', type=int, default=32, help='Batch size per GPU. Scales automatically when \
#                         multiple GPUs are available.')
# parser.add_argument("--parallel", action="store_true",
#                     help="If true, use DataParallel for multi-gpu training")
# parser.add_argument('--num_workers', type=int, default=0, help='Number of worker.')
# parser.add_argument('--load_path', type=str, default=None, help='Path to load as a model checkpoint.')
# parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.')
# parser.add_argument('--l2_wd', type=float, default=3e-3, help='L2 weight decay.')
# parser.add_argument('--num_epochs', type=int, default=100, help='Number of epochs.')
# parser.add_argument("--hidden_size", type=int, default=128, help="Hidden size of the model")
# parser.add_argument("--embedding_size", type=int, default=16, help="Output size of the logistic model")
# parser.add_argument("--model_name", type=str, default="KHopGNNConv",
#                     choices=("KHopGNNConv"), help="Base GNN model")
# parser.add_argument("--K", type=int, default=2, help="Number of hop to consider")
# parser.add_argument("--num_layer", type=int, default=2, help="Number of layer for feature encoder")
# parser.add_argument("--JK", type=str, default="sum", choices=("sum", "max", "mean", "attention", "last", "concat"),
#                     help="Jumping knowledge method")
# parser.add_argument("--residual", default=True, action="store_true", help="If true, use residual connection between each layer")
# parser.add_argument("--virtual_node", action="store_true", default=False, 
#                     help="If true, add virtual node information in each layer")
# parser.add_argument("--eps", type=float, default=0., help="Initial epsilon in GIN")
# parser.add_argument("--train_eps", action="store_true", help="If true, the epsilon in GIN model is trainable")
# parser.add_argument("--combine", type=str, default="geometric", choices=("attention", "geometric"),
#                     help="Combine method in k-hop aggregation")
# parser.add_argument("--pooling_method", type=str, default="sum", choices=("mean", "sum", "attention"),
#                     help="Pooling method in graph classification")
# parser.add_argument('--norm_type', type=str, default="Batch",
#                     choices=("Batch", "Layer", "Instance", "GraphSize", "Pair"),
#                     help="Normalization method in model")
# parser.add_argument('--aggr', type=str, default="add",
#                     help='Aggregation method in GNN layer, only works in GraphSAGE')
# parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
# parser.add_argument('--factor', type=float, default=0.5, help='Factor for reducing learning rate scheduler')
# parser.add_argument('--reprocess', action="store_true", help='If true, reprocess the dataset')
# parser.add_argument('--search', action="store_true", help='If true, search hyper-parameters')
# parser.add_argument("--pos_enc_dim", type=int, default=6, help="Initial positional dim.")
# parser.add_argument("--pos_attr", type=bool, default=False, help="Positional attributes.")
# parser.add_argument("--feature_augmentation", type=bool, default=True, help="If true, feature augmentation.")

# Structure Augmentation
parser = argparse.ArgumentParser(f'arguments for training and testing')
parser.add_argument('--save_dir', type=str, default='./save', help='Base directory for saving information.')
parser.add_argument('--seed', type=int, default=2, help='Random seed for reproducibility.') # 4 -> 93
parser.add_argument('--dataset_name', type=str, default="COLLAB",
                    choices=("MUTAG", "PROTEINS", "PTC_MR", "IMDBBINARY"), help='Name of dataset')
parser.add_argument('--drop_prob', type=float, default=0.5,
                    help='Probability of zeroing an activation in dropout layers.') # 0.5 -> 93
parser.add_argument('--batch_size', type=int, default=32, help='Batch size per GPU. Scales automatically when \
                        multiple GPUs are available.') # 32 -> 93
parser.add_argument("--parallel", action="store_true",
                    help="If true, use DataParallel for multi-gpu training")
parser.add_argument('--num_workers', type=int, default=0, help='Number of worker.')
parser.add_argument('--load_path', type=str, default=None, help='Path to load as a model checkpoint.')
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.') # 0.001 ->93
parser.add_argument('--l2_wd', type=float, default=3e-3, help='L2 weight decay.') # 3e-4 -> 93
parser.add_argument('--num_epochs', type=int, default=100, help='Number of epochs.') # 350 -> 93
parser.add_argument("--hidden_size", type=int, default=128, help="Hidden size of the model") # 128 -> 93
parser.add_argument("--embedding_size", type=int, default=16, help="Output size of the logistic model") # 16 -> 93
parser.add_argument("--model_name", type=str, default="KHopGNNConv",
                    choices=("KHopGNNConv"), help="Base GNN model")
parser.add_argument("--K", type=int, default=2, help="Number of hop to consider") # 2 -> 93
parser.add_argument("--num_layer", type=int, default=2, help="Number of layer for feature encoder") # 2 -> 93
parser.add_argument("--JK", type=str, default="sum", choices=("sum", "max", "mean", "attention", "last", "concat"),
                    help="Jumping knowledge method") # sum -> 93
parser.add_argument("--residual", default=True, action="store_true", help="If true, use residual connection between each layer")
parser.add_argument("--virtual_node", action="store_true", default=False, 
                    help="If true, add virtual node information in each layer")
parser.add_argument("--eps", type=float, default=0., help="Initial epsilon in GIN")
parser.add_argument("--train_eps", action="store_true", help="If true, the epsilon in GIN model is trainable")
parser.add_argument("--combine", type=str, default="geometric", choices=("attention", "geometric"),
                    help="Combine method in k-hop aggregation") # geometric -> 93
parser.add_argument("--pooling_method", type=str, default="sum", choices=("mean", "sum", "attention"),
                    help="Pooling method in graph classification") # sum -> 93
parser.add_argument('--norm_type', type=str, default="Batch",
                    choices=("Batch", "Layer", "Instance", "GraphSize", "Pair"),
                    help="Normalization method in model") # Batch -> 93
parser.add_argument('--aggr', type=str, default="add",
                    help='Aggregation method in GNN layer, only works in GraphSAGE')
parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
parser.add_argument('--factor', type=float, default=0.5, help='Factor for reducing learning rate scheduler')
parser.add_argument('--reprocess', action="store_true", help='If true, reprocess the dataset')
parser.add_argument('--search', action="store_true", help='If true, search hyper-parameters')
parser.add_argument("--pos_enc_dim", type=int, default=6, help="Initial positional dim.") # 6 -> 93
parser.add_argument("--pos_attr", type=bool, default=False, help="Positional attributes.")
parser.add_argument("--feature_augmentation", type=bool, default=False, help="If true, feature augmentation.")

_StoreAction(option_strings=['--feature_augmentation'], dest='feature_augmentation', nargs=None, const=None, default=False, type=<class 'bool'>, choices=None, required=False, help='If true, feature augmentation.', metavar=None)

In [3]:
args = parser.parse_args("")

In [4]:
args.name = args.model_name + "_" + str(args.K) + "_" + str(args.search)

In [5]:
# Set up logging and devices
args.save_dir = train_utils.get_save_dir(args.save_dir, args.name, type=args.dataset_name)
log = train_utils.get_logger(args.save_dir, args.name)
device, args.gpu_ids = train_utils.get_available_devices()

In [6]:
if len(args.gpu_ids) > 1 and args.parallel:
    log.info(f'Using multi-gpu training')
    args.parallel = True
    loader = DataListLoader
    args.batch_size *= max(1, len(args.gpu_ids))
else:
    log.info(f'Using single-gpu training')
    args.parallel = False
    loader = DataLoader

[05.15.24 17:14:31] Using single-gpu training


In [7]:
# Set random seed
seed = args.seed
log.info(f'Using random seed {seed}...')
seed_everything(seed)

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

[05.15.24 17:14:31] Using random seed 2...


In [8]:
def edge_feature_transform(g):
    return extract_edge_attributes(g, args.pos_enc_dim)

In [9]:
tag = str(int(time.time()))

In [10]:
def get_model(args):
    layer = make_gnn_layer(args)
    init_emb = LinearEncoder(args.input_size, args.hidden_size, pos_attr=args.pos_attr)
    init_edge_attr_emb = LinearEdgeEncoder(args.edge_attr_size, args.hidden_size, edge_attr=True)
    init_edge_attr_v2_emb = LinearEdgeEncoder(args.edge_attr_v2_size, args.hidden_size, edge_attr=False)
    
    GNNModel = make_GNN(args)
    
    gnn = GNNModel(
        num_layer=args.num_layer,
        gnn_layer=layer,
        JK=args.JK,
        norm_type=args.norm_type,
        init_emb=init_emb,
        init_edge_attr_emb=init_edge_attr_emb,
        init_edge_attr_v2_emb=init_edge_attr_v2_emb,
        residual=args.residual,
        virtual_node=args.virtual_node,
        drop_prob=args.drop_prob)

    model = GraphClassification(embedding_model=gnn,
                                pooling_method=args.pooling_method,
                                output_size=args.output_size)
    
    model.reset_parameters()

    if args.parallel:
        model = DataParallel(model, args.gpu_ids)
    return model

In [11]:
class Projection(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Projection, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU()
        )
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x) + self.linear(x)

In [12]:
def vicreg_loss(embeddings1, embeddings2, lambda_var=25, lambda_cov=25, mu=1):
    """
    Calculate the VICReg loss between two sets of embeddings with the correct handling of covariance differences.
    
    Args:
    embeddings1, embeddings2 (torch.Tensor): Embeddings from two views, shape (batch_size, feature_dim).
    lambda_var (float): Coefficient for the variance loss.
    lambda_cov (float): Coefficient for the covariance loss.
    mu (float): Coefficient for the invariance loss.
    
    Returns:
    torch.Tensor: The total VICReg loss.
    """
    # Invariance Loss
    invariance_loss = F.mse_loss(embeddings1, embeddings2)

    # Variance Loss
    def variance_loss(embeddings1, embeddings2):
        mean_embeddings1 = embeddings1.mean(dim=0)
        mean_embeddings2 = embeddings2.mean(dim=0)
        
        std_dev1 = torch.sqrt((embeddings1 - mean_embeddings1).var(dim=0) + 1e-4)
        std_dev2 = torch.sqrt((embeddings2 - mean_embeddings2).var(dim=0) + 1e-4)
        
        return torch.mean(torch.abs(F.relu(1 - std_dev1) - F.relu(1 - std_dev2)))

    variance_loss_value = variance_loss(embeddings1, embeddings2)

    # Covariance Loss
    def covariance_loss(embeddings1, embeddings2):
        batch_size, feature_dim = embeddings1.size()
        
        embeddings_centered1 = embeddings1 - embeddings1.mean(dim=0)
        embeddings_centered2 = embeddings2 - embeddings2.mean(dim=0)
        
        covariance_matrix1 = torch.matmul(embeddings_centered1.T, embeddings_centered1) / (batch_size - 1)
        covariance_matrix2 = torch.matmul(embeddings_centered2.T, embeddings_centered2) / (batch_size - 1)
        
        covariance_matrix1.fill_diagonal_(0)
        covariance_matrix2.fill_diagonal_(0)
        
        cov_diff = torch.abs(covariance_matrix1.pow(2) - covariance_matrix2.pow(2))
        return cov_diff.sum() / feature_dim

    covariance_loss_value = covariance_loss(embeddings1, embeddings2)

    total_loss = mu * invariance_loss + lambda_var * variance_loss_value + lambda_cov * covariance_loss_value
    
    return total_loss

In [13]:
class Encoder(torch.nn.Module):
    def __init__(self, model_1, model_2, mlp1, mlp2, aug1, aug2):
        super(Encoder, self).__init__()
        self.model_1 = model_1
        self.model_2 = model_2
        self.mlp1 = mlp1
        self.mlp2 = mlp2
        self.aug1 = aug1
        self.aug2 = aug2
        
    def get_embedding(self, data):
        z, g = self.model_1(data)
        z_pos, g_pos = self.model_2(data)
        
        z = self.mlp1(z)
        g = self.mlp2(g)
        
        z_pos = self.mlp1(z_pos)
        g_pos = self.mlp2(g_pos)

        g = torch.cat((g, g_pos), 1)
        z = torch.cat((z, z_pos), 1)

        return g.detach(), z.detach()

    def forward(self, data):
        data1 = self.aug1(data.x, data.edge_index, data.y, data.pos, data.edge_attr,
                          data.edge_attr_v2, data.batch, data.ptr)
        data2 = self.aug2(data.x, data.edge_index, data.y, data.pos, data.edge_attr,
                          data.edge_attr_v2, data.batch, data.ptr)
        
        # Structural features
        z1, g1 = self.model_1(data1)
        z2, g2 = self.model_1(data2)
        
        # Positional features
        z1_pos, g1_pos = self.model_2(data1)
        z2_pos, g2_pos = self.model_2(data2)
        
        h1, h2 = [self.mlp1(h) for h in [z1, z2]]
        g1, g2 = [self.mlp2(g) for g in [g1, g2]]
        
        h1_pos, h2_pos = [self.mlp1(h_pos) for h_pos in [z1_pos, z2_pos]]
        g1_pos, g2_pos = [self.mlp2(g_pos) for g_pos in [g1_pos, g2_pos]]
        
        h1 = torch.cat((h1, h1_pos), 1)
        h2 = torch.cat((h2, h2_pos), 1)
        g1 = torch.cat((g1, g1_pos), 1)
        g2 = torch.cat((g2, g2_pos), 1)
        
        return h1, h2, g1, g2

In [14]:
def train(encoder_model, dataloader, optimizer, device):
    best = float("inf")
    cnt_wait = 0
    best_t = 0
    
    loss_func = NTXentLoss(temperature=0.10)
    
    encoder_model.train()
    epoch_loss = 0
    for data in dataloader:
        data = data.to(device)
        optimizer.zero_grad()

        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)

        h1, h2, g1, g2 = encoder_model(data)
        
        embeddings = torch.cat((g1, g2))
        
        # The same index corresponds to a positive pair
        indices = torch.arange(0, g1.size(0), device=device)
        labels = torch.cat((indices, indices))
        
        reg_loss = vicreg_loss(h1, h2, lambda_var=24, lambda_cov=24, mu=1)
        loss = loss_func(embeddings, labels) + 0.005*reg_loss
        
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        
        if epoch % 20 == 0:
            print("Epoch: {0}, Loss: {1:0.4f}".format(epoch, epoch_loss))

        if epoch_loss < best:
            best = epoch_loss
            best_t = epoch
            cnt_wait = 0
            torch.save(encoder_model.state_dict(), './pkl/best_model_'+ args.dataset_name + tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            print("Early stopping")
            break
            
    return epoch_loss

In [15]:
def test(encoder_model, dataloader, seeds, device):
    encoder_model.eval()
    x = []
    y = []
    for data in dataloader:
        data = data.to(device)
        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)
        graph_embeds, node_embeds = encoder_model.get_embedding(data)
        x.append(graph_embeds)
        y.append(data.y)
    x = torch.cat(x, dim=0)
    y = torch.cat(y, dim=0)

    result = []
    random.shuffle(seeds)
    seeds = seeds.tolist()
    for _ in np.arange(10):
        random_seed = seeds.pop()
        split = get_split(num_samples=x.size()[0], train_ratio=0.1, test_ratio=0.8, seed=random_seed)
        result.append(LREvaluator()(x, y, split))
    
    return result

In [16]:
path = "./data/data_splits/" + args.dataset_name

In [17]:
train_indices, test_indices, train_test_indices = [], [], []
for i in range(10):
    train_filename = os.path.join(path, '10fold_idx', 'train_idx-{}.txt'.format(i + 1))
    test_filename = os.path.join(path, '10fold_idx', 'test_idx-{}.txt'.format(i + 1))
    train_indices.append(torch.from_numpy(np.loadtxt(train_filename, dtype=int)).to(torch.long))
    test_indices.append(torch.from_numpy(np.loadtxt(test_filename, dtype=int)).to(torch.long))

if args.feature_augmentation:
    file_name = 'train_test_splits_feat.txt'
else:
    file_name = 'train_test_splits_struct.txt'
    
train_test_filename = os.path.join(path, '10fold_idx', file_name)
train_test_indices.append(torch.from_numpy(np.loadtxt(train_test_filename, dtype=int)).to(torch.long))

In [18]:
dataset = TUDataset(root='./data/'+args.dataset_name, name=args.dataset_name, 
                    pre_transform=T.Compose([edge_feature_transform]))

In [19]:
# Determine the maximum degree in the dataset
max_degree = max([data.num_nodes for data in dataset])

# Apply the OneHotDegree transform
dataset.transform = T.OneHotDegree(max_degree)

In [20]:
classes = torch.unique(dataset.y)
args.n_classes = len(classes)

In [21]:
args.input_size = dataset.num_node_features
args.pos_size = args.pos_enc_dim
args.output_size = args.hidden_size
args.edge_attr_size =  dataset.edge_attr.shape[1]
args.edge_attr_v2_size =  dataset.edge_attr_v2.shape[1]

In [22]:
aug1 = A.Identity()

if args.feature_augmentation:
    # Feature Augmentation
    aug2 = A.RandomChoice([A.FeatureDropout(pf=0.1),
                           A.FeatureMasking(pf=0.1),
                           A.EdgeAttrMasking(pf=0.1)], 1)
else:
    # Structure Augmentation
    aug2 = A.RandomChoice([A.RWSampling(num_seeds=1000, walk_length=10),
                       A.NodeDropping(pn=0.1),
                       A.EdgeRemoving(pe=0.1)], 1)

model_1 = get_model(args)
model_1.to(device)

args.pos_attr = True
args.input_size = args.pos_size
model_2 = get_model(args)
model_2.to(device)

mlp1 = Projection(input_dim=args.hidden_size, output_dim=args.hidden_size)
mlp2 = Projection(input_dim=args.hidden_size, output_dim=args.hidden_size)

encoder_model = Encoder(model_1=model_1, model_2=model_2, mlp1=mlp1, mlp2=mlp2, aug1=aug1, aug2=aug2).to(device)

In [23]:
optimizer = Adam(encoder_model.parameters(), lr=args.lr, weight_decay=args.l2_wd)
dataloader = DataLoader(dataset, batch_size=args.batch_size)
with tqdm(total=100, desc='(T)') as pbar:
    for epoch in range(1, 101):
        loss = train(encoder_model, dataloader, optimizer, device)
        pbar.set_postfix({'loss': loss})
        pbar.update()

(T):   1%|▎                          | 1/100 [00:13<21:49, 13.23s/it, loss=28.9]

Early stopping


(T):   2%|▌                            | 2/100 [00:26<21:23, 13.10s/it, loss=14]

Early stopping


(T):   3%|▊                          | 3/100 [00:38<20:43, 12.82s/it, loss=10.4]

Early stopping


(T):   4%|█                          | 4/100 [00:51<20:44, 12.96s/it, loss=8.77]

Early stopping


(T):   5%|█▍                          | 5/100 [01:05<20:45, 13.11s/it, loss=7.5]

Early stopping


(T):   6%|█▌                         | 6/100 [01:19<21:03, 13.44s/it, loss=6.07]

Early stopping


(T):   7%|█▉                          | 7/100 [01:32<20:27, 13.20s/it, loss=6.3]

Early stopping


(T):   8%|██▏                        | 8/100 [01:44<20:01, 13.06s/it, loss=6.14]

Early stopping


(T):   9%|██▍                        | 9/100 [01:57<19:30, 12.86s/it, loss=5.47]

Early stopping


(T):  10%|██▌                       | 10/100 [02:10<19:21, 12.91s/it, loss=5.39]

Early stopping


(T):  11%|██▊                       | 11/100 [02:23<19:09, 12.92s/it, loss=4.95]

Early stopping


(T):  12%|███▏                       | 12/100 [02:36<18:54, 12.89s/it, loss=4.7]

Early stopping


(T):  13%|███▍                      | 13/100 [02:49<18:44, 12.92s/it, loss=4.04]

Early stopping


(T):  14%|███▋                      | 14/100 [03:01<18:13, 12.72s/it, loss=4.56]

Early stopping


(T):  15%|████                       | 15/100 [03:13<17:57, 12.67s/it, loss=3.9]

Early stopping


(T):  16%|████▏                     | 16/100 [03:25<17:22, 12.41s/it, loss=3.88]

Early stopping


(T):  17%|████▍                     | 17/100 [03:38<17:13, 12.46s/it, loss=3.52]

Early stopping


(T):  18%|████▋                     | 18/100 [03:50<17:02, 12.47s/it, loss=3.57]

Early stopping


(T):  19%|████▉                     | 19/100 [04:03<17:03, 12.64s/it, loss=3.67]

Early stopping
Epoch: 20, Loss: 0.1477
Epoch: 20, Loss: 0.4663
Epoch: 20, Loss: 0.6459
Epoch: 20, Loss: 0.8417
Epoch: 20, Loss: 1.0859
Epoch: 20, Loss: 1.2827
Epoch: 20, Loss: 1.3929
Epoch: 20, Loss: 1.5838
Epoch: 20, Loss: 1.7436
Epoch: 20, Loss: 1.8834
Epoch: 20, Loss: 2.0066
Epoch: 20, Loss: 2.1292
Epoch: 20, Loss: 2.2544
Epoch: 20, Loss: 2.3778
Epoch: 20, Loss: 2.5188
Epoch: 20, Loss: 2.6326
Epoch: 20, Loss: 2.8924
Epoch: 20, Loss: 3.0431
Epoch: 20, Loss: 3.1561
Epoch: 20, Loss: 3.3377


(T):  20%|█████▏                    | 20/100 [04:16<16:43, 12.55s/it, loss=3.49]

Epoch: 20, Loss: 3.4895
Early stopping


(T):  21%|█████▍                    | 21/100 [04:28<16:31, 12.55s/it, loss=3.57]

Early stopping


(T):  22%|█████▋                    | 22/100 [04:41<16:37, 12.79s/it, loss=3.49]

Early stopping


(T):  23%|█████▉                    | 23/100 [04:54<16:25, 12.80s/it, loss=3.99]

Early stopping


(T):  24%|██████▏                   | 24/100 [05:07<16:13, 12.81s/it, loss=3.75]

Early stopping


(T):  25%|██████▌                   | 25/100 [05:19<15:46, 12.62s/it, loss=3.82]

Early stopping


(T):  26%|██████▊                   | 26/100 [05:33<16:00, 12.98s/it, loss=3.35]

Early stopping


(T):  27%|███████                   | 27/100 [05:46<15:50, 13.03s/it, loss=3.35]

Early stopping


(T):  28%|███████▌                   | 28/100 [06:01<16:24, 13.67s/it, loss=3.1]

Early stopping


(T):  29%|███████▌                  | 29/100 [06:14<15:46, 13.34s/it, loss=3.29]

Early stopping


(T):  30%|███████▊                  | 30/100 [06:26<15:14, 13.07s/it, loss=3.26]

Early stopping


(T):  31%|████████                  | 31/100 [06:39<14:55, 12.98s/it, loss=3.48]

Early stopping


(T):  32%|████████▎                 | 32/100 [06:52<14:48, 13.06s/it, loss=3.63]

Early stopping


(T):  33%|████████▌                 | 33/100 [07:05<14:21, 12.86s/it, loss=3.51]

Early stopping


(T):  34%|████████▊                 | 34/100 [07:17<13:57, 12.69s/it, loss=3.83]

Early stopping


(T):  35%|█████████                 | 35/100 [07:30<13:53, 12.83s/it, loss=3.21]

Early stopping


(T):  36%|█████████▋                 | 36/100 [07:43<13:38, 12.80s/it, loss=3.3]

Early stopping


(T):  37%|█████████▌                | 37/100 [07:56<13:21, 12.72s/it, loss=3.17]

Early stopping


(T):  38%|█████████▉                | 38/100 [08:08<12:57, 12.54s/it, loss=3.17]

Early stopping


(T):  39%|██████████▏               | 39/100 [08:20<12:43, 12.51s/it, loss=3.78]

Early stopping
Epoch: 40, Loss: 0.1632
Epoch: 40, Loss: 0.4969
Epoch: 40, Loss: 0.6512
Epoch: 40, Loss: 0.8241
Epoch: 40, Loss: 1.0494
Epoch: 40, Loss: 1.1830
Epoch: 40, Loss: 1.3394
Epoch: 40, Loss: 1.5723
Epoch: 40, Loss: 1.7169
Epoch: 40, Loss: 1.8813
Epoch: 40, Loss: 2.0701
Epoch: 40, Loss: 2.2687
Epoch: 40, Loss: 2.4364
Epoch: 40, Loss: 2.5968
Epoch: 40, Loss: 2.7483
Epoch: 40, Loss: 2.8747
Epoch: 40, Loss: 3.1136
Epoch: 40, Loss: 3.3035
Epoch: 40, Loss: 3.4405
Epoch: 40, Loss: 3.5747


(T):  40%|██████████▍               | 40/100 [08:32<12:15, 12.25s/it, loss=3.72]

Epoch: 40, Loss: 3.7250
Early stopping


(T):  41%|██████████▋               | 41/100 [08:44<12:07, 12.33s/it, loss=3.19]

Early stopping


(T):  42%|██████████▉               | 42/100 [08:57<11:58, 12.39s/it, loss=3.59]

Early stopping


(T):  43%|███████████▏              | 43/100 [09:08<11:24, 12.02s/it, loss=3.47]

Early stopping


(T):  44%|███████████▍              | 44/100 [09:20<11:19, 12.14s/it, loss=3.18]

Early stopping


(T):  45%|████████████▏              | 45/100 [09:32<10:52, 11.87s/it, loss=3.3]

Early stopping


(T):  46%|███████████▉              | 46/100 [09:43<10:27, 11.61s/it, loss=3.45]

Early stopping


(T):  47%|████████████▏             | 47/100 [09:55<10:24, 11.79s/it, loss=3.35]

Early stopping


(T):  48%|████████████▍             | 48/100 [10:07<10:16, 11.86s/it, loss=3.34]

Early stopping


(T):  49%|█████████████▏             | 49/100 [10:19<10:07, 11.91s/it, loss=3.5]

Early stopping


(T):  50%|█████████████             | 50/100 [10:31<09:55, 11.92s/it, loss=3.14]

Early stopping


(T):  51%|█████████████▎            | 51/100 [10:42<09:39, 11.82s/it, loss=3.07]

Early stopping


(T):  52%|██████████████             | 52/100 [10:54<09:26, 11.81s/it, loss=3.2]

Early stopping


(T):  53%|█████████████▊            | 53/100 [11:06<09:14, 11.79s/it, loss=3.45]

Early stopping


(T):  54%|██████████████            | 54/100 [11:17<08:57, 11.69s/it, loss=3.79]

Early stopping


(T):  55%|██████████████▊            | 55/100 [11:29<08:49, 11.76s/it, loss=3.8]

Early stopping


(T):  56%|██████████████▌           | 56/100 [11:41<08:41, 11.86s/it, loss=3.52]

Early stopping


(T):  57%|██████████████▊           | 57/100 [11:53<08:29, 11.84s/it, loss=3.57]

Early stopping


(T):  58%|███████████████           | 58/100 [12:04<08:09, 11.66s/it, loss=3.61]

Early stopping


(T):  59%|███████████████▎          | 59/100 [12:16<07:53, 11.55s/it, loss=3.79]

Early stopping
Epoch: 60, Loss: 0.1557
Epoch: 60, Loss: 0.3975
Epoch: 60, Loss: 0.5411
Epoch: 60, Loss: 0.6880
Epoch: 60, Loss: 0.9485
Epoch: 60, Loss: 1.1092
Epoch: 60, Loss: 1.3306
Epoch: 60, Loss: 1.5925
Epoch: 60, Loss: 1.8374
Epoch: 60, Loss: 1.9645
Epoch: 60, Loss: 2.1259
Epoch: 60, Loss: 2.3426
Epoch: 60, Loss: 2.5367
Epoch: 60, Loss: 2.7201
Epoch: 60, Loss: 2.9546
Epoch: 60, Loss: 3.1577
Epoch: 60, Loss: 3.4906
Epoch: 60, Loss: 3.7213
Epoch: 60, Loss: 3.8764
Epoch: 60, Loss: 4.0395


(T):  60%|███████████████▌          | 60/100 [12:27<07:33, 11.34s/it, loss=4.29]

Epoch: 60, Loss: 4.2919
Early stopping


(T):  61%|███████████████▊          | 61/100 [12:39<07:36, 11.70s/it, loss=3.63]

Early stopping


(T):  62%|████████████████          | 62/100 [12:50<07:18, 11.54s/it, loss=3.91]

Early stopping


(T):  63%|████████████████▍         | 63/100 [13:02<07:05, 11.50s/it, loss=3.61]

Early stopping


(T):  64%|████████████████▋         | 64/100 [13:13<06:52, 11.46s/it, loss=4.02]

Early stopping


(T):  65%|████████████████▉         | 65/100 [13:24<06:36, 11.33s/it, loss=4.07]

Early stopping


(T):  66%|█████████████████▏        | 66/100 [13:35<06:25, 11.34s/it, loss=3.96]

Early stopping


(T):  67%|█████████████████▍        | 67/100 [13:47<06:19, 11.50s/it, loss=3.76]

Early stopping


(T):  68%|█████████████████▋        | 68/100 [13:58<06:01, 11.30s/it, loss=3.51]

Early stopping


(T):  69%|██████████████████▋        | 69/100 [14:10<05:58, 11.58s/it, loss=3.1]

Early stopping


(T):  70%|██████████████████▏       | 70/100 [14:22<05:44, 11.48s/it, loss=3.11]

Early stopping


(T):  71%|██████████████████▍       | 71/100 [14:33<05:33, 11.51s/it, loss=2.96]

Early stopping


(T):  72%|██████████████████▋       | 72/100 [14:45<05:22, 11.53s/it, loss=2.97]

Early stopping


(T):  73%|██████████████████▉       | 73/100 [14:56<05:11, 11.54s/it, loss=3.05]

Early stopping


(T):  74%|███████████████████▏      | 74/100 [15:09<05:06, 11.81s/it, loss=3.12]

Early stopping


(T):  75%|████████████████████▎      | 75/100 [15:22<05:07, 12.31s/it, loss=3.2]

Early stopping


(T):  76%|███████████████████▊      | 76/100 [15:33<04:45, 11.88s/it, loss=3.31]

Early stopping


(T):  77%|████████████████████      | 77/100 [15:45<04:30, 11.78s/it, loss=3.34]

Early stopping


(T):  78%|████████████████████▎     | 78/100 [15:56<04:14, 11.55s/it, loss=3.89]

Early stopping


(T):  79%|████████████████████▌     | 79/100 [16:07<04:03, 11.58s/it, loss=4.12]

Early stopping
Epoch: 80, Loss: 0.1906
Epoch: 80, Loss: 0.4731
Epoch: 80, Loss: 0.7016
Epoch: 80, Loss: 0.8845
Epoch: 80, Loss: 1.1903
Epoch: 80, Loss: 1.4075
Epoch: 80, Loss: 1.5759
Epoch: 80, Loss: 1.7573
Epoch: 80, Loss: 1.9189
Epoch: 80, Loss: 2.0614
Epoch: 80, Loss: 2.2154
Epoch: 80, Loss: 2.3807
Epoch: 80, Loss: 2.5331
Epoch: 80, Loss: 2.6879
Epoch: 80, Loss: 2.8682
Epoch: 80, Loss: 3.0640
Epoch: 80, Loss: 3.3219
Epoch: 80, Loss: 3.5284
Epoch: 80, Loss: 3.7420
Epoch: 80, Loss: 3.9208


(T):  80%|████████████████████▊     | 80/100 [16:19<03:49, 11.50s/it, loss=4.08]

Epoch: 80, Loss: 4.0807
Early stopping


(T):  81%|█████████████████████     | 81/100 [16:31<03:41, 11.65s/it, loss=4.33]

Early stopping


(T):  82%|█████████████████████▎    | 82/100 [16:43<03:32, 11.83s/it, loss=4.24]

Early stopping


(T):  83%|█████████████████████▌    | 83/100 [16:57<03:31, 12.44s/it, loss=3.69]

Early stopping


(T):  84%|█████████████████████▊    | 84/100 [17:09<03:16, 12.27s/it, loss=3.76]

Early stopping


(T):  85%|██████████████████████    | 85/100 [17:21<03:03, 12.21s/it, loss=3.76]

Early stopping


(T):  86%|██████████████████████▎   | 86/100 [17:34<02:56, 12.60s/it, loss=3.61]

Early stopping


(T):  87%|██████████████████████▌   | 87/100 [17:46<02:40, 12.38s/it, loss=3.39]

Early stopping


(T):  88%|██████████████████████▉   | 88/100 [17:59<02:28, 12.40s/it, loss=3.49]

Early stopping


(T):  89%|███████████████████████▏  | 89/100 [18:11<02:15, 12.35s/it, loss=3.45]

Early stopping


(T):  90%|███████████████████████▍  | 90/100 [18:22<02:01, 12.16s/it, loss=3.43]

Early stopping


(T):  91%|███████████████████████▋  | 91/100 [18:34<01:48, 12.06s/it, loss=3.24]

Early stopping


(T):  92%|███████████████████████▉  | 92/100 [18:46<01:35, 11.88s/it, loss=3.24]

Early stopping


(T):  93%|████████████████████████▏ | 93/100 [18:58<01:24, 12.13s/it, loss=3.18]

Early stopping


(T):  94%|████████████████████████▍ | 94/100 [19:12<01:15, 12.57s/it, loss=3.06]

Early stopping


(T):  95%|████████████████████████▋ | 95/100 [19:25<01:04, 12.83s/it, loss=3.04]

Early stopping


(T):  96%|████████████████████████▉ | 96/100 [19:37<00:49, 12.48s/it, loss=3.22]

Early stopping


(T):  97%|█████████████████████████▏| 97/100 [19:49<00:36, 12.16s/it, loss=3.24]

Early stopping


(T):  98%|█████████████████████████▍| 98/100 [20:00<00:23, 12.00s/it, loss=3.23]

Early stopping


(T):  99%|█████████████████████████▋| 99/100 [20:11<00:11, 11.78s/it, loss=3.19]

Early stopping
Epoch: 100, Loss: 0.1325
Epoch: 100, Loss: 0.3592
Epoch: 100, Loss: 0.4958
Epoch: 100, Loss: 0.6571
Epoch: 100, Loss: 0.9247
Epoch: 100, Loss: 1.0702
Epoch: 100, Loss: 1.2044
Epoch: 100, Loss: 1.3528
Epoch: 100, Loss: 1.5015
Epoch: 100, Loss: 1.6499
Epoch: 100, Loss: 1.8329
Epoch: 100, Loss: 2.0070
Epoch: 100, Loss: 2.1455
Epoch: 100, Loss: 2.2925
Epoch: 100, Loss: 2.4604
Epoch: 100, Loss: 2.6255
Epoch: 100, Loss: 2.8785
Epoch: 100, Loss: 3.0619
Epoch: 100, Loss: 3.2015
Epoch: 100, Loss: 3.3546


(T): 100%|█████████████████████████| 100/100 [20:22<00:00, 12.23s/it, loss=3.53]

Epoch: 100, Loss: 3.5297
Early stopping





In [24]:
test_result = test(encoder_model, dataloader, train_test_indices[0], device)

(LR): 100%|███████████| 5000/5000 [00:02<00:00, best test F1Mi=0.764, F1Ma=0.73]
(LR): 100%|██████████| 5000/5000 [00:02<00:00, best test F1Mi=0.752, F1Ma=0.722]
(LR): 100%|██████████| 5000/5000 [00:02<00:00, best test F1Mi=0.752, F1Ma=0.722]
(LR): 100%|██████████| 5000/5000 [00:02<00:00, best test F1Mi=0.774, F1Ma=0.753]
(LR): 100%|██████████| 5000/5000 [00:02<00:00, best test F1Mi=0.752, F1Ma=0.734]
(LR): 100%|██████████| 5000/5000 [00:02<00:00, best test F1Mi=0.752, F1Ma=0.722]
(LR): 100%|███████████| 5000/5000 [00:02<00:00, best test F1Mi=0.76, F1Ma=0.729]
(LR): 100%|███████████| 5000/5000 [00:02<00:00, best test F1Mi=0.75, F1Ma=0.731]
(LR): 100%|███████████| 5000/5000 [00:02<00:00, best test F1Mi=0.768, F1Ma=0.75]
(LR): 100%|██████████| 5000/5000 [00:02<00:00, best test F1Mi=0.758, F1Ma=0.724]


In [25]:
micro_f1_values = [entry['micro_f1']*100 for entry in test_result]

In [26]:
np_micro_f1_values = np.array(micro_f1_values)
micro_f1_mean = np.mean(np_micro_f1_values)
micro_f1_std = np.std(np_micro_f1_values)

print(f'test acc mean = {micro_f1_mean:.4f} ± {micro_f1_std:.4f}')

test acc mean = 75.8200 ± 0.7769
