In [1]:
from ogb.nodeproppred import NodePropPredDataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch
from ogb.nodeproppred import Evaluator
import pandas as pd
import math
from tqdm import tqdm
from datasets import Dataset, DatasetDict
import numpy as np
from transformers import (
    RobertaTokenizerFast,
    RobertaForSequenceClassification,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    AutoConfig,
    DataCollatorWithPadding,
    get_scheduler,
    get_linear_schedule_with_warmup
)
import evaluate

In [2]:
DATASET = "ogbn-products"
DATASET_ROOT = "/nlp/scr/ananjan/graph_datasets/"
OUTPUT_ROOT = "/nlp/scr/ananjan/graph_models/roberta_products/"
MODEL = "roberta-base"
BATCH_SIZE = 16
NUM_EPOCHS = 2
LR = 5e-5
WARMUP = 100
LOG_STEPS = 100
MAX_LEN = 512

In [3]:
# Load Dataset

dataset = NodePropPredDataset(name = DATASET, root = DATASET_ROOT)

In [4]:
! ls /nlp/scr/ananjan/graph_datasets/ogbn_products/mapping

Amazon-3M.raw	   labelidx2productcategory.csv  README.md
Amazon-3M.raw.zip  nodeidx2asin.csv


In [5]:
# Get Splits

split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]

In [6]:
# Get Labels for Node Classification

graph, label = dataset[0]
len(label)

2449029

In [7]:
# Load Text Label Mappings

labelidx2productcategeory = pd.read_csv('/nlp/scr/ananjan/graph_datasets/ogbn_products/mapping/labelidx2productcategory.csv')
print(len(labelidx2productcategeory))
labelidx2productcategeory.head()

47


Unnamed: 0,label idx,product category
0,0,Home & Kitchen
1,1,Health & Personal Care
2,2,Beauty
3,3,Sports & Outdoors
4,4,Books


In [8]:
# Load ASINs

node2asin = pd.read_csv('/nlp/scr/ananjan/graph_datasets/ogbn_products/mapping/nodeidx2asin.csv')
print(len(node2asin))
node2asin.head()

2449029


Unnamed: 0,node idx,asin
0,0,B00902X3L2
1,1,B000FW4BGM
2,2,B001NTYWFQ
3,3,B003DTLNVA
4,4,B00KFTCE28


In [9]:
# Load Paper Mappings
import json

txt_mapping = {}

with open('/nlp/scr/ananjan/graph_datasets/ogbn_products/mapping/Amazon-3M.raw/trn.json', 'r') as f:
    for line in f:
        jsonified = json.loads(line)
        txt_mapping[jsonified["uid"]] = jsonified["title"].strip()
        
with open('/nlp/scr/ananjan/graph_datasets/ogbn_products/mapping/Amazon-3M.raw/tst.json', 'r') as f:
    for line in f:
        jsonified = json.loads(line)
        txt_mapping[jsonified["uid"]] = jsonified["title"].strip()

In [10]:
# Dataset Creation

dataset_dict = {'text': [], 'labels': []}

for idx, l in tqdm(enumerate(label)):
    dataset_dict['labels'].append(l[0])
    asin = node2asin.iloc[idx]['asin']
    dataset_dict['text'].append(txt_mapping[asin])

dataset_dict['text'] = np.array(dataset_dict['text'])
dataset_dict['labels'] = np.array(dataset_dict['labels'])

2449029it [01:08, 35697.79it/s]


In [11]:
train_dataset, valid_dataset, test_dataset = {}, {}, {}

train_dataset['text'] = dataset_dict['text'][train_idx]
train_dataset['labels'] = dataset_dict['labels'][train_idx]
train_dataset = Dataset.from_dict(train_dataset)

valid_dataset['text'] = dataset_dict['text'][valid_idx]
valid_dataset['labels'] = dataset_dict['labels'][valid_idx]
valid_dataset = Dataset.from_dict(valid_dataset)

test_dataset['text'] = dataset_dict['text'][test_idx]
test_dataset['labels'] = dataset_dict['labels'][test_idx]
test_dataset = Dataset.from_dict(test_dataset)

fin_dataset = DatasetDict({
    'train': train_dataset,
    'valid': valid_dataset,
    'test': test_dataset
})

In [12]:
# train_dataset, valid_dataset, test_dataset = {}, {}, {}

# train_dataset['text'] = dataset_dict['text'][:100]
# train_dataset['label'] = dataset_dict['label'][:100]
# train_dataset = Dataset.from_dict(train_dataset)

# valid_dataset['text'] = dataset_dict['text'][:100]
# valid_dataset['label'] = dataset_dict['label'][:100]
# valid_dataset = Dataset.from_dict(valid_dataset)

# test_dataset['text'] = dataset_dict['text'][:100]
# test_dataset['label'] = dataset_dict['label'][:100]
# test_dataset = Dataset.from_dict(test_dataset)

# fin_dataset = DatasetDict({
#     'train': train_dataset,
#     'valid': valid_dataset,
#     'test': test_dataset
# })

# train_dataset[1]

In [13]:
# Init tokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL)

def tokenizer_helper(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=MAX_LEN)

In [14]:
# Tokenize dataset

for split in fin_dataset:
    dataset = fin_dataset[split]
    dataset = dataset.map(tokenizer_helper, batched=True, batch_size=BATCH_SIZE)
    dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
    fin_dataset[split] = dataset

Map:   0%|          | 0/196615 [00:00<?, ? examples/s]

Map:   0%|          | 0/39323 [00:00<?, ? examples/s]

Map:   0%|          | 0/2213091 [00:00<?, ? examples/s]

In [15]:
# Compute Metrics for HF

metric = evaluate.load("accuracy")

def compute_metrics(pred):
    logits, labels = pred
    return metric.compute(predictions=np.argmax(logits, axis=-1), references=labels)

In [16]:
# Init model

model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=len(labelidx2productcategeory))

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.weight', 'classifier.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
# # HF Init and Training

# training_args = TrainingArguments(
#     output_dir=OUTPUT_ROOT,
#     num_train_epochs=NUM_EPOCHS,
#     per_device_train_batch_size=BATCH_SIZE,
#     per_device_eval_batch_size=BATCH_SIZE,
#     evaluation_strategy="epoch",
#     logging_dir=f"{OUTPUT_ROOT}/logs",
#     logging_steps=LOG_STEPS,
#     learning_rate=LR,
#     warmup_steps=WARMUP,
#     save_strategy="epoch",
#     load_best_model_at_end=True,
#     save_total_limit=1
# )

# trainer = Trainer(
#     model = model,
#     args = training_args,
#     train_dataset = fin_dataset['train'],
#     eval_dataset = fin_dataset['test'],
#     compute_metrics=compute_metrics
# )

# trainer.train()

In [None]:
# tokenizer.save_pretrained(OUTPUT_ROOT + "checkpoint-2842/")

In [19]:
# PyTorch Init

train_dataloader = DataLoader(fin_dataset['train'], shuffle=True, batch_size=BATCH_SIZE)
eval_dataloader = DataLoader(fin_dataset['test'], batch_size=BATCH_SIZE)
optimizer = AdamW(model.parameters(), lr=LR)
num_training_steps = NUM_EPOCHS * len(train_dataloader)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer, num_warmup_steps=WARMUP, num_training_steps=num_training_steps
)

In [20]:
# PyTorch Train (required because HF logging is broken)
progress_bar = tqdm(range(num_training_steps))
device = torch.device("cuda")
model.to(device)

for epoch in range(NUM_EPOCHS):
    num_steps = 0
#     model.train()
#     for batch in train_dataloader:
#         batch_gpu = {k: v.to(device) for k, v in batch.items()}
#         outputs = model(**batch_gpu)
#         loss = outputs.loss
#         loss.backward()

#         optimizer.step()
#         lr_scheduler.step()
#         optimizer.zero_grad()
#         progress_bar.update(1)
#         num_steps += 1
        
#         if (num_steps%LOG_STEPS == 0):
#             print(f'Step {num_steps}')
#             print(f'Train Loss {loss}')
    
#     print(f'Epoch {epoch} done')
    model.eval()
    
    metric = evaluate.load("accuracy")
    for batch in tqdm(eval_dataloader):
        batch_gpu = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch_gpu)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])
    print(f'Test Accuracy {metric.compute()}')
    
    metric = evaluate.load("accuracy")
    for batch in train_dataloader:
        batch_gpu = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch_gpu)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        metric.add_batch(predictions=predictions, references=batch["labels"])
    print(f'Train Accuracy {metric.compute()}')

  0%|▍                                                                                                    | 100/24578 [00:28<1:32:31,  4.41it/s]

Step 100
Train Loss 1.6136910915374756


  1%|▊                                                                                                    | 200/24578 [00:51<1:34:29,  4.30it/s]

Step 200
Train Loss 1.442386507987976


  1%|█▏                                                                                                   | 300/24578 [01:14<1:34:58,  4.26it/s]

Step 300
Train Loss 1.5180968046188354


  2%|█▋                                                                                                   | 400/24578 [01:38<1:36:00,  4.20it/s]

Step 400
Train Loss 1.5969632863998413


  2%|██                                                                                                   | 500/24578 [02:01<1:36:10,  4.17it/s]

Step 500
Train Loss 1.0579503774642944


  2%|██▍                                                                                                  | 600/24578 [02:25<1:35:51,  4.17it/s]

Step 600
Train Loss 0.5781848430633545


  3%|██▉                                                                                                  | 700/24578 [02:50<1:35:57,  4.15it/s]

Step 700
Train Loss 0.3811384439468384


  3%|███▎                                                                                                 | 800/24578 [03:14<1:35:26,  4.15it/s]

Step 800
Train Loss 0.8008793592453003


  4%|███▋                                                                                                 | 900/24578 [03:38<1:34:44,  4.17it/s]

Step 900
Train Loss 1.1412488222122192


  4%|████                                                                                                | 1000/24578 [04:02<1:35:02,  4.13it/s]

Step 1000
Train Loss 0.6490986943244934


  4%|████▍                                                                                               | 1100/24578 [04:26<1:33:51,  4.17it/s]

Step 1100
Train Loss 1.0110530853271484


  5%|████▉                                                                                               | 1200/24578 [04:50<1:33:33,  4.16it/s]

Step 1200
Train Loss 1.2406967878341675


  5%|█████▎                                                                                              | 1300/24578 [05:14<1:33:06,  4.17it/s]

Step 1300
Train Loss 1.0595496892929077


  6%|█████▋                                                                                              | 1400/24578 [05:38<1:32:52,  4.16it/s]

Step 1400
Train Loss 0.3004203140735626


  6%|██████                                                                                              | 1500/24578 [06:02<1:32:16,  4.17it/s]

Step 1500
Train Loss 1.1974068880081177


  7%|██████▌                                                                                             | 1600/24578 [06:26<1:32:00,  4.16it/s]

Step 1600
Train Loss 0.5197347402572632


  7%|██████▉                                                                                             | 1700/24578 [06:50<1:31:27,  4.17it/s]

Step 1700
Train Loss 1.2766294479370117


  7%|███████▎                                                                                            | 1800/24578 [07:14<1:31:26,  4.15it/s]

Step 1800
Train Loss 1.4921202659606934


  8%|███████▋                                                                                            | 1900/24578 [07:38<1:30:45,  4.16it/s]

Step 1900
Train Loss 1.3602311611175537


  8%|████████▏                                                                                           | 2000/24578 [08:02<1:30:10,  4.17it/s]

Step 2000
Train Loss 0.5925938487052917


  9%|████████▌                                                                                           | 2100/24578 [08:26<1:29:53,  4.17it/s]

Step 2100
Train Loss 0.7237738370895386


  9%|████████▉                                                                                           | 2200/24578 [08:50<1:29:33,  4.16it/s]

Step 2200
Train Loss 0.9800254106521606


  9%|█████████▎                                                                                          | 2300/24578 [09:14<1:29:06,  4.17it/s]

Step 2300
Train Loss 0.8928384780883789


 10%|█████████▊                                                                                          | 2400/24578 [09:38<1:28:41,  4.17it/s]

Step 2400
Train Loss 0.5322636365890503


 10%|██████████▏                                                                                         | 2500/24578 [10:02<1:28:14,  4.17it/s]

Step 2500
Train Loss 0.3320101499557495


 11%|██████████▌                                                                                         | 2600/24578 [10:26<1:28:18,  4.15it/s]

Step 2600
Train Loss 1.063615083694458


 11%|██████████▉                                                                                         | 2700/24578 [10:50<1:27:52,  4.15it/s]

Step 2700
Train Loss 0.9852722883224487


 11%|███████████▍                                                                                        | 2800/24578 [11:14<1:27:13,  4.16it/s]

Step 2800
Train Loss 0.4797027111053467


 12%|███████████▊                                                                                        | 2900/24578 [11:38<1:26:42,  4.17it/s]

Step 2900
Train Loss 0.7585819959640503


 12%|████████████▏                                                                                       | 3000/24578 [12:02<1:26:13,  4.17it/s]

Step 3000
Train Loss 0.7866160869598389


 13%|████████████▌                                                                                       | 3100/24578 [12:26<1:25:47,  4.17it/s]

Step 3100
Train Loss 0.8969007134437561


 13%|█████████████                                                                                       | 3200/24578 [12:50<1:25:32,  4.17it/s]

Step 3200
Train Loss 0.4984815716743469


 13%|█████████████▍                                                                                      | 3300/24578 [13:14<1:24:59,  4.17it/s]

Step 3300
Train Loss 0.7566836476325989


 14%|█████████████▊                                                                                      | 3400/24578 [13:38<1:24:44,  4.17it/s]

Step 3400
Train Loss 0.615092396736145


 14%|██████████████▏                                                                                     | 3500/24578 [14:02<1:24:20,  4.16it/s]

Step 3500
Train Loss 0.7658035755157471


 15%|██████████████▋                                                                                     | 3600/24578 [14:26<1:23:49,  4.17it/s]

Step 3600
Train Loss 0.5718626379966736


 15%|███████████████                                                                                     | 3700/24578 [14:50<1:23:35,  4.16it/s]

Step 3700
Train Loss 1.043565273284912


 15%|███████████████▍                                                                                    | 3800/24578 [15:14<1:23:06,  4.17it/s]

Step 3800
Train Loss 0.11381398141384125


 16%|███████████████▊                                                                                    | 3900/24578 [15:38<1:22:56,  4.15it/s]

Step 3900
Train Loss 1.2894556522369385


 16%|████████████████▎                                                                                   | 4000/24578 [16:02<1:22:16,  4.17it/s]

Step 4000
Train Loss 0.08020049333572388


 17%|████████████████▋                                                                                   | 4100/24578 [16:26<1:21:53,  4.17it/s]

Step 4100
Train Loss 0.9065076112747192


 17%|█████████████████                                                                                   | 4200/24578 [16:50<1:21:30,  4.17it/s]

Step 4200
Train Loss 0.8499460816383362


 17%|█████████████████▍                                                                                  | 4300/24578 [17:14<1:21:01,  4.17it/s]

Step 4300
Train Loss 0.5296209454536438


 18%|█████████████████▉                                                                                  | 4400/24578 [17:38<1:20:41,  4.17it/s]

Step 4400
Train Loss 0.5739572644233704


 18%|██████████████████▎                                                                                 | 4500/24578 [18:02<1:20:18,  4.17it/s]

Step 4500
Train Loss 0.573982834815979


 19%|██████████████████▋                                                                                 | 4600/24578 [18:26<1:19:54,  4.17it/s]

Step 4600
Train Loss 0.7007826566696167


 19%|███████████████████                                                                                 | 4700/24578 [18:50<1:19:33,  4.16it/s]

Step 4700
Train Loss 0.9088789224624634


 20%|███████████████████▌                                                                                | 4800/24578 [19:14<1:19:06,  4.17it/s]

Step 4800
Train Loss 0.9400466680526733


 20%|███████████████████▉                                                                                | 4900/24578 [19:38<1:18:50,  4.16it/s]

Step 4900
Train Loss 0.9197524189949036


 20%|████████████████████▎                                                                               | 5000/24578 [20:02<1:18:15,  4.17it/s]

Step 5000
Train Loss 0.7390686273574829


 21%|████████████████████▊                                                                               | 5100/24578 [20:26<1:17:56,  4.17it/s]

Step 5100
Train Loss 0.7270731925964355


 21%|█████████████████████▏                                                                              | 5200/24578 [20:50<1:17:32,  4.17it/s]

Step 5200
Train Loss 0.6190200448036194


 22%|█████████████████████▌                                                                              | 5300/24578 [21:14<1:17:17,  4.16it/s]

Step 5300
Train Loss 0.391155481338501


 22%|█████████████████████▉                                                                              | 5400/24578 [21:38<1:16:40,  4.17it/s]

Step 5400
Train Loss 0.7287760972976685


 22%|██████████████████████▍                                                                             | 5500/24578 [22:02<1:16:16,  4.17it/s]

Step 5500
Train Loss 0.9941000938415527


 23%|██████████████████████▊                                                                             | 5600/24578 [22:26<1:15:53,  4.17it/s]

Step 5600
Train Loss 1.1041549444198608


 23%|███████████████████████▏                                                                            | 5700/24578 [22:50<1:16:11,  4.13it/s]

Step 5700
Train Loss 0.8472114205360413


 24%|███████████████████████▌                                                                            | 5800/24578 [23:14<1:15:06,  4.17it/s]

Step 5800
Train Loss 0.9700510501861572


 24%|████████████████████████                                                                            | 5900/24578 [23:38<1:14:45,  4.16it/s]

Step 5900
Train Loss 0.7808459401130676


 24%|████████████████████████▍                                                                           | 6000/24578 [24:02<1:14:29,  4.16it/s]

Step 6000
Train Loss 0.499031662940979


 25%|████████████████████████▊                                                                           | 6100/24578 [24:26<1:13:50,  4.17it/s]

Step 6100
Train Loss 0.8003214001655579


 25%|█████████████████████████▏                                                                          | 6200/24578 [24:50<1:13:28,  4.17it/s]

Step 6200
Train Loss 0.3552267849445343


 26%|█████████████████████████▋                                                                          | 6300/24578 [25:14<1:13:02,  4.17it/s]

Step 6300
Train Loss 1.0562047958374023


 26%|██████████████████████████                                                                          | 6400/24578 [25:38<1:12:37,  4.17it/s]

Step 6400
Train Loss 1.4638938903808594


 26%|██████████████████████████▍                                                                         | 6500/24578 [26:02<1:12:14,  4.17it/s]

Step 6500
Train Loss 0.8084973096847534


 27%|██████████████████████████▊                                                                         | 6600/24578 [26:26<1:11:48,  4.17it/s]

Step 6600
Train Loss 0.6308437585830688


 27%|███████████████████████████▎                                                                        | 6700/24578 [26:50<1:11:26,  4.17it/s]

Step 6700
Train Loss 0.5023441314697266


 28%|███████████████████████████▋                                                                        | 6800/24578 [27:14<1:11:12,  4.16it/s]

Step 6800
Train Loss 0.44109368324279785


 28%|████████████████████████████                                                                        | 6900/24578 [27:38<1:10:43,  4.17it/s]

Step 6900
Train Loss 1.0222188234329224


 28%|████████████████████████████▍                                                                       | 7000/24578 [28:02<1:10:19,  4.17it/s]

Step 7000
Train Loss 0.5697863698005676


 29%|████████████████████████████▉                                                                       | 7100/24578 [28:26<1:09:58,  4.16it/s]

Step 7100
Train Loss 0.3187579810619354


 29%|█████████████████████████████▎                                                                      | 7200/24578 [28:50<1:09:33,  4.16it/s]

Step 7200
Train Loss 0.8490167856216431


 30%|█████████████████████████████▋                                                                      | 7300/24578 [29:14<1:09:08,  4.17it/s]

Step 7300
Train Loss 0.9186674952507019


 30%|██████████████████████████████                                                                      | 7400/24578 [29:38<1:08:42,  4.17it/s]

Step 7400
Train Loss 0.2599632740020752


 31%|██████████████████████████████▌                                                                     | 7500/24578 [30:02<1:08:19,  4.17it/s]

Step 7500
Train Loss 1.0679067373275757


 31%|██████████████████████████████▉                                                                     | 7600/24578 [30:26<1:07:52,  4.17it/s]

Step 7600
Train Loss 0.8579830527305603


 31%|███████████████████████████████▎                                                                    | 7700/24578 [30:50<1:07:32,  4.16it/s]

Step 7700
Train Loss 0.7357269525527954


 32%|███████████████████████████████▋                                                                    | 7800/24578 [31:14<1:07:07,  4.17it/s]

Step 7800
Train Loss 0.6790230870246887


 32%|████████████████████████████████▏                                                                   | 7900/24578 [31:38<1:06:38,  4.17it/s]

Step 7900
Train Loss 0.5566328763961792


 33%|████████████████████████████████▌                                                                   | 8000/24578 [32:02<1:06:17,  4.17it/s]

Step 8000
Train Loss 0.5200955271720886


 33%|████████████████████████████████▉                                                                   | 8100/24578 [32:26<1:05:50,  4.17it/s]

Step 8100
Train Loss 1.0459331274032593


 33%|█████████████████████████████████▎                                                                  | 8200/24578 [32:50<1:05:27,  4.17it/s]

Step 8200
Train Loss 0.5879429578781128


 34%|█████████████████████████████████▊                                                                  | 8300/24578 [33:14<1:05:10,  4.16it/s]

Step 8300
Train Loss 0.7399945259094238


 34%|██████████████████████████████████▏                                                                 | 8400/24578 [33:38<1:04:40,  4.17it/s]

Step 8400
Train Loss 0.9430161118507385


 35%|██████████████████████████████████▌                                                                 | 8500/24578 [34:02<1:04:15,  4.17it/s]

Step 8500
Train Loss 0.7458324432373047


 35%|██████████████████████████████████▉                                                                 | 8600/24578 [34:26<1:04:10,  4.15it/s]

Step 8600
Train Loss 1.1350581645965576


 35%|███████████████████████████████████▍                                                                | 8700/24578 [34:50<1:03:23,  4.17it/s]

Step 8700
Train Loss 0.3719581961631775


 36%|███████████████████████████████████▊                                                                | 8800/24578 [35:14<1:03:04,  4.17it/s]

Step 8800
Train Loss 0.7535197138786316


 36%|████████████████████████████████████▏                                                               | 8900/24578 [35:38<1:02:39,  4.17it/s]

Step 8900
Train Loss 0.3853336274623871


 37%|████████████████████████████████████▌                                                               | 9000/24578 [36:02<1:02:17,  4.17it/s]

Step 9000
Train Loss 0.8278952836990356


 37%|█████████████████████████████████████                                                               | 9100/24578 [36:26<1:01:56,  4.16it/s]

Step 9100
Train Loss 0.6033402681350708


 37%|█████████████████████████████████████▍                                                              | 9200/24578 [36:50<1:01:28,  4.17it/s]

Step 9200
Train Loss 0.9524716734886169


 38%|█████████████████████████████████████▊                                                              | 9300/24578 [37:14<1:01:05,  4.17it/s]

Step 9300
Train Loss 0.3319493532180786


 38%|██████████████████████████████████████▏                                                             | 9400/24578 [37:38<1:00:42,  4.17it/s]

Step 9400
Train Loss 0.5668054223060608


 39%|██████████████████████████████████████▋                                                             | 9500/24578 [38:03<1:00:18,  4.17it/s]

Step 9500
Train Loss 1.0928988456726074


 39%|███████████████████████████████████████▊                                                              | 9600/24578 [38:26<59:54,  4.17it/s]

Step 9600
Train Loss 0.4000135660171509


 39%|████████████████████████████████████████▎                                                             | 9700/24578 [38:50<59:28,  4.17it/s]

Step 9700
Train Loss 0.2266545593738556


 40%|████████████████████████████████████████▋                                                             | 9800/24578 [39:15<59:03,  4.17it/s]

Step 9800
Train Loss 0.5241392254829407


 40%|█████████████████████████████████████████                                                             | 9900/24578 [39:38<58:41,  4.17it/s]

Step 9900
Train Loss 0.8208949565887451


 41%|█████████████████████████████████████████                                                            | 10000/24578 [40:03<58:30,  4.15it/s]

Step 10000
Train Loss 0.7854084372520447


 41%|█████████████████████████████████████████▌                                                           | 10100/24578 [40:27<57:55,  4.17it/s]

Step 10100
Train Loss 0.5106386542320251


 42%|█████████████████████████████████████████▉                                                           | 10200/24578 [40:51<57:32,  4.17it/s]

Step 10200
Train Loss 0.22628937661647797


 42%|██████████████████████████████████████████▎                                                          | 10300/24578 [41:15<57:12,  4.16it/s]

Step 10300
Train Loss 0.9260489344596863


 42%|██████████████████████████████████████████▋                                                          | 10400/24578 [41:39<56:42,  4.17it/s]

Step 10400
Train Loss 0.6641349196434021


 43%|███████████████████████████████████████████▏                                                         | 10500/24578 [42:03<56:19,  4.17it/s]

Step 10500
Train Loss 0.3083038032054901


 43%|███████████████████████████████████████████▌                                                         | 10600/24578 [42:27<55:51,  4.17it/s]

Step 10600
Train Loss 0.4260109066963196


 44%|███████████████████████████████████████████▉                                                         | 10700/24578 [42:51<55:32,  4.16it/s]

Step 10700
Train Loss 0.4095516502857208


 44%|████████████████████████████████████████████▍                                                        | 10800/24578 [43:15<55:06,  4.17it/s]

Step 10800
Train Loss 0.2940039336681366


 44%|████████████████████████████████████████████▊                                                        | 10900/24578 [43:39<54:40,  4.17it/s]

Step 10900
Train Loss 0.5046456456184387


 45%|█████████████████████████████████████████████▏                                                       | 11000/24578 [44:03<54:18,  4.17it/s]

Step 11000
Train Loss 0.6651312708854675


 45%|█████████████████████████████████████████████▌                                                       | 11100/24578 [44:27<53:56,  4.16it/s]

Step 11100
Train Loss 0.9722136855125427


 46%|██████████████████████████████████████████████                                                       | 11200/24578 [44:51<53:40,  4.15it/s]

Step 11200
Train Loss 0.39296862483024597


 46%|██████████████████████████████████████████████▍                                                      | 11300/24578 [45:15<53:11,  4.16it/s]

Step 11300
Train Loss 0.765688955783844


 46%|██████████████████████████████████████████████▊                                                      | 11400/24578 [45:39<52:42,  4.17it/s]

Step 11400
Train Loss 0.6031844019889832


 47%|███████████████████████████████████████████████▎                                                     | 11500/24578 [46:03<52:19,  4.17it/s]

Step 11500
Train Loss 0.17470712959766388


 47%|███████████████████████████████████████████████▋                                                     | 11600/24578 [46:27<51:52,  4.17it/s]

Step 11600
Train Loss 0.46477001905441284


 48%|████████████████████████████████████████████████                                                     | 11700/24578 [46:51<51:29,  4.17it/s]

Step 11700
Train Loss 0.24786162376403809


 48%|████████████████████████████████████████████████▍                                                    | 11800/24578 [47:15<51:05,  4.17it/s]

Step 11800
Train Loss 0.7010849118232727


 48%|████████████████████████████████████████████████▉                                                    | 11900/24578 [47:39<50:44,  4.16it/s]

Step 11900
Train Loss 0.40603065490722656


 49%|█████████████████████████████████████████████████▎                                                   | 12000/24578 [48:03<50:14,  4.17it/s]

Step 12000
Train Loss 0.043240271508693695


 49%|█████████████████████████████████████████████████▋                                                   | 12100/24578 [48:27<49:54,  4.17it/s]

Step 12100
Train Loss 1.3467246294021606


 50%|██████████████████████████████████████████████████▏                                                  | 12200/24578 [48:51<49:30,  4.17it/s]

Step 12200
Train Loss 0.6607776284217834


 50%|██████████████████████████████████████████████████▌                                                  | 12289/24578 [49:12<41:49,  4.90it/s]

Epoch 0 done


KeyboardInterrupt: 

In [19]:
# Dump Logits
import pickle

OUT_MODEL = '/nlp/scr/ananjan/graph_models/roberta_all/checkpoint-2842'
device = torch.device("cuda")
model = AutoModelForSequenceClassification.from_pretrained(OUT_MODEL, num_labels=len(labelidx2arxivcategeory))
dump_dataset = Dataset.from_dict(dataset_dict)
dump_dataset = dump_dataset.map(tokenizer_helper, batched=True, batch_size=BATCH_SIZE)
dump_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

dataloader = DataLoader(dump_dataset, shuffle=False, batch_size=BATCH_SIZE)

Map:   0%|          | 0/169343 [00:00<?, ? examples/s]

In [22]:
from tqdm import tqdm

total_logits = []
model.to(device)
for batch in tqdm(dataloader):
    batch_gpu = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch_gpu)
        logits = list(outputs.logits.cpu().numpy())
        total_logits += logits
total_logits = np.array(total_logits)
        
with open("/nlp/scr/ananjan/graph_embeddings/finetuned/roberta_logits_arxiv.pkl", 'wb') as f:
    pickle.dump(total_logits, f)

100%|███████████████████████████████████████████████████████████████████| 10584/10584 [27:16<00:00,  6.47it/s]
