In [None]:
import random
import wandb
from tqdm import tqdm
import networkx as nx
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import plotly.express as px

import torch

print(torch.__version__)
print(torch.version.cuda)

import torch_geometric
import torch_geometric.nn as pyg_nn
from torch_geometric.utils import to_networkx
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Entities
from torch_geometric.nn import GATConv
from torch_geometric.utils import k_hop_subgraph


import collections
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib as mpl
import plotly.express as px

import seaborn as sns
from functools import partial
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score


import collections
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV


import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

pl.seed_everything(11)

## 1.4 Contrastive Learning

In [None]:
from torch_geometric.datasets import FB15k_237

In [None]:
train_data = FB15k_237('data/FB15k_237', split='train')[0]
val_data = FB15k_237('data/FB15k_237', split='val')[0]
test_data = FB15k_237('data/FB15k_237', split='test')[0]

In [None]:
train_node_features = torch.zeros((train_data.num_nodes, len(torch.unique(train_data.edge_type))), dtype=torch.float16)
for i in range(len(train_data.edge_type)):
    b = train_data.edge_type[i]
    a = train_data.edge_index[0, i]
    
    train_node_features[a, b] += 1

In [None]:
class ContrastiveModel(pl.LightningModule):

    def __init__(self, triplets=None, in_dim=237, hidden_dim=256, num_relations = 237, dropout=0.3, num_hidden = 0, num_bases:int=None):
        super().__init__()
        
        self.anchors, self.positives, self.negatives = triplets

        num_anchors = len(self.anchors)
        permutation = torch.randperm(num_anchors)

        self.anchors = self.anchors[permutation]
        self.positives = self.positives[permutation]
        self.negatives = self.negatives[permutation]

        num_train = int(0.9 * num_anchors)

        self.train_anchors, self.train_pos, self.train_neg = self.anchors[:num_train], self.positives[:num_train], self.negatives[:num_train]
        self.val_anchors, self.val_pos, self.val_neg = self.anchors[num_train:], self.positives[num_train:], self.negatives[num_train:]
    

        self.layers = torch.nn.ModuleList()
        self.layers.append(pyg_nn.RGCNConv(in_dim, hidden_dim, num_relations, num_bases=num_bases))
        for i in range(num_hidden):
            self.layers.append(pyg_nn.RGCNConv(hidden_dim, hidden_dim, num_relations, num_bases=num_bases))

        self.activation = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(dropout)
        self.norm = torch.nn.LayerNorm(hidden_dim)
        self.loss = torch.nn.TripletMarginLoss()
        self.save_hyperparameters()

    def forward(self, x, edge_index, edge_type):
        for i in range(len(self.layers)-1):
            x = self.layers[i](x, edge_index, edge_type)
            x = self.activation(x)
            x = self.dropout(x)
            x = self.norm(x)

        x = self.layers[-1](x, edge_index, edge_type)
        return x

    def training_step(self, batch):
        emb = self(batch.x.float(), batch.edge_index, batch.edge_type)
        loss = self.loss(emb[self.train_anchors], emb[self.train_pos], emb[self.train_neg])

        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch):
        emb = self(batch.x.float(), batch.edge_index, batch.edge_type)
        loss = self.cost(emb[self.val_anchors], emb[self.val_pos], emb[self.val_neg])

        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-3)

In [None]:
data_list = [torch_geometric.data.Data(x=train_node_features, edge_index=train_data.edge_index, edge_type=train_data.edge_type)]
loader = torch_geometric.loader.DataLoader(data_list, shuffle=False, batch_size=1, num_workers=2)

triplets = torch_geometric.utils.structured_negative_sampling(train_data.edge_index)

model = ContrastiveModel(triplets=triplets, num_bases=42)
trainer = pl.Trainer(max_epochs=100, enable_model_summary=False, enable_progress_bar=False, \
                    log_every_n_steps=1, accelerator='auto')

trainer.fit(model, loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/moritzduck/miniconda3/envs/graphml/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/Users/moritzduck/miniconda3/envs/graphml/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
Missing logger folder: /Users/moritzduck/workspaces/graphml/project2/lightning_logs
/

In [None]:
model.eval()
x = data_list[0].x.float()
edge_index = data_list[0].edge_index
edge_type = data_list[0].edge_type
embeddings = model(x, edge_index, edge_type).detach().requires_grad_(False)

In [None]:
#torch.save(embeddings, "embeddings.pt")
#embeddings = torch.load("embeddings.pt")

### train a scoring function

In [None]:
class Scoring(pl.LightningModule):

    def __init__(self,num_relations = 237,num_nodes = -1,hidden_dim = 256,lr = 1e-3,batch_size: int = 1):
        super().__init__()
       
        self.embeddings = embeddings

        self.lr = lr
        self.batch_size = batch_size
        self.num_relations = num_relations
        self.num_nodes = num_nodes
       
        self.Wr = torch.nn.Bilinear(hidden_dim, hidden_dim, 2, bias = False)
        self.sig = torch.nn.Sigmoid()
        self.loss = torch.nn.BCELoss()

    def training_step(self, batch):
        target, source = batch
        pos_labels = torch.zeros((target.shape[0],2), dtype=torch.float)
        neg_labels = torch.zeros((target.shape[0],2), dtype=torch.float)
        pos_labels[:,0] = 1
        neg_labels[:,1] = 1
        
        
        idx_target = torch.randint(0, self.num_nodes, (target.shape[0],))
        idx_source = torch.randint(0, self.num_nodes, (target.shape[0],))
        neg_samples_target = self.embeddings[idx_target]
        neg_samples_source = self.embeddings[idx_source]
       
        h = torch.cat((target, neg_samples_target), dim = 0)
        t = torch.cat((source, neg_samples_source), dim = 0)
        y = torch.cat((pos_labels, neg_labels), dim = 0)
        
        pred = self.Wr(h,t)
        logits = self.sig(pred)
        loss = self.loss(logits, y)
        self.log("train_loss", loss)
        return loss


    def validation_step(self, batch):
        target, source = batch
        pos_labels = torch.zeros((target.shape[0],2), dtype=torch.float)
        pos_labels[:,0] = 1
        y = pos_labels
        pred = self.Wr(target,source)
        logits = self.sig(pred)
        loss = self.loss(logits, y)
        
        self.log("val_loss", loss)
        return loss

    def classify(self, h, t):
        pred = self.Wr(h,t)
        logits = self.sig(pred)
        return logits[:,0]

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)


In [None]:
def subset_to_r(data, relationship_type):
    mask = data.edge_type == relationship_type
    edge_index = data.edge_index[:,mask]
    return list(zip(embeddings[edge_index[0]], embeddings[edge_index[1]]))
    

In [None]:
models = []

topk_acc = torch.tensor(0.0)
mmr_acc = torch.tensor(0.0)

for rel_type in range(237):
    train_set = subset_to_r(train_data, rel_type)
    validation_set = subset_to_r(val_data, rel_type)
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=1024)
    val_loader = torch.utils.data.DataLoader(validation_set, batch_size=1024)
    model = Scoring(num_nodes = train_data.num_nodes, batch_size=1024)
    trainer = pl.Trainer(max_epochs=200, log_every_n_steps=1, accelerator='cpu', enable_model_summary=False, enable_progress_bar=False)
    trainer.fit(model, train_loader, val_loader)
    models.append(model)


In [None]:
def get_scores(models, r_type):
    test_set = subset_to_r(test_data, r_type)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=2048, shuffle=False)
    prob_per_r = None
    for batch in test_loader:
        v_t, v_s = batch
        prob_per_r_batch = []
        for model in models:
            prob_per_r_batch.append(model.classify(v_t, v_s))

        if prob_per_r is None:
            prob_per_r = torch.stack(prob_per_r_batch).squeeze(1)
        else:
            prob_per_r = torch.cat((prob_per_r, torch.stack(prob_per_r_batch).squeeze(1)))

    return prob_per_r

In [None]:
from torchmetrics.retrieval import RetrievalMRR, RetrievalHitRate

topk_accs = []
mmr_accs = []
weights = []

for rel_type in range(237):
    ranking = get_scores(models, rel_type)
    if ranking is None:
        continue

    ranking = ranking.detach().numpy().T

    if len(ranking.shape) == 1:
        ranking = ranking.reshape(1, -1)

    true_labels = np.zeros(ranking.shape[0]) + rel_type

    targets = np.zeros((ranking.shape[0], ranking.shape[1]))
    targets[:, rel_type] = 1

    indices = np.arange(ranking.shape[0]).reshape(-1,1).repeat(ranking.shape[1], axis=1)

    topk_acc = RetrievalHitRate(top_k=10)(torch.tensor(ranking), torch.tensor(targets), torch.tensor(indices))
    mmr_acc = RetrievalMRR()(torch.tensor(ranking), torch.tensor(targets), torch.tensor(indices))

    weights.append(ranking.shape[0])
    topk_accs.append(topk_acc)
    mmr_accs.append(mmr_acc)


In [None]:
# compute weighted average for metrics
topk_accs = np.array(topk_accs)
mmr_accs = np.array(mmr_accs)
weights = np.array(weights)

topk_acc = np.average(topk_accs, weights=weights)
mmr_acc = np.average(mmr_accs, weights=weights)

print(f'TopK Accuracy: {topk_acc:.4f}')
print(f'MMR Accuracy: {mmr_acc:.4f}')

TopK Accuracy: 0.7249
MMR Accuracy: 0.3661
