In [1]:
import argparse
import os
import os.path as osp
import shutil
import time
from itertools import product
from json import dumps
import random
import numpy as np
import seaborn as sns
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch.optim import Adam
from torch_geometric.loader import DataListLoader, DataLoader
from torch_geometric.nn import DataParallel
from torch_geometric.seed import seed_everything
from torch_geometric.utils.convert import to_networkx
from torch_geometric.data import DataLoader, Data
from torch import nn
from tqdm import tqdm

from pytorch_metric_learning.losses import NTXentLoss, VICRegLoss
from sklearn.model_selection import GridSearchCV, StratifiedKFold

import train_utils
from data_utils import extract_edge_attributes
from torch_geometric.datasets import TUDataset
from layers.input_encoder import LinearEncoder, LinearEdgeEncoder
from layers.layer_utils import make_gnn_layer
from models.GraphClassification import GraphClassification
from models.model_utils import make_GNN

import GCL.losses as L
import GCL.augmentors as A
from GCL.eval import get_split, SVMEvaluator, LREvaluator
from GCL.models import DualBranchContrast

import warnings
warnings.filterwarnings('ignore')

In [2]:
# Feature Augmentation
parser = argparse.ArgumentParser(f'arguments for training and testing')
parser.add_argument('--save_dir', type=str, default='./save', help='Base directory for saving information.')
parser.add_argument('--seed', type=int, default=2, help='Random seed for reproducibility.')
parser.add_argument('--dataset_name', type=str, default="PROTEINS",
                    choices=("MUTAG", "PROTEINS", "PTC_MR", "IMDBBINARY"), help='Name of dataset')
parser.add_argument('--drop_prob', type=float, default=0.6,
                    help='Probability of zeroing an activation in dropout layers.') 
parser.add_argument('--batch_size', type=int, default=32, help='Batch size per GPU. Scales automatically when \
                        multiple GPUs are available.')
parser.add_argument("--parallel", action="store_true",
                    help="If true, use DataParallel for multi-gpu training")
parser.add_argument('--num_workers', type=int, default=0, help='Number of worker.')
parser.add_argument('--load_path', type=str, default=None, help='Path to load as a model checkpoint.')
parser.add_argument('--lr', type=float, default=0.005, help='Learning rate.')
parser.add_argument('--l2_wd', type=float, default=3e-4, help='L2 weight decay.')
parser.add_argument('--num_epochs', type=int, default=100, help='Number of epochs.')
parser.add_argument("--hidden_size", type=int, default=128, help="Hidden size of the model")
parser.add_argument("--embedding_size", type=int, default=16, help="Output size of the logistic model")
parser.add_argument("--model_name", type=str, default="KHopGNNConv",
                    choices=("KHopGNNConv"), help="Base GNN model")
parser.add_argument("--K", type=int, default=2, help="Number of hop to consider")
parser.add_argument("--num_layer", type=int, default=2, help="Number of layer for feature encoder")
parser.add_argument("--JK", type=str, default="sum", choices=("sum", "max", "mean", "attention", "last", "concat"),
                    help="Jumping knowledge method")
parser.add_argument("--residual", default=True, action="store_true", help="If true, use residual connection between each layer")
parser.add_argument("--virtual_node", action="store_true", default=False, 
                    help="If true, add virtual node information in each layer")
parser.add_argument("--eps", type=float, default=0., help="Initial epsilon in GIN")
parser.add_argument("--train_eps", action="store_true", help="If true, the epsilon in GIN model is trainable")
parser.add_argument("--combine", type=str, default="geometric", choices=("attention", "geometric"),
                    help="Combine method in k-hop aggregation")
parser.add_argument("--pooling_method", type=str, default="sum", choices=("mean", "sum", "attention"),
                    help="Pooling method in graph classification")
parser.add_argument('--norm_type', type=str, default="Batch",
                    choices=("Batch", "Layer", "Instance", "GraphSize", "Pair"),
                    help="Normalization method in model")
parser.add_argument('--aggr', type=str, default="add",
                    help='Aggregation method in GNN layer, only works in GraphSAGE')
parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
parser.add_argument('--factor', type=float, default=0.5, help='Factor for reducing learning rate scheduler')
parser.add_argument('--reprocess', action="store_true", help='If true, reprocess the dataset')
parser.add_argument('--search', action="store_true", help='If true, search hyper-parameters')
parser.add_argument("--pos_enc_dim", type=int, default=6, help="Initial positional dim.")
parser.add_argument("--pos_attr", type=bool, default=False, help="Positional attributes.")
parser.add_argument("--feature_augmentation", type=bool, default=True, help="If true, feature augmentation.")

# Structure Augmentation
# parser = argparse.ArgumentParser(f'arguments for training and testing')
# parser.add_argument('--save_dir', type=str, default='./save', help='Base directory for saving information.')
# parser.add_argument('--seed', type=int, default=2, help='Random seed for reproducibility.') # 4 -> 93
# parser.add_argument('--dataset_name', type=str, default="PROTEINS",
#                     choices=("MUTAG", "PROTEINS", "PTC_MR", "IMDBBINARY"), help='Name of dataset')
# parser.add_argument('--drop_prob', type=float, default=0.5,
#                     help='Probability of zeroing an activation in dropout layers.') # 0.5 -> 93
# parser.add_argument('--batch_size', type=int, default=32, help='Batch size per GPU. Scales automatically when \
#                         multiple GPUs are available.') # 32 -> 93
# parser.add_argument("--parallel", action="store_true",
#                     help="If true, use DataParallel for multi-gpu training")
# parser.add_argument('--num_workers', type=int, default=0, help='Number of worker.')
# parser.add_argument('--load_path', type=str, default=None, help='Path to load as a model checkpoint.')
# parser.add_argument('--lr', type=float, default=0.001, help='Learning rate.') # 0.001 ->93
# parser.add_argument('--l2_wd', type=float, default=3e-4, help='L2 weight decay.') # 3e-4 -> 93
# parser.add_argument('--num_epochs', type=int, default=100, help='Number of epochs.') # 350 -> 93
# parser.add_argument("--hidden_size", type=int, default=128, help="Hidden size of the model") # 128 -> 93
# parser.add_argument("--embedding_size", type=int, default=16, help="Output size of the logistic model") # 16 -> 93
# parser.add_argument("--model_name", type=str, default="KHopGNNConv",
#                     choices=("KHopGNNConv"), help="Base GNN model")
# parser.add_argument("--K", type=int, default=2, help="Number of hop to consider") # 2 -> 93
# parser.add_argument("--num_layer", type=int, default=2, help="Number of layer for feature encoder") # 2 -> 93
# parser.add_argument("--JK", type=str, default="sum", choices=("sum", "max", "mean", "attention", "last", "concat"),
#                     help="Jumping knowledge method") # sum -> 93
# parser.add_argument("--residual", default=True, action="store_true", help="If true, use residual connection between each layer")
# parser.add_argument("--virtual_node", action="store_true", default=False, 
#                     help="If true, add virtual node information in each layer")
# parser.add_argument("--eps", type=float, default=0., help="Initial epsilon in GIN")
# parser.add_argument("--train_eps", action="store_true", help="If true, the epsilon in GIN model is trainable")
# parser.add_argument("--combine", type=str, default="geometric", choices=("attention", "geometric"),
#                     help="Combine method in k-hop aggregation") # geometric -> 93
# parser.add_argument("--pooling_method", type=str, default="sum", choices=("mean", "sum", "attention"),
#                     help="Pooling method in graph classification") # sum -> 93
# parser.add_argument('--norm_type', type=str, default="Batch",
#                     choices=("Batch", "Layer", "Instance", "GraphSize", "Pair"),
#                     help="Normalization method in model") # Batch -> 93
# parser.add_argument('--aggr', type=str, default="add",
#                     help='Aggregation method in GNN layer, only works in GraphSAGE')
# parser.add_argument("--patience", type=int, default=20, help="Patient epochs to wait before early stopping.")
# parser.add_argument('--factor', type=float, default=0.5, help='Factor for reducing learning rate scheduler')
# parser.add_argument('--reprocess', action="store_true", help='If true, reprocess the dataset')
# parser.add_argument('--search', action="store_true", help='If true, search hyper-parameters')
# parser.add_argument("--pos_enc_dim", type=int, default=6, help="Initial positional dim.") # 6 -> 93
# parser.add_argument("--pos_attr", type=bool, default=False, help="Positional attributes.")
# parser.add_argument("--feature_augmentation", type=bool, default=False, help="If true, feature augmentation.")

_StoreAction(option_strings=['--feature_augmentation'], dest='feature_augmentation', nargs=None, const=None, default=True, type=<class 'bool'>, choices=None, required=False, help='If true, feature augmentation.', metavar=None)

In [3]:
args = parser.parse_args("")

In [4]:
args.name = args.model_name + "_" + str(args.K) + "_" + str(args.search)

In [5]:
# Set up logging and devices
args.save_dir = train_utils.get_save_dir(args.save_dir, args.name, type=args.dataset_name)
log = train_utils.get_logger(args.save_dir, args.name)
device, args.gpu_ids = train_utils.get_available_devices()

In [6]:
if len(args.gpu_ids) > 1 and args.parallel:
    log.info(f'Using multi-gpu training')
    args.parallel = True
    loader = DataListLoader
    args.batch_size *= max(1, len(args.gpu_ids))
else:
    log.info(f'Using single-gpu training')
    args.parallel = False
    loader = DataLoader

[05.12.24 21:08:54] Using single-gpu training


In [7]:
# Set random seed
seed = args.seed
log.info(f'Using random seed {seed}...')
seed_everything(seed)

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

[05.12.24 21:08:54] Using random seed 2...


In [8]:
def edge_feature_transform(g):
    return extract_edge_attributes(g, args.pos_enc_dim)

In [9]:
tag = str(int(time.time()))

In [10]:
def get_model(args):
    layer = make_gnn_layer(args)
    init_emb = LinearEncoder(args.input_size, args.hidden_size, pos_attr=args.pos_attr)
    init_edge_attr_emb = LinearEdgeEncoder(args.edge_attr_size, args.hidden_size, edge_attr=True)
    init_edge_attr_v2_emb = LinearEdgeEncoder(args.edge_attr_v2_size, args.hidden_size, edge_attr=False)
    
    GNNModel = make_GNN(args)
    
    gnn = GNNModel(
        num_layer=args.num_layer,
        gnn_layer=layer,
        JK=args.JK,
        norm_type=args.norm_type,
        init_emb=init_emb,
        init_edge_attr_emb=init_edge_attr_emb,
        init_edge_attr_v2_emb=init_edge_attr_v2_emb,
        residual=args.residual,
        virtual_node=args.virtual_node,
        drop_prob=args.drop_prob)

    model = GraphClassification(embedding_model=gnn,
                                pooling_method=args.pooling_method,
                                output_size=args.output_size)
    
    model.reset_parameters()

    if args.parallel:
        model = DataParallel(model, args.gpu_ids)
    return model

In [11]:
class Projection(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Projection, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU(),
            nn.Linear(output_dim, output_dim),
            nn.ReLU()
        )
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.fc(x) + self.linear(x)

In [12]:
def vicreg_loss(embeddings1, embeddings2, lambda_var=25, lambda_cov=25, mu=1):
    """
    Calculate the VICReg loss between two sets of embeddings with the correct handling of covariance differences.
    
    Args:
    embeddings1, embeddings2 (torch.Tensor): Embeddings from two views, shape (batch_size, feature_dim).
    lambda_var (float): Coefficient for the variance loss.
    lambda_cov (float): Coefficient for the covariance loss.
    mu (float): Coefficient for the invariance loss.
    
    Returns:
    torch.Tensor: The total VICReg loss.
    """
    # Invariance Loss
    invariance_loss = F.mse_loss(embeddings1, embeddings2)

    # Variance Loss
    def variance_loss(embeddings1, embeddings2):
        mean_embeddings1 = embeddings1.mean(dim=0)
        mean_embeddings2 = embeddings2.mean(dim=0)
        
        std_dev1 = torch.sqrt((embeddings1 - mean_embeddings1).var(dim=0) + 1e-4)
        std_dev2 = torch.sqrt((embeddings2 - mean_embeddings2).var(dim=0) + 1e-4)
        
        return torch.mean(torch.abs(F.relu(1 - std_dev1) - F.relu(1 - std_dev2)))

    variance_loss_value = variance_loss(embeddings1, embeddings2)

    # Covariance Loss
    def covariance_loss(embeddings1, embeddings2):
        batch_size, feature_dim = embeddings1.size()
        
        embeddings_centered1 = embeddings1 - embeddings1.mean(dim=0)
        embeddings_centered2 = embeddings2 - embeddings2.mean(dim=0)
        
        covariance_matrix1 = torch.matmul(embeddings_centered1.T, embeddings_centered1) / (batch_size - 1)
        covariance_matrix2 = torch.matmul(embeddings_centered2.T, embeddings_centered2) / (batch_size - 1)
        
        covariance_matrix1.fill_diagonal_(0)
        covariance_matrix2.fill_diagonal_(0)
        
        cov_diff = torch.abs(covariance_matrix1.pow(2) - covariance_matrix2.pow(2))
        return cov_diff.sum() / feature_dim

    covariance_loss_value = covariance_loss(embeddings1, embeddings2)

    total_loss = mu * invariance_loss + lambda_var * variance_loss_value + lambda_cov * covariance_loss_value
    
    return total_loss

In [13]:
class Encoder(torch.nn.Module):
    def __init__(self, model_1, model_2, mlp1, mlp2, aug1, aug2):
        super(Encoder, self).__init__()
        self.model_1 = model_1
        self.model_2 = model_2
        self.mlp1 = mlp1
        self.mlp2 = mlp2
        self.aug1 = aug1
        self.aug2 = aug2
        
    def get_embedding(self, data):
        z, g = self.model_1(data)
        z_pos, g_pos = self.model_2(data)
        
        z = self.mlp1(z)
        g = self.mlp2(g)
        
        z_pos = self.mlp1(z_pos)
        g_pos = self.mlp2(g_pos)

        g = torch.cat((g, g_pos), 1)
        z = torch.cat((z, z_pos), 1)

        return g.detach(), z.detach()

    def forward(self, data):
        data1 = self.aug1(data.x, data.edge_index, data.y, data.pos, data.edge_attr,
                          data.edge_attr_v2, data.batch, data.ptr)
        data2 = self.aug2(data.x, data.edge_index, data.y, data.pos, data.edge_attr,
                          data.edge_attr_v2, data.batch, data.ptr)
        
        # Structural features
        z1, g1 = self.model_1(data1)
        z2, g2 = self.model_1(data2)
        
        # Positional features
        z1_pos, g1_pos = self.model_2(data1)
        z2_pos, g2_pos = self.model_2(data2)
        
        h1, h2 = [self.mlp1(h) for h in [z1, z2]]
        g1, g2 = [self.mlp2(g) for g in [g1, g2]]
        
        h1_pos, h2_pos = [self.mlp1(h_pos) for h_pos in [z1_pos, z2_pos]]
        g1_pos, g2_pos = [self.mlp2(g_pos) for g_pos in [g1_pos, g2_pos]]
        
        h1 = torch.cat((h1, h1_pos), 1)
        h2 = torch.cat((h2, h2_pos), 1)
        g1 = torch.cat((g1, g1_pos), 1)
        g2 = torch.cat((g2, g2_pos), 1)
        
        return h1, h2, g1, g2

In [14]:
def train(encoder_model, dataloader, optimizer, device):
    best = float("inf")
    cnt_wait = 0
    best_t = 0
    
    loss_func = NTXentLoss(temperature=0.10)
    
    encoder_model.train()
    epoch_loss = 0
    for data in dataloader:
        data = data.to(device)
        optimizer.zero_grad()

        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)

        h1, h2, g1, g2 = encoder_model(data)
        
        embeddings = torch.cat((g1, g2))
        
        # The same index corresponds to a positive pair
        indices = torch.arange(0, g1.size(0), device=device)
        labels = torch.cat((indices, indices))
        
        reg_loss = vicreg_loss(h1, h2, lambda_var=24, lambda_cov=24, mu=1)
        loss = loss_func(embeddings, labels) + 0.005*reg_loss
        
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        
        if epoch % 20 == 0:
            print("Epoch: {0}, Loss: {1:0.4f}".format(epoch, epoch_loss))

        if epoch_loss < best:
            best = epoch_loss
            best_t = epoch
            cnt_wait = 0
            torch.save(encoder_model.state_dict(), './pkl/best_model_'+ args.dataset_name + tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            print("Early stopping")
            break
            
    return epoch_loss

In [15]:
def test(encoder_model, dataloader, seeds, device):
    encoder_model.eval()
    x = []
    y = []
    for data in dataloader:
        data = data.to(device)
        if data.x is None:
            num_nodes = data.batch.size(0)
            data.x = torch.ones((num_nodes, 1), dtype=torch.float32, device=data.batch.device)
        graph_embeds, node_embeds = encoder_model.get_embedding(data)
        x.append(graph_embeds)
        y.append(data.y)
    x = torch.cat(x, dim=0)
    y = torch.cat(y, dim=0)

    result = []
    random.shuffle(seeds)
    seeds = seeds.tolist()
    for _ in np.arange(10):
        random_seed = seeds.pop()
        split = get_split(num_samples=x.size()[0], train_ratio=0.1, test_ratio=0.8, seed=random_seed)
        result.append(LREvaluator()(x, y, split))
    
    return result

In [16]:
path = "./data/data_splits/" + args.dataset_name

In [17]:
train_indices, test_indices, train_test_indices = [], [], []
for i in range(10):
    train_filename = os.path.join(path, '10fold_idx', 'train_idx-{}.txt'.format(i + 1))
    test_filename = os.path.join(path, '10fold_idx', 'test_idx-{}.txt'.format(i + 1))
    train_indices.append(torch.from_numpy(np.loadtxt(train_filename, dtype=int)).to(torch.long))
    test_indices.append(torch.from_numpy(np.loadtxt(test_filename, dtype=int)).to(torch.long))

if args.feature_augmentation:
    file_name = 'train_test_splits_feat.txt'
else:
    file_name = 'train_test_splits_struct.txt'
    
train_test_filename = os.path.join(path, '10fold_idx', file_name)
train_test_indices.append(torch.from_numpy(np.loadtxt(train_test_filename, dtype=int)).to(torch.long))

In [18]:
dataset = TUDataset(root='./data/'+args.dataset_name, name=args.dataset_name, 
                    pre_transform=T.Compose([edge_feature_transform]))

In [19]:
classes = torch.unique(dataset.y)
args.n_classes = len(classes)

In [20]:
args.input_size = dataset.num_node_features
args.pos_size = args.pos_enc_dim
args.output_size = args.hidden_size
args.edge_attr_size =  dataset.edge_attr.shape[1]
args.edge_attr_v2_size =  dataset.edge_attr_v2.shape[1]

In [21]:
aug1 = A.Identity()

if args.feature_augmentation:
    # Feature Augmentation
    aug2 = A.RandomChoice([A.FeatureDropout(pf=0.1),
                           A.FeatureMasking(pf=0.1),
                           A.EdgeAttrMasking(pf=0.1)], 1)
else:
    # Structure Augmentation
    aug2 = A.RandomChoice([A.RWSampling(num_seeds=1000, walk_length=10),
                       A.NodeDropping(pn=0.1),
                       A.EdgeRemoving(pe=0.1)], 1)

model_1 = get_model(args)
model_1.to(device)

args.pos_attr = True
args.input_size = args.pos_size
model_2 = get_model(args)
model_2.to(device)

mlp1 = Projection(input_dim=args.hidden_size, output_dim=args.hidden_size)
mlp2 = Projection(input_dim=args.hidden_size, output_dim=args.hidden_size)

encoder_model = Encoder(model_1=model_1, model_2=model_2, mlp1=mlp1, mlp2=mlp2, aug1=aug1, aug2=aug2).to(device)

In [22]:
optimizer = Adam(encoder_model.parameters(), lr=args.lr, weight_decay=args.l2_wd)
dataloader = DataLoader(dataset, batch_size=args.batch_size)
with tqdm(total=100, desc='(T)') as pbar:
    for epoch in range(1, 101):
        loss = train(encoder_model, dataloader, optimizer, device)
        pbar.set_postfix({'loss': loss})
        pbar.update()

(T):   1%|▎                          | 1/100 [00:02<04:55,  2.98s/it, loss=48.5]

Early stopping


(T):   2%|▌                          | 2/100 [00:05<04:49,  2.96s/it, loss=37.8]

Early stopping


(T):   3%|▊                          | 3/100 [00:08<04:45,  2.95s/it, loss=30.4]

Early stopping


(T):   4%|█                          | 4/100 [00:11<04:42,  2.94s/it, loss=29.4]

Early stopping


(T):   5%|█▍                           | 5/100 [00:14<04:42,  2.98s/it, loss=30]

Early stopping


(T):   6%|█▋                           | 6/100 [00:17<04:39,  2.97s/it, loss=26]

Early stopping


(T):   7%|█▉                         | 7/100 [00:20<04:36,  2.97s/it, loss=25.8]

Early stopping


(T):   8%|██▏                        | 8/100 [00:23<04:34,  2.99s/it, loss=22.6]

Early stopping


(T):   9%|██▌                          | 9/100 [00:26<04:31,  2.98s/it, loss=30]

Early stopping


(T):  10%|██▊                         | 10/100 [00:29<04:29,  2.99s/it, loss=28]

Early stopping


(T):  11%|██▊                       | 11/100 [00:32<04:25,  2.98s/it, loss=25.2]

Early stopping


(T):  12%|███                       | 12/100 [00:35<04:21,  2.97s/it, loss=19.6]

Early stopping


(T):  13%|███▍                      | 13/100 [00:38<04:17,  2.96s/it, loss=19.4]

Early stopping


(T):  14%|███▋                      | 14/100 [00:41<04:14,  2.96s/it, loss=22.4]

Early stopping


(T):  15%|███▉                      | 15/100 [00:44<04:12,  2.97s/it, loss=20.6]

Early stopping


(T):  16%|████▏                     | 16/100 [00:47<04:08,  2.96s/it, loss=19.5]

Early stopping


(T):  17%|████▍                     | 17/100 [00:50<04:05,  2.96s/it, loss=14.4]

Early stopping


(T):  18%|████▋                     | 18/100 [00:53<04:02,  2.96s/it, loss=20.7]

Early stopping


(T):  19%|████▉                     | 19/100 [00:56<03:59,  2.95s/it, loss=18.6]

Early stopping
Epoch: 20, Loss: 1.1606
Epoch: 20, Loss: 2.1216
Epoch: 20, Loss: 2.8957
Epoch: 20, Loss: 4.0673
Epoch: 20, Loss: 4.4895
Epoch: 20, Loss: 5.4081
Epoch: 20, Loss: 6.1319
Epoch: 20, Loss: 6.6626
Epoch: 20, Loss: 7.1736
Epoch: 20, Loss: 8.0045
Epoch: 20, Loss: 9.1672
Epoch: 20, Loss: 9.9855
Epoch: 20, Loss: 11.5239
Epoch: 20, Loss: 12.2690
Epoch: 20, Loss: 12.8128
Epoch: 20, Loss: 13.4073
Epoch: 20, Loss: 13.9180
Epoch: 20, Loss: 14.5479
Epoch: 20, Loss: 15.7223
Epoch: 20, Loss: 16.8387


(T):  20%|█████▏                    | 20/100 [00:59<03:57,  2.97s/it, loss=17.7]

Epoch: 20, Loss: 17.6777
Early stopping


(T):  21%|█████▍                    | 21/100 [01:02<03:55,  2.99s/it, loss=20.1]

Early stopping


(T):  22%|█████▋                    | 22/100 [01:05<03:53,  2.99s/it, loss=16.8]

Early stopping


(T):  23%|█████▉                    | 23/100 [01:08<03:49,  2.98s/it, loss=22.1]

Early stopping


(T):  24%|██████▋                     | 24/100 [01:11<03:48,  3.01s/it, loss=20]

Early stopping


(T):  25%|██████▌                   | 25/100 [01:14<03:44,  3.00s/it, loss=11.9]

Early stopping


(T):  26%|██████▊                   | 26/100 [01:17<03:42,  3.01s/it, loss=21.2]

Early stopping


(T):  27%|███████                   | 27/100 [01:20<03:38,  2.99s/it, loss=15.5]

Early stopping


(T):  28%|███████▎                  | 28/100 [01:23<03:35,  2.99s/it, loss=15.3]

Early stopping


(T):  29%|███████▌                  | 29/100 [01:26<03:32,  2.99s/it, loss=13.3]

Early stopping


(T):  30%|███████▊                  | 30/100 [01:29<03:28,  2.98s/it, loss=10.2]

Early stopping


(T):  31%|████████                  | 31/100 [01:32<03:25,  2.99s/it, loss=15.5]

Early stopping


(T):  32%|████████▎                 | 32/100 [01:35<03:22,  2.98s/it, loss=10.5]

Early stopping


(T):  33%|█████████▏                  | 33/100 [01:38<03:19,  2.98s/it, loss=14]

Early stopping


(T):  34%|████████▊                 | 34/100 [01:41<03:15,  2.96s/it, loss=14.7]

Early stopping


(T):  35%|█████████▊                  | 35/100 [01:44<03:12,  2.96s/it, loss=13]

Early stopping


(T):  36%|█████████▎                | 36/100 [01:47<03:09,  2.96s/it, loss=11.7]

Early stopping


(T):  37%|█████████▌                | 37/100 [01:50<03:07,  2.97s/it, loss=18.3]

Early stopping


(T):  38%|█████████▉                | 38/100 [01:53<03:04,  2.97s/it, loss=15.9]

Early stopping


(T):  39%|██████████▏               | 39/100 [01:56<03:00,  2.97s/it, loss=14.8]

Early stopping
Epoch: 40, Loss: 0.9035
Epoch: 40, Loss: 1.7295
Epoch: 40, Loss: 2.5726
Epoch: 40, Loss: 3.5013
Epoch: 40, Loss: 5.1637
Epoch: 40, Loss: 23.8902
Epoch: 40, Loss: 24.9658
Epoch: 40, Loss: 31.8097
Epoch: 40, Loss: 53.3478
Epoch: 40, Loss: 102.0576
Epoch: 40, Loss: 105.6865
Epoch: 40, Loss: 108.9586
Epoch: 40, Loss: 111.2120
Epoch: 40, Loss: 114.9326
Epoch: 40, Loss: 118.1201
Epoch: 40, Loss: 121.2337
Epoch: 40, Loss: 125.0317
Epoch: 40, Loss: 127.9229
Epoch: 40, Loss: 130.3880
Epoch: 40, Loss: 133.4314


(T):  40%|██████████▊                | 40/100 [01:59<02:58,  2.97s/it, loss=136]

Epoch: 40, Loss: 136.1973
Early stopping


(T):  41%|██████████▋               | 41/100 [02:01<02:55,  2.97s/it, loss=46.1]

Early stopping


(T):  42%|██████████▉               | 42/100 [02:04<02:53,  2.98s/it, loss=35.8]

Early stopping


(T):  43%|████████████                | 43/100 [02:07<02:49,  2.98s/it, loss=27]

Early stopping


(T):  44%|███████████▍              | 44/100 [02:10<02:46,  2.97s/it, loss=29.2]

Early stopping


(T):  45%|████████████▌               | 45/100 [02:14<02:46,  3.03s/it, loss=24]

Early stopping


(T):  46%|███████████▉              | 46/100 [02:17<02:43,  3.02s/it, loss=19.1]

Early stopping


(T):  47%|████████████▏             | 47/100 [02:20<02:39,  3.01s/it, loss=19.1]

Early stopping


(T):  48%|████████████▍             | 48/100 [02:23<02:36,  3.01s/it, loss=20.6]

Early stopping


(T):  49%|████████████▋             | 49/100 [02:26<02:33,  3.00s/it, loss=23.9]

Early stopping


(T):  50%|██████████████              | 50/100 [02:29<02:29,  3.00s/it, loss=23]

Early stopping


(T):  51%|█████████████▎            | 51/100 [02:32<02:26,  2.99s/it, loss=19.2]

Early stopping


(T):  52%|█████████████▌            | 52/100 [02:34<02:23,  2.98s/it, loss=20.4]

Early stopping


(T):  53%|█████████████▊            | 53/100 [02:38<02:21,  3.00s/it, loss=19.2]

Early stopping


(T):  54%|██████████████            | 54/100 [02:41<02:17,  2.99s/it, loss=17.2]

Early stopping


(T):  55%|██████████████▎           | 55/100 [02:43<02:14,  2.98s/it, loss=14.4]

Early stopping


(T):  56%|██████████████▌           | 56/100 [02:46<02:10,  2.97s/it, loss=12.4]

Early stopping


(T):  57%|██████████████▊           | 57/100 [02:49<02:07,  2.97s/it, loss=21.6]

Early stopping


(T):  58%|███████████████           | 58/100 [02:52<02:04,  2.98s/it, loss=16.4]

Early stopping


(T):  59%|███████████████▎          | 59/100 [02:55<02:01,  2.97s/it, loss=21.6]

Early stopping
Epoch: 60, Loss: 1.8754
Epoch: 60, Loss: 2.8715
Epoch: 60, Loss: 3.4196
Epoch: 60, Loss: 4.0310
Epoch: 60, Loss: 4.6706
Epoch: 60, Loss: 5.4913
Epoch: 60, Loss: 5.9637
Epoch: 60, Loss: 6.5606
Epoch: 60, Loss: 7.0125
Epoch: 60, Loss: 7.6028
Epoch: 60, Loss: 7.9552
Epoch: 60, Loss: 8.6086
Epoch: 60, Loss: 9.1887
Epoch: 60, Loss: 10.1789
Epoch: 60, Loss: 10.7713
Epoch: 60, Loss: 12.2108
Epoch: 60, Loss: 12.9425
Epoch: 60, Loss: 15.3187
Epoch: 60, Loss: 15.7757
Epoch: 60, Loss: 16.2480


(T):  60%|███████████████▌          | 60/100 [02:58<01:59,  2.98s/it, loss=17.8]

Epoch: 60, Loss: 17.7929
Early stopping


(T):  61%|███████████████▊          | 61/100 [03:01<01:55,  2.97s/it, loss=18.8]

Early stopping


(T):  62%|████████████████          | 62/100 [03:04<01:52,  2.97s/it, loss=16.8]

Early stopping


(T):  63%|████████████████▍         | 63/100 [03:07<01:49,  2.96s/it, loss=18.6]

Early stopping


(T):  64%|████████████████▋         | 64/100 [03:10<01:47,  2.98s/it, loss=15.1]

Early stopping


(T):  65%|████████████████▉         | 65/100 [03:13<01:44,  2.98s/it, loss=14.4]

Early stopping


(T):  66%|█████████████████▏        | 66/100 [03:16<01:41,  3.00s/it, loss=11.2]

Early stopping


(T):  67%|█████████████████▍        | 67/100 [03:19<01:38,  2.98s/it, loss=14.6]

Early stopping


(T):  68%|███████████████████         | 68/100 [03:22<01:35,  2.98s/it, loss=13]

Early stopping


(T):  69%|█████████████████▉        | 69/100 [03:25<01:33,  3.00s/it, loss=15.9]

Early stopping


(T):  70%|██████████████████▏       | 70/100 [03:28<01:29,  2.99s/it, loss=14.3]

Early stopping


(T):  71%|███████████████████▉        | 71/100 [03:31<01:26,  2.98s/it, loss=14]

Early stopping


(T):  72%|██████████████████▋       | 72/100 [03:34<01:23,  2.98s/it, loss=10.1]

Early stopping


(T):  73%|██████████████████▉       | 73/100 [03:37<01:20,  2.97s/it, loss=10.6]

Early stopping


(T):  74%|███████████████████▏      | 74/100 [03:40<01:17,  2.99s/it, loss=12.9]

Early stopping


(T):  75%|███████████████████▌      | 75/100 [03:43<01:14,  2.98s/it, loss=14.4]

Early stopping


(T):  76%|████████████████████▌      | 76/100 [03:46<01:11,  2.98s/it, loss=9.3]

Early stopping


(T):  77%|████████████████████      | 77/100 [03:49<01:08,  2.99s/it, loss=15.2]

Early stopping


(T):  78%|████████████████████▎     | 78/100 [03:52<01:05,  2.99s/it, loss=10.4]

Early stopping


(T):  79%|████████████████████▌     | 79/100 [03:55<01:03,  3.04s/it, loss=11.1]

Early stopping
Epoch: 80, Loss: 0.3069
Epoch: 80, Loss: 0.5990
Epoch: 80, Loss: 1.1483
Epoch: 80, Loss: 1.7736
Epoch: 80, Loss: 1.9947
Epoch: 80, Loss: 2.2362
Epoch: 80, Loss: 2.8479
Epoch: 80, Loss: 3.3948
Epoch: 80, Loss: 3.8949
Epoch: 80, Loss: 4.3448
Epoch: 80, Loss: 4.7929
Epoch: 80, Loss: 5.3110
Epoch: 80, Loss: 6.0005
Epoch: 80, Loss: 6.3920
Epoch: 80, Loss: 6.5954
Epoch: 80, Loss: 6.9360
Epoch: 80, Loss: 7.7255
Epoch: 80, Loss: 8.3702
Epoch: 80, Loss: 9.1048
Epoch: 80, Loss: 9.6553


(T):  80%|████████████████████▊     | 80/100 [03:58<01:01,  3.06s/it, loss=10.1]

Epoch: 80, Loss: 10.0847
Early stopping


(T):  81%|██████████████████████▋     | 81/100 [04:01<00:57,  3.03s/it, loss=10]

Early stopping


(T):  82%|█████████████████████▎    | 82/100 [04:04<00:54,  3.01s/it, loss=14.5]

Early stopping


(T):  83%|█████████████████████▌    | 83/100 [04:07<00:50,  2.99s/it, loss=15.7]

Early stopping


(T):  84%|███████████████████████▌    | 84/100 [04:10<00:47,  2.99s/it, loss=18]

Early stopping


(T):  85%|██████████████████████    | 85/100 [04:13<00:45,  3.02s/it, loss=13.6]

Early stopping


(T):  86%|██████████████████████▎   | 86/100 [04:16<00:42,  3.03s/it, loss=9.94]

Early stopping


(T):  87%|████████████████████████▎   | 87/100 [04:19<00:39,  3.05s/it, loss=15]

Early stopping


(T):  88%|██████████████████████▉   | 88/100 [04:23<00:36,  3.06s/it, loss=10.1]

Early stopping


(T):  89%|███████████████████████▏  | 89/100 [04:26<00:33,  3.09s/it, loss=8.71]

Early stopping


(T):  90%|███████████████████████▍  | 90/100 [04:29<00:31,  3.11s/it, loss=11.4]

Early stopping


(T):  91%|███████████████████████▋  | 91/100 [04:32<00:27,  3.11s/it, loss=10.7]

Early stopping


(T):  92%|███████████████████████▉  | 92/100 [04:35<00:25,  3.13s/it, loss=11.6]

Early stopping


(T):  93%|████████████████████████▏ | 93/100 [04:38<00:21,  3.14s/it, loss=10.5]

Early stopping


(T):  94%|████████████████████████▍ | 94/100 [04:41<00:18,  3.14s/it, loss=7.89]

Early stopping


(T):  95%|████████████████████████▋ | 95/100 [04:45<00:15,  3.14s/it, loss=11.5]

Early stopping


(T):  96%|█████████████████████████▉ | 96/100 [04:48<00:12,  3.13s/it, loss=9.4]

Early stopping


(T):  97%|█████████████████████████▏| 97/100 [04:51<00:09,  3.14s/it, loss=8.17]

Early stopping


(T):  98%|█████████████████████████▍| 98/100 [04:54<00:06,  3.13s/it, loss=7.12]

Early stopping


(T):  99%|█████████████████████████▋| 99/100 [04:57<00:03,  3.15s/it, loss=13.1]

Early stopping
Epoch: 100, Loss: 0.7953
Epoch: 100, Loss: 1.0679
Epoch: 100, Loss: 1.4930
Epoch: 100, Loss: 4.0917
Epoch: 100, Loss: 4.4555
Epoch: 100, Loss: 4.9483
Epoch: 100, Loss: 5.4882
Epoch: 100, Loss: 5.9513
Epoch: 100, Loss: 6.6304
Epoch: 100, Loss: 7.0621
Epoch: 100, Loss: 7.3078
Epoch: 100, Loss: 7.9137
Epoch: 100, Loss: 8.3491
Epoch: 100, Loss: 8.7846
Epoch: 100, Loss: 9.0485
Epoch: 100, Loss: 9.3614
Epoch: 100, Loss: 10.1410
Epoch: 100, Loss: 10.5808
Epoch: 100, Loss: 10.9067
Epoch: 100, Loss: 11.1652


(T): 100%|█████████████████████████| 100/100 [05:00<00:00,  3.01s/it, loss=11.7]

Epoch: 100, Loss: 11.6797
Early stopping





In [23]:
test_result = test(encoder_model, dataloader, train_test_indices[0], device)

(LR): 100%|██████████| 5000/5000 [00:01<00:00, best test F1Mi=0.759, F1Ma=0.744]
(LR): 100%|██████████| 5000/5000 [00:01<00:00, best test F1Mi=0.741, F1Ma=0.735]
(LR): 100%|██████████| 5000/5000 [00:01<00:00, best test F1Mi=0.732, F1Ma=0.717]
(LR): 100%|██████████| 5000/5000 [00:01<00:00, best test F1Mi=0.732, F1Ma=0.717]
(LR): 100%|██████████| 5000/5000 [00:01<00:00, best test F1Mi=0.732, F1Ma=0.728]
(LR): 100%|███████████| 5000/5000 [00:01<00:00, best test F1Mi=0.75, F1Ma=0.731]
(LR): 100%|███████████| 5000/5000 [00:01<00:00, best test F1Mi=0.75, F1Ma=0.733]
(LR): 100%|██████████| 5000/5000 [00:01<00:00, best test F1Mi=0.732, F1Ma=0.729]
(LR): 100%|███████████| 5000/5000 [00:01<00:00, best test F1Mi=0.75, F1Ma=0.731]
(LR): 100%|██████████| 5000/5000 [00:01<00:00, best test F1Mi=0.786, F1Ma=0.782]


In [24]:
micro_f1_values = [entry['micro_f1']*100 for entry in test_result]

In [25]:
np_micro_f1_values = np.array(micro_f1_values)
micro_f1_mean = np.mean(np_micro_f1_values)
micro_f1_std = np.std(np_micro_f1_values)
uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(np_micro_f1_values, func=np.mean, 
                                                                      n_boot=1000), 95) - np_micro_f1_values.mean()))

print(f'test acc mean = {micro_f1_mean:.4f} ± {micro_f1_std:.4f}')

test acc mean = 74.6429 ± 1.6071
