In [1]:
import pickle as pkl
import os 
import sys
import numpy as np
from xopen import xopen
import json
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
import pandas as pd

import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops

def simMatrix(A: torch.tensor, B: torch.tensor) -> torch.tensor:
    # Assume A and B are your input tensors of shape (N, d)
    # Example: A = torch.randn(N, d)
    #          B = torch.randn(N, d)

    # Step 1: Normalize A and B
    A_norm = A / A.norm(dim=1, keepdim=True)
    B_norm = B / B.norm(dim=1, keepdim=True)

    # Step 2: Compute the dot product
    cosine_similarity_matrix = torch.mm(A_norm, B_norm.transpose(0, 1))

    # The resulting cosine_similarity_matrix is of shape (N, N)
    # and contains values in the range [-1, 1]
    return cosine_similarity_matrix

DATA_PATH = "/home/ubuntu/proj/data/graph/node_cora"
DATA_NAME = "text_graph_cora" # "text_graph_pubmed" #"text_graph_aids" #"text_graph_pubmed" # # 
TRAIN_SPLIT_NAME = 'train_index'
VALID_SPLIT_NAME = 'valid_index'
TEST_SPLIT_NAME = 'test_index'

with open(os.path.join(DATA_PATH, f"{DATA_NAME}.pkl"), 'rb') as f:
    graph = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TRAIN_SPLIT_NAME}.pkl"), 'rb') as f:
    train_split = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{VALID_SPLIT_NAME}.pkl"), 'rb') as f:
    valid_split = pkl.load(f)
with open(os.path.join(DATA_PATH, f"{TEST_SPLIT_NAME}.pkl"), 'rb') as f:
    test_split = pkl.load(f)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from torch import Tensor
from torch.nn import Embedding, Sequential, Linear, ModuleList, ReLU
from typing import Callable, Union
from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size
from torch_geometric.nn import MessagePassing
class MLP(torch.nn.Module):
    def __init__(self, input_channels_node, hidden_channels, output_channels, readout='add', num_layers=3):
        super(MLP, self).__init__()
        self.readout = readout
        self.num_layers = num_layers
        self.mlp = ModuleList()
        block = Sequential(
            Linear(input_channels_node, hidden_channels),
            ReLU(),
        )
        self.mlp.append(block)
        for _ in range(self.num_layers-2):
            block = Sequential(
                Linear(hidden_channels, hidden_channels),
                ReLU(),
            )
            self.mlp.append(block)
        block = Sequential(
            Linear(hidden_channels, output_channels)
        )
        self.mlp.append(block)

    def forward(self, x):
        for i in range(self.num_layers):
            x = self.mlp[i](x)

        return x

In [3]:
SAVE_DIR = "/home/ubuntu/proj/code/axolotl_softprompt/data/cora"
pos_type = "textual"
order = 2
train_pos_tokens = torch.load(os.path.join(SAVE_DIR, f'train_{pos_type}_order{order}.pt'))
valid_pos_tokens = torch.load(os.path.join(SAVE_DIR, f'valid_{pos_type}_order{order}.pt'))
test_pos_tokens = torch.load(os.path.join(SAVE_DIR, f'test_{pos_type}_order{order}.pt'))
train_samples = []
valid_samples = []
test_samples = []
with xopen(os.path.join(SAVE_DIR, 'train.jsonl')) as fin:
    for i,line in tqdm(enumerate(fin)):
        input_sample = json.loads(line)
        train_samples.append(input_sample)
with xopen(os.path.join(SAVE_DIR, 'valid.jsonl')) as fin:
    for i,line in tqdm(enumerate(fin)):
        input_sample = json.loads(line)
        valid_samples.append(input_sample)
with xopen(os.path.join(SAVE_DIR, 'test.jsonl')) as fin:
    for i,line in tqdm(enumerate(fin)):
        input_sample = json.loads(line)
        test_samples.append(input_sample)

1624it [00:00, 274626.04it/s]
542it [00:00, 240766.02it/s]
542it [00:00, 179665.91it/s]


In [4]:
y_train, y_valid, y_test = graph.y[train_split], graph.y[valid_split], graph.y[test_split]
#total_tokens = torch.cat([train_pos_tokens, test_pos_tokens])
#y_total = torch.cat([y_train, y_test])
#train_pos_tokens, test_pos_tokens = total_tokens[:len(total_tokens)//2], total_tokens[len(total_tokens)//2:]
#y_train, y_test = y_total[:len(y_total)//2], y_total[len(y_total)//2:]

In [5]:
num_classes = len(torch.unique(y_train))

In [6]:
import torch
from torch.utils.data import TensorDataset, DataLoader

# Create a TensorDataset
train_dataset = TensorDataset(train_pos_tokens, y_train)
valid_dataset = TensorDataset(valid_pos_tokens, y_valid)
test_dataset = TensorDataset(test_pos_tokens, y_test)

# create a DataLoader for batching
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=64, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [7]:
model = MLP(input_channels_node=768, hidden_channels=1024, output_channels=num_classes)
device = torch.device('cuda:0')
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

In [8]:
def train(dataloader):
    loss_list = []
    model.train()
    for batch in dataloader:
        batch_x, batch_y = batch
        batch_x = batch_x.view(-1, 768)
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        output = model(batch_x)
        loss = criterion(output, batch_y.long())
        loss_list.append(float(loss))
        loss.backward()
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
    return np.mean(loss_list)

def test(dataloader):
    y_true_list, y_pred_list = [], []
    model.eval()
    for batch in dataloader:
        batch_x, batch_y = batch
        batch_x = batch_x.view(-1, 768)
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        output = model(batch_x)
        y_pred = output.argmax(dim=-1)
        y_true_list.append(batch_y.detach().cpu().numpy().reshape(-1))
        y_pred_list.append(y_pred.detach().cpu().numpy().reshape(-1))
    y_true = np.concatenate(y_true_list)
    y_pred = np.concatenate(y_pred_list)
    return y_true, y_pred

In [9]:
num_epochs = 100

best_val_res = 0
best_res = {}
for epoch in range(1, num_epochs+1):
    loss = train(train_dataloader)
    y_true_train, y_pred_train = test(train_dataloader)
    y_true_valid, y_pred_valid = test(valid_dataloader)
    y_true_test, y_pred_test = test(test_dataloader)
    train_score = np.mean(y_true_train == y_pred_train)
    valid_score = np.mean(y_true_valid == y_pred_valid)
    test_score = np.mean(y_true_test == y_pred_test)
    if valid_score > best_val_res:
        best_val_res = valid_score
        best_res = {
            'epoch': epoch,
            'train_score': train_score,
            'valid_score': valid_score,
            'test_score': test_score,
        }
    print(f"{epoch=:3d}, {loss=:.3f}, {train_score=:.3f}, {valid_score=:.3f}, {test_score=:.3f}")
epoch, train_score, valid_score, test_score = best_res['epoch'], best_res['train_score'], best_res['valid_score'], best_res['test_score']
print(f"Best Results: {epoch=:3d}, {train_score=:.3f}, {valid_score=:.3f}, {test_score=:.3f}")

epoch=  1, loss=1.717, train_score=0.415, valid_score=0.423, test_score=0.393
epoch=  2, loss=1.245, train_score=0.683, valid_score=0.664, test_score=0.668
epoch=  3, loss=0.996, train_score=0.727, valid_score=0.716, test_score=0.720
epoch=  4, loss=0.801, train_score=0.754, valid_score=0.738, test_score=0.760
epoch=  5, loss=0.710, train_score=0.780, valid_score=0.758, test_score=0.755
epoch=  6, loss=0.641, train_score=0.807, valid_score=0.777, test_score=0.793
epoch=  7, loss=0.583, train_score=0.799, valid_score=0.786, test_score=0.808
epoch=  8, loss=0.550, train_score=0.813, valid_score=0.804, test_score=0.801
epoch=  9, loss=0.535, train_score=0.848, valid_score=0.823, test_score=0.821
epoch= 10, loss=0.524, train_score=0.821, valid_score=0.803, test_score=0.793
epoch= 11, loss=0.532, train_score=0.832, valid_score=0.808, test_score=0.817
epoch= 12, loss=0.512, train_score=0.843, valid_score=0.823, test_score=0.832
epoch= 13, loss=0.492, train_score=0.850, valid_score=0.823, tes