In [1]:
from ogb.nodeproppred import NodePropPredDataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch
from ogb.nodeproppred import Evaluator
import pandas as pd
import math
from tqdm import tqdm
from datasets import Dataset, DatasetDict
import numpy as np
from transformers import (
    RobertaTokenizerFast,
    RobertaForSequenceClassification,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    AutoConfig,
    DataCollatorWithPadding,
    get_scheduler,
    get_linear_schedule_with_warmup
)
import evaluate

In [3]:
! ls /nlp/scr/ananjan/graph_models/mpnet_all

checkpoint-5684


In [4]:
DATASET = "ogbn-arxiv"
DATASET_ROOT = "/nlp/scr/ananjan/graph_datasets/"
OUTPUT_ROOT = "/nlp/scr/ananjan/graph_models/mpnet_all/"
MODEL = "/nlp/scr/ananjan/graph_models/mpnet_all/checkpoint-5684"
MODE = 'all'
BATCH_SIZE = 32
NUM_EPOCHS = 2
LR = 5e-5
WARMUP = 100
LOG_STEPS = 100
MAX_LEN = 512

In [5]:
# Load Dataset

dataset = NodePropPredDataset(name = DATASET, root = DATASET_ROOT)

In [6]:
# Get Splits

split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]

In [7]:
# Get Labels for Node Classification

graph, label = dataset[0]
len(label)

169343

In [8]:
# Load Text Label Mappings

labelidx2arxivcategeory = pd.read_csv('/nlp/scr/ananjan/graph_datasets/ogbn_arxiv/mapping/labelidx2arxivcategeory.csv')
print(len(labelidx2arxivcategeory))
labelidx2arxivcategeory.head()

40


Unnamed: 0,label idx,arxiv category
0,0,arxiv cs na
1,1,arxiv cs mm
2,2,arxiv cs lo
3,3,arxiv cs cy
4,4,arxiv cs cr


In [9]:
# Load Node-Paper Id Mappings

nodeidx2paperid = pd.read_csv('/nlp/scr/ananjan/graph_datasets/ogbn_arxiv/mapping/nodeidx2paperid.csv')
nodeidx2paperid.head()

Unnamed: 0,node idx,paper id
0,0,9657784
1,1,39886162
2,2,116214155
3,3,121432379
4,4,231147053


In [10]:
# Load Paper Mappings

titleabs = pd.read_csv('/nlp/scr/ananjan/graph_datasets/ogbn_arxiv/mapping/titleabs.tsv', sep='\t')
titleabs.head()

Unnamed: 0,paperid,title,abstract
0,200971.0,ontology as a source for rule generation,This paper discloses the potential of OWL (Web...
1,549074.0,a novel methodology for thermal analysis a 3 d...,The semiconductor industry is reaching a fasci...
2,630234.0,spreadsheets on the move an evaluation of mobi...,The power of mobile devices has increased dram...
3,803423.0,multi view metric learning for multi view vide...,Traditional methods on video summarization are...
4,1102481.0,big data analytics in future internet of things,Current research on Internet of Things (IoT) m...


In [11]:
# Make reverse index for text df

reverse_index= {}

paperids = titleabs["paperid"].tolist()
for idx, paperid in enumerate(paperids):
    if (not math.isnan(paperid)):
        reverse_index[int(paperid)] = idx

In [12]:
# Dataset Creation

dataset_dict = {'text': [], 'labels': []}

for idx, l in tqdm(enumerate(label)):
    dataset_dict['labels'].append(l[0])
    paper_id = nodeidx2paperid.iloc[idx]['paper id']
    reference_idx = reverse_index[paper_id]
    title = titleabs.iloc[reference_idx]['title']
    abstract = titleabs.iloc[reference_idx]['abstract']
    if (MODE == 'title'):
        dataset_dict['text'].append("Title: " + title)
    elif (MODE == 'abstract'):
        dataset_dict['text'].append(" Abstract: " + abstract)
    else:
        dataset_dict['text'].append("Title: " + title + " Abstract: " + abstract)

dataset_dict['text'] = np.array(dataset_dict['text'])
dataset_dict['labels'] = np.array(dataset_dict['labels'])

169343it [00:18, 9266.31it/s]


In [13]:
train_dataset, valid_dataset, test_dataset = {}, {}, {}

train_dataset['text'] = dataset_dict['text'][train_idx]
train_dataset['labels'] = dataset_dict['labels'][train_idx]
train_dataset = Dataset.from_dict(train_dataset)

valid_dataset['text'] = dataset_dict['text'][valid_idx]
valid_dataset['labels'] = dataset_dict['labels'][valid_idx]
valid_dataset = Dataset.from_dict(valid_dataset)

test_dataset['text'] = dataset_dict['text'][test_idx]
test_dataset['labels'] = dataset_dict['labels'][test_idx]
test_dataset = Dataset.from_dict(test_dataset)

fin_dataset = DatasetDict({
    'train': train_dataset,
    'valid': valid_dataset,
    'test': test_dataset
})

In [14]:
# train_dataset, valid_dataset, test_dataset = {}, {}, {}

# train_dataset['text'] = dataset_dict['text'][:100]
# train_dataset['label'] = dataset_dict['label'][:100]
# train_dataset = Dataset.from_dict(train_dataset)

# valid_dataset['text'] = dataset_dict['text'][:100]
# valid_dataset['label'] = dataset_dict['label'][:100]
# valid_dataset = Dataset.from_dict(valid_dataset)

# test_dataset['text'] = dataset_dict['text'][:100]
# test_dataset['label'] = dataset_dict['label'][:100]
# test_dataset = Dataset.from_dict(test_dataset)

# fin_dataset = DatasetDict({
#     'train': train_dataset,
#     'valid': valid_dataset,
#     'test': test_dataset
# })

# train_dataset[1]

In [15]:
# Init tokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL)

def tokenizer_helper(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=MAX_LEN)

In [16]:
# Tokenize dataset

for split in fin_dataset:
    dataset = fin_dataset[split]
    dataset = dataset.map(tokenizer_helper, batched=True, batch_size=BATCH_SIZE)
    dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
    fin_dataset[split] = dataset

Map:   0%|          | 0/90941 [00:00<?, ? examples/s]

Map:   0%|          | 0/29799 [00:00<?, ? examples/s]

Map:   0%|          | 0/48603 [00:00<?, ? examples/s]

In [17]:
# Compute Metrics for HF

metric = evaluate.load("accuracy")

def compute_metrics(pred):
    logits, labels = pred
    return metric.compute(predictions=np.argmax(logits, axis=-1), references=labels)

In [18]:
# Init model

model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=len(labelidx2arxivcategeory))

In [19]:
# # HF Init and Training

# training_args = TrainingArguments(
#     output_dir=OUTPUT_ROOT,
#     num_train_epochs=NUM_EPOCHS,
#     per_device_train_batch_size=BATCH_SIZE,
#     per_device_eval_batch_size=BATCH_SIZE,
#     evaluation_strategy="epoch",
#     logging_dir=f"{OUTPUT_ROOT}/logs",
#     logging_steps=LOG_STEPS,
#     learning_rate=LR,
#     warmup_steps=WARMUP,
#     save_strategy="epoch",
#     load_best_model_at_end=True,
#     save_total_limit=1
# )

# trainer = Trainer(
#     model = model,
#     args = training_args,
#     train_dataset = fin_dataset['train'],
#     eval_dataset = fin_dataset['test'],
#     compute_metrics=compute_metrics
# )

# trainer.train()

In [20]:
! ls /nlp/scr/ananjan/graph_models/mpnet_all/

checkpoint-5684


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [21]:
# tokenizer.save_pretrained(OUTPUT_ROOT + "checkpoint-5684/")

In [22]:
# PyTorch Init

# train_dataloader = DataLoader(fin_dataset['train'], shuffle=True, batch_size=BATCH_SIZE)
eval_dataloader = DataLoader(fin_dataset['test'], batch_size=BATCH_SIZE)
# optimizer = AdamW(model.parameters(), lr=LR)
# num_training_steps = NUM_EPOCHS * len(train_dataloader)
# lr_scheduler = get_linear_schedule_with_warmup(
#     optimizer=optimizer, num_warmup_steps=WARMUP, num_training_steps=num_training_steps
# )

In [23]:
# PyTorch Train (required because HF logging is broken)
# progress_bar = tqdm(len(eval_dataloader))
device = torch.device("cuda")
model.to(device)

# for epoch in range(NUM_EPOCHS):
#     num_steps = 0
#     model.train()
#     for batch in train_dataloader:
#         batch_gpu = {k: v.to(device) for k, v in batch.items()}
#         outputs = model(**batch_gpu)
#         loss = outputs.loss
#         loss.backward()

#         optimizer.step()
#         lr_scheduler.step()
#         optimizer.zero_grad()
#         progress_bar.update(1)
#         num_steps += 1
        
#         if (num_steps%LOG_STEPS == 0):
#             print(f'Step {num_steps}')
#             print(f'Train Loss {loss}')
    
#     print(f'Epoch {epoch} done')
#     model.eval()
    
#     metric = evaluate.load("accuracy")
#     for batch in eval_dataloader:
#         batch_gpu = {k: v.to(device) for k, v in batch.items()}
#         with torch.no_grad():
#             outputs = model(**batch_gpu)

#         logits = outputs.logits
#         predictions = torch.argmax(logits, dim=-1)
#         progress_bar.update(1)
#         metric.add_batch(predictions=predictions, references=batch["labels"])
#     print(f'Test Accuracy {metric.compute()}')
    
#     metric = evaluate.load("accuracy")
#     for batch in train_dataloader:
#         batch_gpu = {k: v.to(device) for k, v in batch.items()}
#         with torch.no_grad():
#             outputs = model(**batch_gpu)

#         logits = outputs.logits
#         predictions = torch.argmax(logits, dim=-1)
#         metric.add_batch(predictions=predictions, references=batch["labels"])
#     print(f'Train Accuracy {metric.compute()}')

MPNetForSequenceClassification(
  (mpnet): MPNetModel(
    (embeddings): MPNetEmbeddings(
      (word_embeddings): Embedding(30527, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): MPNetEncoder(
      (layer): ModuleList(
        (0-11): 12 x MPNetLayer(
          (attention): MPNetAttention(
            (attn): MPNetSelfAttention(
              (q): Linear(in_features=768, out_features=768, bias=True)
              (k): Linear(in_features=768, out_features=768, bias=True)
              (v): Linear(in_features=768, out_features=768, bias=True)
              (o): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
 

In [84]:
model.eval()
    
metric = evaluate.load("accuracy")
for batch in tqdm(eval_dataloader):
    batch_gpu = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch_gpu)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    progress_bar.update(1)
    metric.add_batch(predictions=predictions, references=batch["labels"])
print(f'Test Accuracy {metric.compute()}')


3137it [1:01:20, 78.89s/it]                                                          | 0/1519 [00:00<?, ?it/s][A
3139it [1:01:20, 42.51s/it]                                                  | 2/1519 [00:00<01:34, 16.11it/s][A
3141it [1:01:20, 25.64s/it]                                                  | 4/1519 [00:00<01:33, 16.28it/s][A
3143it [1:01:20, 16.38s/it]                                                  | 6/1519 [00:00<01:32, 16.33it/s][A
3145it [1:01:20, 10.82s/it]                                                  | 8/1519 [00:00<01:32, 16.38it/s][A
3147it [1:01:20,  7.29s/it]                                                 | 10/1519 [00:00<01:32, 16.37it/s][A
3149it [1:01:21,  4.99s/it]                                                 | 12/1519 [00:00<01:31, 16.38it/s][A
3151it [1:01:21,  3.45s/it]                                                 | 14/1519 [00:00<01:31, 16.41it/s][A
3153it [1:01:21,  2.40s/it]                                                 | 16/1519 [

3421it [1:01:37, 16.38it/s]                                                | 284/1519 [00:17<01:15, 16.38it/s][A
3423it [1:01:37, 16.38it/s]                                                | 286/1519 [00:17<01:15, 16.38it/s][A
3425it [1:01:37, 16.39it/s]                                                | 288/1519 [00:17<01:15, 16.39it/s][A
3427it [1:01:38, 16.39it/s]                                                | 290/1519 [00:17<01:14, 16.39it/s][A
3429it [1:01:38, 16.39it/s]                                                | 292/1519 [00:17<01:14, 16.39it/s][A
3431it [1:01:38, 16.39it/s]                                                | 294/1519 [00:17<01:14, 16.38it/s][A
3433it [1:01:38, 16.39it/s]                                                | 296/1519 [00:18<01:14, 16.39it/s][A
3435it [1:01:38, 16.38it/s]                                                | 298/1519 [00:18<01:14, 16.39it/s][A
3437it [1:01:38, 16.39it/s]                                                | 300/1519 [0

3705it [1:01:55, 16.24it/s]████▏                                           | 568/1519 [00:34<00:58, 16.24it/s][A
3707it [1:01:55, 16.25it/s]████▎                                           | 570/1519 [00:34<00:58, 16.25it/s][A
3709it [1:01:55, 16.26it/s]████▎                                           | 572/1519 [00:35<00:58, 16.26it/s][A
3711it [1:01:55, 16.25it/s]████▍                                           | 574/1519 [00:35<00:58, 16.25it/s][A
3713it [1:01:55, 16.26it/s]████▌                                           | 576/1519 [00:35<00:58, 16.25it/s][A
3715it [1:01:55, 16.26it/s]████▋                                           | 578/1519 [00:35<00:57, 16.26it/s][A
3717it [1:01:55, 16.26it/s]████▋                                           | 580/1519 [00:35<00:57, 16.26it/s][A
3719it [1:01:55, 16.24it/s]████▊                                           | 582/1519 [00:35<00:57, 16.25it/s][A
3721it [1:01:56, 16.24it/s]████▉                                           | 584/1519 [0

3989it [1:02:12, 16.16it/s]█████████████████▎                              | 852/1519 [00:52<00:41, 16.16it/s][A
3991it [1:02:12, 16.15it/s]█████████████████▎                              | 854/1519 [00:52<00:41, 16.16it/s][A
3993it [1:02:12, 16.15it/s]█████████████████▍                              | 856/1519 [00:52<00:41, 16.16it/s][A
3995it [1:02:12, 16.15it/s]█████████████████▌                              | 858/1519 [00:52<00:40, 16.16it/s][A
3997it [1:02:13, 16.17it/s]█████████████████▋                              | 860/1519 [00:52<00:40, 16.16it/s][A
3999it [1:02:13, 16.15it/s]█████████████████▋                              | 862/1519 [00:52<00:40, 16.16it/s][A
4001it [1:02:13, 16.16it/s]█████████████████▊                              | 864/1519 [00:53<00:40, 16.16it/s][A
4003it [1:02:13, 16.15it/s]█████████████████▉                              | 866/1519 [00:53<00:40, 16.15it/s][A
4005it [1:02:13, 16.17it/s]██████████████████                              | 868/1519 [0

4273it [1:02:30, 16.11it/s]█████████████████████████████▌                 | 1136/1519 [01:09<00:23, 16.11it/s][A
4275it [1:02:30, 16.12it/s]█████████████████████████████▋                 | 1138/1519 [01:10<00:23, 16.12it/s][A
4277it [1:02:30, 16.12it/s]█████████████████████████████▊                 | 1140/1519 [01:10<00:23, 16.12it/s][A
4279it [1:02:30, 16.13it/s]█████████████████████████████▊                 | 1142/1519 [01:10<00:23, 16.12it/s][A
4281it [1:02:30, 16.11it/s]█████████████████████████████▉                 | 1144/1519 [01:10<00:23, 16.12it/s][A
4283it [1:02:30, 16.11it/s]██████████████████████████████                 | 1146/1519 [01:10<00:23, 16.11it/s][A
4285it [1:02:30, 16.09it/s]██████████████████████████████▏                | 1148/1519 [01:10<00:23, 16.10it/s][A
4287it [1:02:31, 16.10it/s]██████████████████████████████▏                | 1150/1519 [01:10<00:22, 16.10it/s][A
4289it [1:02:31, 16.11it/s]██████████████████████████████▎                | 1152/1519 [0

4557it [1:02:47, 16.04it/s]██████████████████████████████████████████▌    | 1420/1519 [01:27<00:06, 16.04it/s][A
4559it [1:02:48, 16.03it/s]██████████████████████████████████████████▌    | 1422/1519 [01:27<00:06, 16.04it/s][A
4561it [1:02:48, 16.04it/s]██████████████████████████████████████████▋    | 1424/1519 [01:27<00:05, 16.04it/s][A
4563it [1:02:48, 16.05it/s]██████████████████████████████████████████▊    | 1426/1519 [01:28<00:05, 16.05it/s][A
4565it [1:02:48, 16.06it/s]██████████████████████████████████████████▊    | 1428/1519 [01:28<00:05, 16.05it/s][A
4567it [1:02:48, 16.03it/s]██████████████████████████████████████████▉    | 1430/1519 [01:28<00:05, 16.04it/s][A
4569it [1:02:48, 16.02it/s]███████████████████████████████████████████    | 1432/1519 [01:28<00:05, 16.02it/s][A
4571it [1:02:48, 15.77it/s]███████████████████████████████████████████▏   | 1434/1519 [01:28<00:05, 15.78it/s][A
4573it [1:02:48, 15.87it/s]███████████████████████████████████████████▏   | 1436/1519 [0

Test Accuracy {'accuracy': 0.72573709441804}


In [24]:
# Dump Logits
import pickle

# OUT_MODEL = '/nlp/scr/ananjan/graph_models/roberta_all/checkpoint-2842'
device = torch.device("cuda")
# model = AutoModelForSequenceClassification.from_pretrained(OUT_MODEL, num_labels=len(labelidx2arxivcategeory))
dump_dataset = Dataset.from_dict(dataset_dict)
dump_dataset = dump_dataset.map(tokenizer_helper, batched=True, batch_size=BATCH_SIZE)
dump_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

dataloader = DataLoader(dump_dataset, shuffle=False, batch_size=BATCH_SIZE)

Map:   0%|          | 0/169343 [00:00<?, ? examples/s]

In [25]:
from tqdm import tqdm

total_logits = []
model.to(device)
for batch in tqdm(dataloader):
    batch_gpu = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch_gpu)
        logits = list(outputs.logits.cpu().numpy())
        total_logits += logits
total_logits = np.array(total_logits)
        
with open("/nlp/scr/ananjan/graph_embeddings/finetuned/mpnet_logits_arxiv.pkl", 'wb') as f:
    pickle.dump(total_logits, f)

100%|█████████████████████████████| 5292/5292 [37:24<00:00,  2.36it/s]
