In [1]:
import pickle as pkl
import os 
import sys
import numpy as np
from xopen import xopen
import json
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
import pandas as pd

import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops

def simMatrix(A: torch.tensor, B: torch.tensor) -> torch.tensor:
    # Assume A and B are your input tensors of shape (N, d)
    # Example: A = torch.randn(N, d)
    #          B = torch.randn(N, d)

    # Step 1: Normalize A and B
    A_norm = A / A.norm(dim=1, keepdim=True)
    B_norm = B / B.norm(dim=1, keepdim=True)

    # Step 2: Compute the dot product
    cosine_similarity_matrix = torch.mm(A_norm, B_norm.transpose(0, 1))

    # The resulting cosine_similarity_matrix is of shape (N, N)
    # and contains values in the range [-1, 1]
    return cosine_similarity_matrix

DATA_PATH = "/home/ubuntu/proj/data/graph/node_pubmed"
DATA_NAME = "text_graph_pubmed" # "text_graph_pubmed" #"text_graph_aids" #"text_graph_pubmed" # # 
TRAIN_SPLIT_NAME = 'train_index'
TEST_SPLIT_NAME = 'test_index'

with open(os.path.join(DATA_PATH, f"{DATA_NAME}.pkl"), 'rb') as f:
    graph = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TRAIN_SPLIT_NAME}.pkl"), 'rb') as f:
    train_split = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TEST_SPLIT_NAME}.pkl"), 'rb') as f:
    test_split = pkl.load(f)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from torch import Tensor
from torch.nn import Embedding, Sequential, Linear, ModuleList, ReLU
from typing import Callable, Union
from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size
from torch_geometric.nn import MessagePassing
class MLP(torch.nn.Module):
    def __init__(self, input_channels_node, hidden_channels, output_channels, readout='add', num_layers=3):
        super(MLP, self).__init__()
        self.readout = readout
        self.num_layers = num_layers
        self.mlp = ModuleList()
        block = Sequential(
            Linear(input_channels_node, hidden_channels),
            ReLU(),
        )
        self.mlp.append(block)
        for _ in range(self.num_layers-2):
            block = Sequential(
                Linear(hidden_channels, hidden_channels),
                ReLU(),
            )
            self.mlp.append(block)
        block = Sequential(
            Linear(hidden_channels, output_channels)
        )
        self.mlp.append(block)

    def forward(self, x):
        for i in range(self.num_layers):
            x = self.mlp[i](x)

        return x

In [3]:
encoder_type = 'angle'
relevance_type = 'pos'
pos_type = "gcn"
SAVE_DIR = f"/home/ubuntu/proj/code/axolotl_softprompt/data/pubmed/{relevance_type}"
input_channels_node=768 if encoder_type == 'bert' else 1024
order = 1
train_pos_tokens = torch.load(os.path.join(SAVE_DIR, f'train_{pos_type}_order{order}-{encoder_type}.pt'))
test_pos_tokens = torch.load(os.path.join(SAVE_DIR, f'test_{pos_type}_order{order}-{encoder_type}.pt'))

y_train, y_test = graph.y[train_split], graph.y[test_split]
num_classes = len(torch.unique(y_train))

In [4]:
import torch
from torch.utils.data import TensorDataset, DataLoader

# Create a TensorDataset
train_dataset = TensorDataset(train_pos_tokens, y_train)
test_dataset = TensorDataset(test_pos_tokens, y_test)

# create a DataLoader for batching
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

model = MLP(
    input_channels_node=input_channels_node, 
    hidden_channels=1024, 
    output_channels=num_classes
    )
device = torch.device('cuda:1')
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

In [5]:
def train(dataloader):
    loss_list = []
    model.train()
    for batch in dataloader:
        batch_x, batch_y = batch
        batch_x = batch_x.view(-1, input_channels_node)
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        output = model(batch_x)
        loss = criterion(output, batch_y.long())
        loss_list.append(float(loss))
        loss.backward()
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
    return np.mean(loss_list)

def test(dataloader):
    y_true_list, y_pred_list = [], []
    model.eval()
    for batch in dataloader:
        batch_x, batch_y = batch
        batch_x = batch_x.view(-1, input_channels_node)
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        output = model(batch_x)
        y_pred = output.argmax(dim=-1)
        y_true_list.append(batch_y.detach().cpu().numpy().reshape(-1))
        y_pred_list.append(y_pred.detach().cpu().numpy().reshape(-1))
    y_true = np.concatenate(y_true_list)
    y_pred = np.concatenate(y_pred_list)
    return y_true, y_pred

In [6]:
num_epochs = 100

best_val_res = 0
best_res = {}
for epoch in range(1, num_epochs+1):
    loss = train(train_dataloader)
    y_true_train, y_pred_train = test(train_dataloader)
    y_true_test, y_pred_test = test(test_dataloader)
    train_score = np.mean(y_true_train == y_pred_train)
    test_score = np.mean(y_true_test == y_pred_test)
    print(f"{epoch=:3d}, {loss=:.3f}, {train_score=:.3f}, {test_score=:.3f}")

epoch=  1, loss=0.405, train_score=0.874, test_score=0.837
epoch=  2, loss=0.336, train_score=0.887, test_score=0.850
epoch=  3, loss=0.304, train_score=0.903, test_score=0.860
epoch=  4, loss=0.274, train_score=0.912, test_score=0.866
epoch=  5, loss=0.250, train_score=0.917, test_score=0.853
epoch=  6, loss=0.222, train_score=0.931, test_score=0.861
epoch=  7, loss=0.199, train_score=0.941, test_score=0.870
epoch=  8, loss=0.174, train_score=0.942, test_score=0.849
epoch=  9, loss=0.157, train_score=0.953, test_score=0.864
epoch= 10, loss=0.138, train_score=0.957, test_score=0.846
epoch= 11, loss=0.119, train_score=0.963, test_score=0.844
epoch= 12, loss=0.104, train_score=0.969, test_score=0.867
epoch= 13, loss=0.104, train_score=0.963, test_score=0.850
epoch= 14, loss=0.096, train_score=0.976, test_score=0.852
epoch= 15, loss=0.074, train_score=0.986, test_score=0.861
epoch= 16, loss=0.072, train_score=0.984, test_score=0.851
epoch= 17, loss=0.058, train_score=0.981, test_score=0.8