In [1]:
import torch
import sys

sys.path.append("/workspace/kbqa/")  # go to parent dir
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [2]:
from tqdm import tqdm
import jsonlines
import networkx as nx
import pandas as pd
import numpy as np
import torch
import json
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import TrainingArguments, Trainer
from transformers.models.graphormer.collating_graphormer import GraphormerDataCollator
from transformers import GraphormerForGraphClassification
from transformers.models.graphormer.collating_graphormer import algos_graphormer

2023-08-14 10:34:55.262616: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-08-14 10:34:55.437630: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-08-14 10:34:56.028157: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-08-14 10:34:56.028274: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinf

In [82]:
dataset_type = 't5-xl-ssm'
train_bs = 64
eval_bs = 64
data_prep = False
push_to_hub = True
model_weights = 'hle2000/graphsormer_subgraphs_reranking_t5xl'
model_name = 'clefourrier/graphormer-base-pcqm4mv2'
num_epochs = 50
model_save_name = f"{model_name}_mse" if not model_weights else model_weights.split('/')[-1]

In [83]:
path = 'Mintaka_Subgraphs_T5_xl_ssm' if dataset_type == 't5-xl-ssm' else 'Mintaka_Subgraphs_T5_large_ssm'
subgraphs_dataset = load_dataset(f'hle2000/{path}')

In [84]:
train_df = subgraphs_dataset['train'].to_pandas()
test_df = subgraphs_dataset['test'].to_pandas()

#### Transforming the data

In [85]:
def transform_graph(graph_data, answer_entity, ground_truth_entity):
    # Create an empty dictionary to store the transformed graph
    transformed_graph = {}

    # Extract 'nodes' and 'links' from the graph_data
    nodes = graph_data['nodes']
    links = graph_data['links']

    # Calculate num_nodes
    num_nodes = len(nodes)

    # Calculate edge_index
    edge_index = [[link['source'], link['target']] for link in links]
    edge_index = list(zip(*edge_index))

    # Check if "answerEntity" matches with "groundTruthAnswerEntity" to get the label (y)
    y = 1.0 if answer_entity in ground_truth_entity else 0.0

    # Calculate node_feat based on 'type' key
    node_feat = []
    for node in nodes:
        if node['type'] == 'INTERNAL':
            node_feat.append([1])
        elif node['type'] == 'ANSWER_CANDIDATE_ENTITY':
            node_feat.append([2])
        elif node['type'] == 'QUESTIONS_ENTITY':
            node_feat.append([3])
    
    # Store the calculated values in the transformed_graph dictionary
    transformed_graph['edge_index'] = edge_index
    transformed_graph['num_nodes'] = num_nodes
    transformed_graph['y'] = [y]
    transformed_graph['node_feat'] = node_feat
    transformed_graph['edge_attr'] = [[0]]

    return transformed_graph


In [86]:
def create_adjacency_matrix(edge_list):
    # Find the maximum node ID in the edge_list
    max_node_id = max(max(edge_list[0]), max(edge_list[1]))

    # Initialize an empty adjacency matrix with zeros
    adjacency_matrix = np.zeros((max_node_id+1, max_node_id+1), dtype=np.int32)  

    # Add edges to the adjacency matrix
    for src, dest in zip(edge_list[0], edge_list[1]):
        adjacency_matrix[src, dest] = 1  
    

    return adjacency_matrix

In [87]:
def preprocess(item):
    """Convert to the required format for Graphormer"""
    attn_edge_type = None  # Initialize outside the loop

    # Calculate adjacency matrix
    adj = create_adjacency_matrix(item["edge_index"])

    shortest_path_result, path = algos_graphormer.floyd_warshall(adj)

    try:
        # Calculate max_dist and input_edges if the function call succeeds
        shortest_path_result, path = algos_graphormer.floyd_warshall(adj)
        max_dist = np.amax(shortest_path_result)
        attn_edge_type = np.zeros((item["num_nodes"], item["num_nodes"], len(item['edge_attr'])), dtype=np.int64)
        input_edges = algos_graphormer.gen_edge_input(max_dist, path, attn_edge_type)
    except:
        # If the function call fails, handle the exception
        max_dist = 0
        attn_edge_type = None
        input_edges = np.zeros((item["num_nodes"], item["num_nodes"], max_dist, len(item['edge_attr'])), dtype=np.int64)
        shortest_path_result = None

    if attn_edge_type is None:
        # Initialize attn_edge_type here if it hasn't been initialized already
        attn_edge_type = np.zeros((item["num_nodes"], item["num_nodes"], len(item['edge_attr'])), dtype=np.int64)

    # Set values for all the keys
    processed_item = {
        "edge_index": np.array(item["edge_index"]),
        "num_nodes": item["num_nodes"],
        "y": item["y"],
        "node_feat": np.array(item["node_feat"]),
        "input_nodes": np.array(item["node_feat"]),  # Use node_feat as input_nodes if node_feat is the feature representation
        "edge_attr": np.array(item["edge_attr"]),
        "attn_bias": np.zeros((item["num_nodes"] + 1, item["num_nodes"] + 1), dtype=np.single),
        "attn_edge_type": attn_edge_type,
        "spatial_pos": shortest_path_result.astype(np.int64) + 1,
        "in_degree": np.sum(adj, axis=1).reshape(-1) + 1,
        "out_degree": np.sum(adj, axis=1).reshape(-1) + 1,  # for undirected graph
        "input_edges": input_edges + 1,
        "labels": item.get("labels", item["y"]),  # Assuming "labels" key may or may not exist in the input data
    }

    return processed_item

In [88]:
from ast import literal_eval
from unidecode import unidecode
def try_literal_eval(s):
    try:
        return literal_eval(s)
    except ValueError:
        return s
    
def transform_data(df, save_path):
    transformed_graph_dicts = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc="Transforming graphs"):
        try:
            curr_dict = {}
            graph_data = try_literal_eval(row['graph']) # convert to dict
            curr_dict['original_graph'] = graph_data

            transformed_graph = transform_graph(graph_data, row['answerEntity'], row['groundTruthAnswerEntity'])
            if len(transformed_graph["edge_index"][0]) or len(transformed_graph["edge_index"][1]) > 1:
                curr_dict['question'] = row['question']
                curr_dict['answerEntity'] = row['answerEntity']
                curr_dict['groundTruthAnswerEntity'] = row['groundTruthAnswerEntity']
                curr_dict['correct'] = row['correct']
                curr_dict['transformed_graph'] = transformed_graph
                transformed_graph_dicts.append(curr_dict)
        except:
            continue 

            
    with open(save_path, 'w+') as file:
        for transformed_graph in transformed_graph_dicts:
            file.write(json.dumps(transformed_graph) + '\n')

In [89]:
train_trans_path = f'/workspace/storage/new_subgraph_dataset/{dataset_type}/graph_class/transformed_graphs_train.jsonl'
test_trans_path = f'/workspace/storage/new_subgraph_dataset/{dataset_type}/graph_class/transformed_graphs_test.jsonl'

In [90]:
if data_prep:
    transform_data(test_df, test_trans_path)
    transform_data(train_df, train_trans_path)

### Preparing the data

In [91]:
class CustomGraphDataset(Dataset):
    def __init__(self, file_path):
        self.data = []
        with open(file_path, 'r') as file:
            for line in file:
                graph_dicts = json.loads(line)
                preproc_graph = preprocess(graph_dicts['transformed_graph'])
                
                if preproc_graph['input_edges'].shape[2] != 0:
                    self.data.append(preproc_graph)
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Load your custom training and test datasets
train_dataset = CustomGraphDataset(train_trans_path)
test_dataset = CustomGraphDataset(test_trans_path)

In [92]:
len(train_dataset)

87449

#### Training

In [93]:
import numpy as np
import evaluate


threshold = 0.5
metric_classifier = evaluate.combine(["accuracy", "f1", "precision", "recall", "hyperml/balanced_accuracy",])
metric_regression = evaluate.combine(["mae"])


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = predictions[0]
    results = metric_regression.compute(predictions=predictions, references=labels)

    predictions = predictions > threshold
    results.update(
        metric_classifier.compute(predictions=predictions, references=labels)
    )
    return results

In [95]:
push_to_hub = False

In [96]:
if model_weights: # evaluating previous trained model weights
    model = GraphormerForGraphClassification.from_pretrained(
    model_weights,
    num_classes=1,
    ignore_mismatched_sizes=True,)
    
    # push this version to the hub
    if push_to_hub:
        model.push_to_hub(commit_message='previous trained best checkpoint', repo_id=f'hle2000/graphsormer_subgraphs_reranking_{dataset_type}')
else: # training from scratch
    model = GraphormerForGraphClassification.from_pretrained(
    model_name,
    num_classes=1,
    ignore_mismatched_sizes=True,)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.42k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/191M [00:00<?, ?B/s]

In [97]:
from torch.utils.data.sampler import WeightedRandomSampler
import numpy as np

class CustomTrainer(Trainer):  
    def get_labels(self):
        labels = []
        for i in self.train_dataset:
            labels.append(int(i["y"][0]))
        return labels

    def _get_train_sampler(self) -> torch.utils.data.Sampler:
        labels = self.get_labels()
        return self.create_sampler(labels)
      
    def create_sampler(self, target):
        class_sample_count = np.array(
            [len(np.where(target == t)[0]) for t in np.unique(target)]
        )
        weight = 1.0 / class_sample_count
        samples_weight = np.array([weight[t] for t in target])

        samples_weight = torch.from_numpy(samples_weight)
        samples_weigth = samples_weight.double()
        sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

        return sampler

In [98]:
# Specifiy the arguments for the trainer
training_args = TrainingArguments(
    output_dir=f"/workspace/storage/subgraphs_reranking_results/{dataset_type}/results/{model_save_name}",  # output directory
    num_train_epochs=num_epochs,  # total number of training epochs
    per_device_train_batch_size=train_bs,  # batch size per device during training
    per_device_eval_batch_size=eval_bs,  # batch size for evaluation
    warmup_steps=500,  # number of warmup steps for learning rate scheduler
    weight_decay=0.01,  # strength of weight decay
    logging_dir=f"/workspace/storage/subgraphs_reranking_results/{dataset_type}/logs/{model_save_name}",  # directory for storing logs
    load_best_model_at_end=True,  # load the best model when finished training (default metric is loss)
    metric_for_best_model="balanced_accuracy",  # select the base metrics
    logging_steps=500,  # log & save weights each logging_steps
    save_steps=500,
    evaluation_strategy="steps",  # evaluate each `logging_steps`
    report_to='wandb',
)

In [99]:
# Initialize the data collator
data_collator = GraphormerDataCollator()
# Initialize the Trainer
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,  # the callback that computes metrics of interest
)

In [100]:
if not model_weights: # training
    train_results = trainer.train()
    trainer.save_model(f"/workspace/storage/subgraphs_reranking_results/{dataset_type}/results/{model_save_name}/best_checkpoint")
    if push_to_hub:
        trainer.push_to_hub(commit_message='best checkpoint')

### Eval

In [101]:
evaluate_res = trainer.evaluate()
evaluate_res

{'eval_loss': 0.19735507667064667,
 'eval_mae': 0.3600324040272491,
 'eval_accuracy': 0.6814940111547957,
 'eval_f1': 0.3529302498374663,
 'eval_precision': 0.2360541682196546,
 'eval_recall': 0.6990434142752023,
 'eval_balanced_accuracy': 0.689023690850276,
 'eval_runtime': 14.0095,
 'eval_samples_per_second': 1561.365,
 'eval_steps_per_second': 24.412}

#### Re-ranking

In [102]:
def read_jsonl(path):
    jsonl_reader = jsonlines.open(path)
    jsonl_reader_list = list(jsonl_reader)
    df = []
    for line in tqdm(jsonl_reader_list):
        df.append(line)
    df = pd.DataFrame(df)
    return df

# df that holds the transformed graph
test_trans_df = read_jsonl(test_trans_path)

100%|██████████| 21880/21880 [00:00<00:00, 2317283.33it/s]


In [103]:
test_trans_df.head()

Unnamed: 0,original_graph,question,answerEntity,groundTruthAnswerEntity,correct,transformed_graph
0,"{'directed': True, 'multigraph': False, 'graph...",What man was a famous American author and also...,Q191050,Q7245,False,"{'edge_index': [[0, 1, 2, 3], [0, 0, 0, 0]], '..."
1,"{'directed': True, 'multigraph': False, 'graph...",What man was a famous American author and also...,Q3259878,Q7245,False,"{'edge_index': [[1, 1, 2, 3, 4], [0, 1, 0, 4, ..."
2,"{'directed': True, 'multigraph': False, 'graph...",What man was a famous American author and also...,Q7245,Q7245,True,"{'edge_index': [[1, 3, 3], [0, 0, 2]], 'num_no..."
3,"{'directed': True, 'multigraph': False, 'graph...",What man was a famous American author and also...,Q1074614,Q7245,False,"{'edge_index': [[1, 2, 3, 4, 5], [2, 0, 0, 1, ..."
4,"{'directed': True, 'multigraph': False, 'graph...",What man was a famous American author and also...,Q15133865,Q7245,False,"{'edge_index': [[1, 1, 2, 2, 3, 3, 4, 5, 5], [..."


In [105]:
dataset_type

't5-xl-ssm'

In [106]:
# getting the 200 beams from seq2seq outputs
path = 'Mintaka_T5_large_ssm_outputs' if dataset_type == 't5-large-ssm' else 'Mintaka_T5_xl_ssm_outputs'
test_res_csv = load_dataset(f'hle2000/{path}', ignore_verifications=True)['test'].to_pandas()



Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating test split:   0%|          | 0/4000 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/16000 [00:00<?, ? examples/s]

In [107]:
class EvalGraphDataset(Dataset):
    def __init__(self, is_corrects, graphs):
        self.data = []
        self.correct = []
        for is_correct, graph in zip(is_corrects, graphs):
            preproc_graph = preprocess(graph)
            if preproc_graph['input_edges'].shape[2] != 0:
                self.data.append(preproc_graph)
                self.correct.append(is_correct)
    
    def get_new_correct(self):
        return self.correct
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [108]:
final_acc, top200_total, top1_total, seq2seq_correct = 0, 0, 0, 0
    
for idx, group in tqdm(test_res_csv.iterrows()):
    curr_question_df = test_trans_df[test_trans_df["question"] == group['question']]
    if len(curr_question_df) == 0: # we don't have subgraph for this question, take answer from seq2seq
        if group["answer_0"] == group["target"]:
            seq2seq_correct += 1
        else: # check if answer exist in 200 beams for question with no subgraphs
            all_beams = group.tolist()[2:-1] # all 200 beams
            all_beams = list(set(all_beams))
            top200_total += 1 if group["target"] in all_beams else 0
            
    else: # we have subgraph for this question  
        all_beams = group.tolist()[2:-1] # all 200 beams
        all_beams = list(set(all_beams))
        
        if group["target"] not in all_beams: # no correct answer in beam
            continue
            
        # correct answer exist in beam
        top1_total += 1 if group["answer_0"] == group["target"] else 0
        top200_total += 1
        
        transformed_graphs = curr_question_df["transformed_graph"].tolist()
        is_corrects = curr_question_df["correct"].tolist()
        current_dataset = EvalGraphDataset(is_corrects, transformed_graphs)
        filtered_is_correct = current_dataset.get_new_correct()
        
        current_dataloader = torch.utils.data.DataLoader(current_dataset, 
                                                         batch_size=len(transformed_graphs), 
                                                         collate_fn=data_collator, 
                                                         shuffle=False)

        # batch size should only be one
        for item in current_dataloader:
            logits = outputs = model(input_nodes = item['input_nodes'].to(device), 
                                    input_edges = item['input_edges'].to(device),
                                    attn_bias = item['attn_bias'].to(device),
                                    in_degree = item['in_degree'].to(device),
                                    out_degree = item['out_degree'].to(device),
                                    spatial_pos = item['spatial_pos'].to(device),
                                    attn_edge_type = item['attn_edge_type'].to(device))
            mse_pred = outputs.logits.flatten()
            max_idx = mse_pred.argmax()
        
        if filtered_is_correct[max_idx]:
            final_acc += 1 
              

# final rerankinga, top1 and top200 result
reranking_res = (final_acc + seq2seq_correct)/ len(test_res_csv)
top200 = (top200_total + seq2seq_correct)/len(test_res_csv)
top1 = (top1_total + seq2seq_correct)/ len(test_res_csv)

        

4000it [00:30, 132.17it/s]


In [80]:
top1, reranking_res, top200

(0.25425, 0.2355, 0.64375)