In [1]:
import pickle as pkl
import os 
import sys
import numpy as np
from xopen import xopen
import json
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
import pandas as pd

import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops

def simMatrix(A: torch.tensor, B: torch.tensor) -> torch.tensor:
    # Assume A and B are your input tensors of shape (N, d)
    # Example: A = torch.randn(N, d)
    #          B = torch.randn(N, d)

    # Step 1: Normalize A and B
    A_norm = A / A.norm(dim=1, keepdim=True)
    B_norm = B / B.norm(dim=1, keepdim=True)

    # Step 2: Compute the dot product
    cosine_similarity_matrix = torch.mm(A_norm, B_norm.transpose(0, 1))

    # The resulting cosine_similarity_matrix is of shape (N, N)
    # and contains values in the range [-1, 1]
    return cosine_similarity_matrix

DATA_PATH = "/home/ubuntu/proj/data/graph/node_pubmed"
DATA_NAME = "text_graph_pubmed" # "text_graph_pubmed" #"text_graph_aids" #"text_graph_pubmed" # # 
TRAIN_SPLIT_NAME = 'train_index'
VALID_SPLIT_NAME = 'valid_index'
TEST_SPLIT_NAME = 'test_index'

with open(os.path.join(DATA_PATH, f"{DATA_NAME}.pkl"), 'rb') as f:
    graph = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TRAIN_SPLIT_NAME}.pkl"), 'rb') as f:
    train_split = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{VALID_SPLIT_NAME}.pkl"), 'rb') as f:
    valid_split = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TEST_SPLIT_NAME}.pkl"), 'rb') as f:
    test_split = pkl.load(f)
k = 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
all_levels_embedding = dict()
for order in range(0, k+1):
    all_levels_embedding[order] = torch.load(os.path.join(DATA_PATH, f"order-{order}.pt"))

In [3]:
from gnn import GCN, SAGE
device = torch.device('cuda:0')
num_features = 768
hidden_channels = 768
num_layers = 2
dropout = 0.5
gnn = 'sage' #'gcn'

if gnn == 'gcn':
    model = GCN(num_features, hidden_channels,
                num_features, num_layers,
                dropout).to(device)
elif gnn == 'sage':
    model = SAGE(num_features, hidden_channels,
                num_features, num_layers,
                dropout).to(device)

lr = 1e-3
num_epochs = 1000

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = torch.nn.NLLLoss() #torch.nn.CrossEntropyLoss()

In [4]:
model.train()
for epoch in range(1, 1 + num_epochs):
    optimizer.zero_grad()
    x = all_levels_embedding[0].to(device)
    edge_index = graph.edge_index.to(device)
    embeddings_mapping = model.encode(x, edge_index)
    target = torch.eye(graph.num_nodes, device=device, dtype=torch.long)
    loss_map = {}
    for t in range(1, num_layers+1):
        corrmatrix = simMatrix(
            all_levels_embedding[num_layers].to(device), 
            embeddings_mapping[t]
            )
        loss_map[t] = criterion(torch.sigmoid(corrmatrix.view(-1)), target.view(-1))
        #loss_map[t] = #F.cross_entropy(torch.sigmoid(corrmatrix.view(-1)), target.view(-1))
    loss = torch.sum(torch.stack([loss_map[i] for i in range(1, num_layers+1)]))
    loss.backward()
    optimizer.step()
    print(f"{epoch=}," + ",".join([f"order {i}: {float(loss_map[i]):.4f}" for i in range(1, num_layers+1)]))

epoch=1,order 1: -0.4935,order 2: -0.4975
epoch=2,order 1: -0.6036,order 2: -0.5838
epoch=3,order 1: -0.6347,order 2: -0.6364
epoch=4,order 1: -0.6476,order 2: -0.6522
epoch=5,order 1: -0.6546,order 2: -0.6687
epoch=6,order 1: -0.6595,order 2: -0.6715
epoch=7,order 1: -0.6632,order 2: -0.6806
epoch=8,order 1: -0.6662,order 2: -0.6868
epoch=9,order 1: -0.6689,order 2: -0.6846
epoch=10,order 1: -0.6713,order 2: -0.6889
epoch=11,order 1: -0.6735,order 2: -0.6858
epoch=12,order 1: -0.6755,order 2: -0.6893
epoch=13,order 1: -0.6773,order 2: -0.6916
epoch=14,order 1: -0.6791,order 2: -0.6956
epoch=15,order 1: -0.6807,order 2: -0.6927
epoch=16,order 1: -0.6822,order 2: -0.6954
epoch=17,order 1: -0.6837,order 2: -0.6954
epoch=18,order 1: -0.6850,order 2: -0.6991
epoch=19,order 1: -0.6863,order 2: -0.6961
epoch=20,order 1: -0.6875,order 2: -0.6986
epoch=21,order 1: -0.6887,order 2: -0.6985
epoch=22,order 1: -0.6897,order 2: -0.6991
epoch=23,order 1: -0.6908,order 2: -0.7002
epoch=24,order 1: -0

In [5]:
SAVE_DIR = "/home/ubuntu/proj/code/axolotl_softprompt/data/cora"
embed = all_levels_embedding[0]
train_pos_tokens, valid_pos_tokens, test_pos_tokens = embed[torch.tensor(train_split)], embed[torch.tensor(valid_split)], embed[torch.tensor(test_split)]
train_pos_tokens = train_pos_tokens.view(-1, 1, 768)
valid_pos_tokens = valid_pos_tokens.view(-1, 1, 768)
test_pos_tokens = test_pos_tokens.view(-1, 1, 768)
torch.save(train_pos_tokens, os.path.join(SAVE_DIR, f'train_textual_order0.pt'))
torch.save(valid_pos_tokens, os.path.join(SAVE_DIR, f'valid_textual_order0.pt'))
torch.save(test_pos_tokens, os.path.join(SAVE_DIR, f'test_textual_order0.pt'))

for i in range(1, num_layers+1):
    # save textual BERT embeddings
    embed = all_levels_embedding[i]
    train_pos_tokens, valid_pos_tokens, test_pos_tokens = embed[torch.tensor(train_split)], embed[torch.tensor(valid_split)], embed[torch.tensor(test_split)]
    train_pos_tokens = train_pos_tokens.view(-1, 1, 768)
    valid_pos_tokens = valid_pos_tokens.view(-1, 1, 768)
    test_pos_tokens = test_pos_tokens.view(-1, 1, 768)
    torch.save(train_pos_tokens, os.path.join(SAVE_DIR, f'train_textual_order{i}.pt'))
    torch.save(valid_pos_tokens, os.path.join(SAVE_DIR, f'valid_textual_order{i}.pt'))
    torch.save(test_pos_tokens, os.path.join(SAVE_DIR, f'test_textual_order{i}.pt'))

    # save GNN embeddings
    embed = embeddings_mapping[i].detach().cpu()
    train_pos_tokens, valid_pos_tokens, test_pos_tokens = embed[torch.tensor(train_split)], embed[torch.tensor(valid_split)], embed[torch.tensor(test_split)]
    train_pos_tokens = train_pos_tokens.view(-1, 1, 768)
    valid_pos_tokens = valid_pos_tokens.view(-1, 1, 768)
    test_pos_tokens = test_pos_tokens.view(-1, 1, 768)
    torch.save(train_pos_tokens, os.path.join(SAVE_DIR, f'train_{gnn}_order{i}.pt'))
    torch.save(valid_pos_tokens, os.path.join(SAVE_DIR, f'valid_{gnn}_order{i}.pt'))
    torch.save(test_pos_tokens, os.path.join(SAVE_DIR, f'test_{gnn}_order{i}.pt'))