## Knowledge graph experiments using SEAL 
### Author: Ridha Alkhabaz 

We are applying the seal paradigm onto FB15k-237

In [1]:
import argparse
import time
import os, sys
import torch
import pdb
import copy as cp
import numpy as np
import os.path as osp
import scipy.sparse as ssp
import torch.nn.functional as F
import torch_geometric.transforms as T
from tqdm import tqdm
from shutil import copy
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import DataLoader
from torch_sparse import coalesce
from sklearn.metrics import roc_auc_score
from torch_geometric.datasets import FB15k_237
from torch_geometric.data import Data, Dataset, InMemoryDataset, DataLoader
from torch_geometric.utils import to_networkx, to_undirected
from torch_geometric.datasets import RelLinkPredDataset
from utils import *
from torch.nn import (ModuleList, Linear, Conv1d, MaxPool1d, Embedding, ReLU, 
                      Sequential, BatchNorm1d as BN)
from torch_geometric.nn import (GCNConv, SAGEConv, GINConv, 
                                global_sort_pool, global_add_pool, global_mean_pool)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import warnings
warnings.filterwarnings("ignore")



In [2]:
# helpful funcstions and classes 
# this is for model definition and other things 
class DGCNN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers, max_z, k=0.6, train_dataset=None, 
                 dynamic_train=False, GNN=GCNConv, use_feature=False, 
                 node_embedding=None):
        super(DGCNN, self).__init__()

        self.use_feature = use_feature
        self.node_embedding = node_embedding

        if k <= 1:  # Transform percentile to number.
            if train_dataset is None:
                k = 30
            else:
                if dynamic_train:
                    sampled_train = train_dataset[:1000]
                else:
                    sampled_train = train_dataset
                num_nodes = sorted([g.num_nodes for g in sampled_train])
                k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1]
                k = max(10, k)
        self.k = int(k)

        self.max_z = max_z
        self.z_embedding = Embedding(self.max_z, hidden_channels)

        self.convs = ModuleList()
        initial_channels = hidden_channels
        if self.use_feature:
            initial_channels += train_dataset.num_features
        if self.node_embedding is not None:
            initial_channels += node_embedding.embedding_dim

        self.convs.append(GNN(initial_channels, hidden_channels))
        for i in range(0, num_layers-1):
            self.convs.append(GNN(hidden_channels, hidden_channels))
        self.convs.append(GNN(hidden_channels, 1))

        conv1d_channels = [16, 32]
        total_latent_dim = hidden_channels * num_layers + 1
        conv1d_kws = [total_latent_dim, 5]
        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0],
                            conv1d_kws[0])
        self.maxpool1d = MaxPool1d(2, 2)
        self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1],
                            conv1d_kws[1], 1)
        dense_dim = int((self.k - 2) / 2 + 1)
        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
        self.lin1 = Linear(dense_dim, 128)
        self.lin2 = Linear(128, 1)

    def forward(self, z, edge_index, batch, x=None, edge_weight=None, node_id=None):
        z_emb = self.z_embedding(z)
        if z_emb.ndim == 3:  # in case z has multiple integer labels
            z_emb = z_emb.sum(dim=1)
        if self.use_feature and x is not None:
            x = torch.cat([z_emb, x.to(torch.float)], 1)
        else:
            x = z_emb
        if self.node_embedding is not None and node_id is not None:
            n_emb = self.node_embedding(node_id)
            x = torch.cat([x, n_emb], 1)
        xs = [x]

        for conv in self.convs:
            xs += [torch.tanh(conv(xs[-1], edge_index, edge_weight))]
        x = torch.cat(xs[1:], dim=-1)

        # Global pooling.
        x = global_sort_pool(x, batch, self.k)
        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = F.relu(self.conv1(x))
        x = self.maxpool1d(x)
        x = F.relu(self.conv2(x))
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]

        # MLP.
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin2(x)
        return x
# this is for dataset definition
class SEALDataset(InMemoryDataset):
    def __init__(self, root, data, split_edge, num_hops, percent=100, split='train', 
                 use_coalesce=False, node_label='drnl', ratio_per_hop=1.0, 
                 max_nodes_per_hop=None, directed=False):
        self.data = data
        self.split_edge = split_edge
        self.num_hops = num_hops
        self.percent = int(percent) if percent >= 1.0 else percent
        self.split = split
        self.use_coalesce = use_coalesce
        self.node_label = node_label
        self.ratio_per_hop = ratio_per_hop
        self.max_nodes_per_hop = max_nodes_per_hop
        self.directed = directed
        super(SEALDataset, self).__init__(root)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def processed_file_names(self):
        if self.percent == 100:
            name = 'SEAL_{}_data'.format(self.split)
        else:
            name = 'SEAL_{}_data_{}'.format(self.split, self.percent)
        name += '.pt'
        return [name]

    def process(self):
        pos_edge, neg_edge = get_pos_neg_edges(self.split, self.split_edge, 
                                               self.data.edge_index, 
                                               self.data.num_nodes, 
                                               self.percent)

        if self.use_coalesce:  # compress mutli-edge into edge with weight
            self.data.edge_index, self.data.edge_weight = coalesce(
                self.data.edge_index, self.data.edge_weight, 
                self.data.num_nodes, self.data.num_nodes)

        if 'edge_weight' in self.data:
            edge_weight = self.data.edge_weight.view(-1)
        else:
            edge_weight = torch.ones(self.data.edge_index.size(1), dtype=int)
        A = ssp.csr_matrix(
            (edge_weight, (self.data.edge_index[0], self.data.edge_index[1])), 
            shape=(self.data.num_nodes, self.data.num_nodes)
        )

        if self.directed:
            A_csc = A.tocsc()
        else:
            A_csc = None
        
        # Extract enclosing subgraphs for pos and neg edges
        pos_list = extract_enclosing_subgraphs(
            pos_edge, A, self.data.x, 1, self.num_hops, self.node_label, 
            self.ratio_per_hop, self.max_nodes_per_hop, self.directed, A_csc)
        neg_list = extract_enclosing_subgraphs(
            neg_edge, A, self.data.x, 0, self.num_hops, self.node_label, 
            self.ratio_per_hop, self.max_nodes_per_hop, self.directed, A_csc)

        torch.save(self.collate(pos_list + neg_list), self.processed_paths[0])
        del pos_list, neg_list
def train():
    model.train()

    total_loss = 0
    pbar = tqdm(train_loader, ncols=70)
    for data in pbar:
        data = data.to(device)
        optimizer.zero_grad()
        x =  None
        edge_weight = None
        node_id = data.node_id if emb else None
        logits = model(data.z, data.edge_index, data.batch, x, edge_weight, node_id)
        loss = BCEWithLogitsLoss()(logits.view(-1), data.y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs

    return total_loss / len(train_dataset)
@torch.no_grad()
def test():
    model.eval()

    y_pred, y_true = [], []
    for data in tqdm(val_loader, ncols=70):
        data = data.to(device)
        x =  None
        edge_weight = None
        node_id = data.node_id if emb else None
        logits = model(data.z, data.edge_index, data.batch, x, edge_weight, node_id)
        y_pred.append(logits.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))
    val_pred, val_true = torch.cat(y_pred), torch.cat(y_true)
    pos_val_pred = val_pred[val_true==1]
    neg_val_pred = val_pred[val_true==0]

    y_pred, y_true = [], []
    for data in tqdm(test_loader, ncols=70):
        data = data.to(device)
        x =  None
        edge_weight =  None
        node_id = data.node_id if emb else None
        logits = model(data.z, data.edge_index, data.batch, x, edge_weight, node_id)
        y_pred.append(logits.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))
    test_pred, test_true = torch.cat(y_pred), torch.cat(y_true)
    pos_test_pred = test_pred[test_true==1]
    neg_test_pred = test_pred[test_true==0]
    
   
    results = evaluate_mrr(val_pred, val_true, test_pred, test_true)

    return results
def evaluate_mrr(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred):
    pos_val_pred = pos_val_pred.view(-1, 1)
    neg_val_pred = neg_val_pred.view(-1, 1)
    pos_test_pred = pos_test_pred.view(-1, 1)
    neg_test_pred = neg_test_pred.view(-1, 1)
    optimistic_rank_test = (neg_test_pred > pos_test_pred).sum(dim=1)
    optimistic_rank_val = (neg_val_pred > pos_val_pred).sum(dim=1)
    pessimistic_rank_test = (neg_test_pred >= pos_test_pred).sum(dim=1)
    pessimistic_rank_val = (neg_val_pred >= pos_val_pred).sum(dim=1)
    ranking_list_test = 0.5 * (optimistic_rank_test + pessimistic_rank_test) + 1
    ranking_list_val = 0.5 * (optimistic_rank_val + pessimistic_rank_val) + 1

    results = {}
    valid_mrr =  (1./ranking_list_val.to(torch.float)).mean()
    test_mrr = (1./ranking_list_test.to(torch.float)).mean()
   
    results['MRR'] = (valid_mrr, test_mrr)
    return results
def do_edge_split_v2(data, fast_split=False, val_ratio=0.05, test_ratio=0.1):
    random.seed(234)
    torch.manual_seed(234)
    # print('check oone')
    if not fast_split:
        # print('check two')
        data = train_test_split_edges(data, val_ratio, test_ratio)
        # print('thre')
        edge_index, _ = add_self_loops(data.train_pos_edge_index)
        # print('four')
        data.train_neg_edge_index = negative_sampling(
            edge_index, num_nodes=data.num_nodes,
            num_neg_samples=data.train_pos_edge_index.size(1))
    else:
        # print('check ttwo')
        num_nodes = data.num_nodes
        row, col = data.edge_index
    
        # print('check thre')
        # print(row, col)
        # Return upper triangular portion.
        mask = row < col
        # print(mask.shape)
        row, col = row[mask], col[mask]
        n_v = int(math.floor(val_ratio * row.size(0)))
        n_t = int(math.floor(test_ratio * row.size(0)))
        # Positive edges.
        
        perm = torch.randperm(row.size(0))
        row, col = row[perm], col[perm]
        r, c = row[:n_v], col[:n_v]
        data.val_pos_edge_index = torch.stack([r, c], dim=0)
        r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t]
        data.test_pos_edge_index = torch.stack([r, c], dim=0)
        r, c = row[n_v + n_t:], col[n_v + n_t:]
        data.train_pos_edge_index = torch.stack([r, c], dim=0)
        # Negative edges (cannot guarantee (i,j) and (j,i) won't both appear)
        neg_edge_index = negative_sampling(
            data.edge_index, num_nodes=num_nodes,
            num_neg_samples=row.size(0))
        data.val_neg_edge_index = neg_edge_index[:, :n_v]
        data.test_neg_edge_index = neg_edge_index[:, n_v:n_v + n_t]
        data.train_neg_edge_index = neg_edge_index[:, n_v + n_t:]

    split_edge = {'train': {}, 'valid': {}, 'test': {}}
    split_edge['train']['edge'] = data.train_pos_edge_index.t()
    split_edge['train']['edge_neg'] = data.train_neg_edge_index.t()
    split_edge['valid']['edge'] = data.val_pos_edge_index.t()
    split_edge['valid']['edge_neg'] = data.val_neg_edge_index.t()
    split_edge['test']['edge'] = data.test_pos_edge_index.t()
    split_edge['test']['edge_neg'] = data.test_neg_edge_index.t()
    return split_edge

In [4]:
dataset_train = FB15k_237('data')
data_train = dataset_train[0]
dataset_val = FB15k_237('data', split='val')
data_val = dataset_val[0]
dataset_test = FB15k_237('data', split='test')
data_test = dataset_test[0]

In [11]:
data = Data()
data.num_nodes = 14541
data.edge_type = torch.cat((data_train.edge_type, data_val.edge_type, data_test.edge_type),0)
data.edge_index = torch.cat((data_train.edge_index, data_val.edge_index, data_test.edge_index),1)

In [14]:
split_edge = do_edge_split_v2(data)
data.edge_index = split_edge['train']['edge'].t()

Data(num_nodes=14541, edge_type=[310116], val_pos_edge_index=[2, 6226], test_pos_edge_index=[2, 12452], train_pos_edge_index=[2, 196914], train_neg_adj_mask=[14541, 14541], val_neg_edge_index=[2, 6226], test_neg_edge_index=[2, 12452], train_neg_edge_index=[2, 196914], edge_index=[2, 196914])

In [16]:
path = dataset_train.root
train_dataset = eval('SEALDataset')(
    path, 
    data, 
    split_edge, 
    num_hops=1, 
    percent=100, 
    split='train', 
    use_coalesce=False, 
    node_label='drnl', 
    ratio_per_hop=1.0, 
    max_nodes_per_hop=None, 
    directed=True, 
)
val_dataset = eval('SEALDataset')(
    path, 
    data, 
    split_edge, 
    num_hops=1, 
    percent=100, 
    split='valid', 
    use_coalesce=False, 
    node_label='drnl', 
    ratio_per_hop=1.0, 
    max_nodes_per_hop=None, 
    directed=True, 
) 
test_dataset = eval('SEALDataset')(
    path, 
    data, 
    split_edge, 
    num_hops=1, 
    percent=100, 
    split='test', 
    use_coalesce=False, 
    node_label='drnl', 
    ratio_per_hop=1.0, 
    max_nodes_per_hop=None, 
    directed=True, 
)  

In [17]:
max_z = 1000 
train_loader = DataLoader(train_dataset, batch_size=32, 
                          shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, 
                           num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, 
                           num_workers=0)
model = DGCNN(hidden_channels=32, num_layers=3, max_z=max_z, k=0.6, 
                      train_dataset=train_dataset, use_feature=None, 
                      node_embedding=None).to(device)

In [18]:
parameters = list(model.parameters())
optimizer = torch.optim.Adam(params=parameters, lr=0.0001)
total_params = sum(p.numel() for param in parameters for p in param)
print(f'Total number of parameters is {total_params}')

Total number of parameters is 101058


In [19]:
emb = None
start_epoch = 1
loggers = {}
run = 1
for epoch in range(start_epoch, start_epoch + 30):
        loss = train()

        if epoch % 5 == 0:
            results = test()
            # for key, result in results.items():
            #     loggers[key].add_result(run, result)

            # if epoch % 1 == 0:
                # model_name = os.path.join(
                #     args.res_dir, 'run{}_model_checkpoint{}.pth'.format(run+1, epoch))
                # optimizer_name = os.path.join(
                #     args.res_dir, 'run{}_optimizer_checkpoint{}.pth'.format(run+1, epoch))
                # torch.save(model.state_dict(), model_name)
                # torch.save(optimizer.state_dict(), optimizer_name)

            for key, result in results.items():
                valid_res, test_res = result
                to_print = (f'Run: {run + 1:02d}, Epoch: {epoch:02d}, ' +
                            f'Loss: {loss:.4f}, Valid: {100 * valid_res:.2f}%, ' +
                            f'Test: {100 * test_res:.2f}%')
                print(key)
                print(to_print)
                    # with open(log_file, 'a') as f:
                    #     print(key, file=f)
                    #     print(to_print, file=f)

100%|███████████████████████████| 10817/10817 [03:16<00:00, 55.00it/s]
100%|███████████████████████████| 10817/10817 [03:23<00:00, 53.18it/s]
100%|███████████████████████████| 10817/10817 [03:17<00:00, 54.79it/s]
100%|███████████████████████████| 10817/10817 [03:18<00:00, 54.38it/s]
100%|███████████████████████████| 10817/10817 [03:11<00:00, 56.53it/s]
100%|███████████████████████████████| 345/345 [00:04<00:00, 85.61it/s]
100%|███████████████████████████████| 691/691 [00:08<00:00, 80.51it/s]


MRR
Run: 02, Epoch: 05, Loss: 0.2803, Valid: 72.08%, Test: 71.93%


100%|███████████████████████████| 10817/10817 [03:19<00:00, 54.28it/s]
100%|███████████████████████████| 10817/10817 [03:12<00:00, 56.24it/s]
100%|███████████████████████████| 10817/10817 [03:15<00:00, 55.37it/s]
100%|███████████████████████████| 10817/10817 [03:15<00:00, 55.33it/s]
100%|███████████████████████████| 10817/10817 [03:16<00:00, 55.15it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 109.65it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 113.11it/s]


MRR
Run: 02, Epoch: 10, Loss: 0.2764, Valid: 71.89%, Test: 71.66%


100%|███████████████████████████| 10817/10817 [03:16<00:00, 55.01it/s]
100%|███████████████████████████| 10817/10817 [03:28<00:00, 51.82it/s]
100%|███████████████████████████| 10817/10817 [03:35<00:00, 50.24it/s]
100%|███████████████████████████| 10817/10817 [03:23<00:00, 53.09it/s]
100%|███████████████████████████| 10817/10817 [03:18<00:00, 54.55it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 107.63it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 109.33it/s]


MRR
Run: 02, Epoch: 15, Loss: 0.2745, Valid: 72.68%, Test: 72.42%


100%|███████████████████████████| 10817/10817 [03:15<00:00, 55.38it/s]
100%|███████████████████████████| 10817/10817 [03:16<00:00, 55.04it/s]
100%|███████████████████████████| 10817/10817 [03:16<00:00, 54.96it/s]
100%|███████████████████████████| 10817/10817 [03:16<00:00, 55.05it/s]
100%|███████████████████████████| 10817/10817 [03:16<00:00, 55.02it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 106.86it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 105.75it/s]


MRR
Run: 02, Epoch: 20, Loss: 0.2732, Valid: 72.26%, Test: 72.05%


100%|███████████████████████████| 10817/10817 [03:16<00:00, 55.18it/s]
100%|███████████████████████████| 10817/10817 [03:14<00:00, 55.50it/s]
100%|███████████████████████████| 10817/10817 [03:17<00:00, 54.82it/s]
100%|███████████████████████████| 10817/10817 [03:17<00:00, 54.79it/s]
100%|███████████████████████████| 10817/10817 [03:17<00:00, 54.74it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 104.63it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 109.49it/s]


MRR
Run: 02, Epoch: 25, Loss: 0.2732, Valid: 71.76%, Test: 71.60%


100%|███████████████████████████| 10817/10817 [03:21<00:00, 53.76it/s]
100%|███████████████████████████| 10817/10817 [03:16<00:00, 55.18it/s]
100%|███████████████████████████| 10817/10817 [03:17<00:00, 54.66it/s]
100%|███████████████████████████| 10817/10817 [03:17<00:00, 54.76it/s]
100%|███████████████████████████| 10817/10817 [03:18<00:00, 54.48it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 107.14it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 108.49it/s]

MRR
Run: 02, Epoch: 30, Loss: 0.2713, Valid: 72.08%, Test: 71.90%





In [23]:
torch.save(model.state_dict(), './mods/seal_fbk_epch30.pt')
torch.save(optimizer.state_dict(), './mods/seal_fbk_optim_epch30.pt')

In [24]:
emb = None
start_epoch = 31
loggers = {}
run = 2
for epoch in range(start_epoch, start_epoch + 70):
        loss = train()

        if epoch % 5 == 0:
            results = test()
            # for key, result in results.items():
            #     loggers[key].add_result(run, result)

            # if epoch % 1 == 0:
                # model_name = os.path.join(
                #     args.res_dir, 'run{}_model_checkpoint{}.pth'.format(run+1, epoch))
                # optimizer_name = os.path.join(
                #     args.res_dir, 'run{}_optimizer_checkpoint{}.pth'.format(run+1, epoch))
                # torch.save(model.state_dict(), model_name)
                # torch.save(optimizer.state_dict(), optimizer_name)

            for key, result in results.items():
                valid_res, test_res = result
                to_print = (f'Run: {run + 1:02d}, Epoch: {epoch:02d}, ' +
                            f'Loss: {loss:.4f}, Valid: {100 * valid_res:.2f}%, ' +
                            f'Test: {100 * test_res:.2f}%')
                print(key)
                print(to_print)
                    # with open(log_file, 'a') as f:
                    #     print(key, file=f)
                    #     print(to_print, file=f)

100%|███████████████████████████| 10817/10817 [10:09<00:00, 17.74it/s]
100%|███████████████████████████| 10817/10817 [03:33<00:00, 50.71it/s]
100%|███████████████████████████| 10817/10817 [03:50<00:00, 46.97it/s]
100%|███████████████████████████| 10817/10817 [03:47<00:00, 47.45it/s]
100%|███████████████████████████| 10817/10817 [03:39<00:00, 49.37it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 102.82it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 108.16it/s]


MRR
Run: 03, Epoch: 35, Loss: 0.2702, Valid: 71.08%, Test: 70.82%


100%|███████████████████████████| 10817/10817 [03:26<00:00, 52.41it/s]
100%|███████████████████████████| 10817/10817 [03:19<00:00, 54.11it/s]
100%|███████████████████████████| 10817/10817 [03:24<00:00, 52.81it/s]
100%|███████████████████████████| 10817/10817 [03:21<00:00, 53.71it/s]
100%|███████████████████████████| 10817/10817 [03:21<00:00, 53.66it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 106.05it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 108.28it/s]


MRR
Run: 03, Epoch: 40, Loss: 0.2694, Valid: 71.92%, Test: 71.74%


100%|███████████████████████████| 10817/10817 [03:21<00:00, 53.69it/s]
100%|███████████████████████████| 10817/10817 [03:19<00:00, 54.21it/s]
100%|███████████████████████████| 10817/10817 [03:19<00:00, 54.20it/s]
100%|███████████████████████████| 10817/10817 [03:20<00:00, 53.95it/s]
100%|███████████████████████████| 10817/10817 [03:20<00:00, 53.84it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 106.25it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 110.54it/s]


MRR
Run: 03, Epoch: 45, Loss: 0.2689, Valid: 72.58%, Test: 72.42%


100%|███████████████████████████| 10817/10817 [03:20<00:00, 53.90it/s]
100%|███████████████████████████| 10817/10817 [03:22<00:00, 53.35it/s]
100%|███████████████████████████| 10817/10817 [03:22<00:00, 53.43it/s]
100%|███████████████████████████| 10817/10817 [03:22<00:00, 53.47it/s]
100%|███████████████████████████| 10817/10817 [03:23<00:00, 53.13it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 105.28it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 106.80it/s]


MRR
Run: 03, Epoch: 50, Loss: 0.2692, Valid: 71.75%, Test: 71.59%


100%|███████████████████████████| 10817/10817 [03:21<00:00, 53.73it/s]
100%|███████████████████████████| 10817/10817 [03:20<00:00, 53.84it/s]
100%|███████████████████████████| 10817/10817 [03:22<00:00, 53.46it/s]
100%|███████████████████████████| 10817/10817 [03:23<00:00, 53.17it/s]
100%|███████████████████████████| 10817/10817 [03:23<00:00, 53.08it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 105.22it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 105.95it/s]


MRR
Run: 03, Epoch: 55, Loss: 0.2684, Valid: 72.14%, Test: 71.83%


100%|███████████████████████████| 10817/10817 [03:23<00:00, 53.08it/s]
100%|███████████████████████████| 10817/10817 [03:22<00:00, 53.40it/s]
100%|███████████████████████████| 10817/10817 [03:22<00:00, 53.44it/s]
100%|███████████████████████████| 10817/10817 [03:21<00:00, 53.58it/s]
100%|███████████████████████████| 10817/10817 [04:38<00:00, 38.83it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 105.91it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 107.75it/s]


MRR
Run: 03, Epoch: 60, Loss: 0.2692, Valid: 72.01%, Test: 71.82%


100%|███████████████████████████| 10817/10817 [03:31<00:00, 51.10it/s]
100%|███████████████████████████| 10817/10817 [03:32<00:00, 50.97it/s]
100%|███████████████████████████| 10817/10817 [03:31<00:00, 51.22it/s]
100%|███████████████████████████| 10817/10817 [03:25<00:00, 52.76it/s]
100%|███████████████████████████| 10817/10817 [03:25<00:00, 52.76it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 104.53it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 108.37it/s]


MRR
Run: 03, Epoch: 65, Loss: 0.2684, Valid: 71.30%, Test: 71.07%


100%|███████████████████████████| 10817/10817 [03:24<00:00, 52.87it/s]
100%|███████████████████████████| 10817/10817 [03:24<00:00, 52.78it/s]
100%|███████████████████████████| 10817/10817 [03:25<00:00, 52.55it/s]
100%|███████████████████████████| 10817/10817 [03:24<00:00, 52.88it/s]
100%|███████████████████████████| 10817/10817 [03:25<00:00, 52.64it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 104.04it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 107.05it/s]


MRR
Run: 03, Epoch: 70, Loss: 0.2683, Valid: 71.81%, Test: 71.62%


100%|███████████████████████████| 10817/10817 [03:26<00:00, 52.38it/s]
100%|███████████████████████████| 10817/10817 [03:25<00:00, 52.51it/s]
100%|███████████████████████████| 10817/10817 [03:25<00:00, 52.57it/s]
100%|███████████████████████████| 10817/10817 [03:26<00:00, 52.42it/s]
100%|███████████████████████████| 10817/10817 [03:25<00:00, 52.61it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 103.67it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 106.42it/s]


MRR
Run: 03, Epoch: 75, Loss: 0.2674, Valid: 72.38%, Test: 72.18%


100%|███████████████████████████| 10817/10817 [03:26<00:00, 52.47it/s]
100%|███████████████████████████| 10817/10817 [03:31<00:00, 51.16it/s]
100%|███████████████████████████| 10817/10817 [03:35<00:00, 50.13it/s]
100%|███████████████████████████| 10817/10817 [03:34<00:00, 50.41it/s]
100%|███████████████████████████| 10817/10817 [03:33<00:00, 50.65it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 101.39it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 102.98it/s]


MRR
Run: 03, Epoch: 80, Loss: 0.2672, Valid: 72.60%, Test: 72.51%


100%|███████████████████████████| 10817/10817 [03:34<00:00, 50.45it/s]
100%|███████████████████████████| 10817/10817 [03:33<00:00, 50.70it/s]
100%|███████████████████████████| 10817/10817 [03:34<00:00, 50.46it/s]
100%|███████████████████████████| 10817/10817 [03:33<00:00, 50.62it/s]
100%|███████████████████████████| 10817/10817 [03:34<00:00, 50.46it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 102.80it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 102.00it/s]


MRR
Run: 03, Epoch: 85, Loss: 0.2693, Valid: 71.82%, Test: 71.59%


100%|███████████████████████████| 10817/10817 [03:34<00:00, 50.34it/s]
100%|███████████████████████████| 10817/10817 [03:32<00:00, 50.82it/s]
100%|███████████████████████████| 10817/10817 [03:35<00:00, 50.10it/s]
100%|███████████████████████████| 10817/10817 [03:35<00:00, 50.22it/s]
100%|███████████████████████████| 10817/10817 [03:34<00:00, 50.44it/s]
100%|███████████████████████████████| 345/345 [00:03<00:00, 97.31it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 104.06it/s]


MRR
Run: 03, Epoch: 90, Loss: 0.2685, Valid: 72.09%, Test: 71.90%


100%|███████████████████████████| 10817/10817 [03:34<00:00, 50.46it/s]
100%|███████████████████████████| 10817/10817 [03:34<00:00, 50.50it/s]
100%|███████████████████████████| 10817/10817 [03:34<00:00, 50.47it/s]
100%|███████████████████████████| 10817/10817 [03:35<00:00, 50.28it/s]
100%|███████████████████████████| 10817/10817 [03:35<00:00, 50.30it/s]
100%|███████████████████████████████| 345/345 [00:03<00:00, 99.99it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 102.62it/s]


MRR
Run: 03, Epoch: 95, Loss: 0.2677, Valid: 72.14%, Test: 71.94%


100%|███████████████████████████| 10817/10817 [03:36<00:00, 49.88it/s]
100%|███████████████████████████| 10817/10817 [03:35<00:00, 50.20it/s]
100%|███████████████████████████| 10817/10817 [03:36<00:00, 49.98it/s]
100%|███████████████████████████| 10817/10817 [03:37<00:00, 49.76it/s]
100%|███████████████████████████| 10817/10817 [03:36<00:00, 50.04it/s]
100%|██████████████████████████████| 345/345 [00:03<00:00, 100.22it/s]
100%|██████████████████████████████| 691/691 [00:06<00:00, 101.09it/s]

MRR
Run: 03, Epoch: 100, Loss: 0.2663, Valid: 71.76%, Test: 71.56%





In [25]:
torch.save(model.state_dict(), './mods/seal_fbk_epch100.pt')
torch.save(optimizer.state_dict(), './mods/seal_fbk_optim_epch100.pt')