In [1]:
import numpy as np
import pandas as pd
import math
import random

import torch
from torch.optim import AdamW, SGD
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

from safetensors.torch import save_file, load_file

# Preparing arXiv data

Here we load the arXiv data set arxiv-metadata-oai-snapshot.json from Kaggle (https://www.kaggle.com/datasets/Cornell-University/arxiv/data) into a pandas DataFrame and perform the following cleaning operations:

- restrict to title, abstract, categories columns
- restrict to the num_labels=20 most popular categories
- perform balanced sampling of papers such that each category appears exactly N=1000 times

NB: the code can easily be adapted to include more categories for classification.

In [2]:
# Load the data in chunks and filter right away (easier on memory and time)
path = "../data_sets/arxiv-metadata-oai-snapshot.json"

columns = ["title", "abstract", "categories"]
chunks_filt = []

for chunk in pd.read_json(path, lines=True, chunksize=10000):
    chunk_filt = chunk[columns]
    chunks_filt.append(chunk_filt)

data = pd.concat(chunks_filt, ignore_index=True)

In [3]:
# Split the categories string into a list of single categories
data.categories = data.categories.str.split(" ")
data.head()

Unnamed: 0,title,abstract,categories
0,Calculation of prompt diphoton production cros...,A fully differential calculation in perturba...,[hep-ph]
1,Sparsity-certifying Graph Decompositions,"We describe a new algorithm, the $(k,\ell)$-...","[math.CO, cs.CG]"
2,The evolution of the Earth-Moon system based o...,The evolution of Earth-Moon system is descri...,[physics.gen-ph]
3,A determinant of Stirling cycle numbers counts...,We show that a determinant of Stirling cycle...,[math.CO]
4,From dyadic $\Lambda_{\alpha}$ to $\Lambda_{\a...,In this paper we show how to compute the $\L...,"[math.CA, math.FA]"


In [17]:
# Check for the most frequently published under categories (including cross-listing)
num_labels = 20
cat_ranking_cross = data.categories.explode().value_counts()

cats = cat_ranking_cross.index.tolist()[:num_labels]
cat_ranking_cross.iloc[:num_labels]

categories
cs.LG                 223626
hep-ph                188414
hep-th                174604
quant-ph              163662
cs.CV                 159690
cs.AI                 132333
gr-qc                 114478
astro-ph              105380
cond-mat.mtrl-sci     100648
cond-mat.mes-hall      96044
cs.CL                  87718
math.MP                84637
math-ph                84636
cond-mat.str-el        78278
cond-mat.stat-mech     77357
astro-ph.CO            72188
math.CO                71920
stat.ML                71882
astro-ph.GA            70770
math.AP                67927
Name: count, dtype: int64

In [35]:
# Just for fun: check how many papers are listed under how many different categories
data.categories.map(lambda s: len(s)).value_counts()

categories
1     1458652
2      829424
3      335912
4      108365
5       30931
6        6672
7         961
8         161
9          33
10         14
11          2
13          1
Name: count, dtype: int64

In [37]:
# A paper with 13 cross-listed categories?! I have to know what it's about...
idx = data.categories.map(lambda s: len(s)).idxmax()
data.loc[idx,:]

title         The finite harmonic oscillator and its associa...
abstract        A system of functions (signals) on the finit...
categories    [cs.IT, cs.CR, cs.DM, math-ph, math.GR, math.I...
Name: 77549, dtype: object

In [5]:
# Perform balanced sampling on the dataset, so as to retain each top 20 category exactly N times (to avoid bias in training)
N = 1000

cat_count = pd.Series([0] * len(cats), index=cats)
selected_rows = []
shuffled_data = data.sample(frac=1, random_state=42)

for i, row in shuffled_data.iterrows():
    top_cats = set(row.categories) & set(cats)

    if len(top_cats) < len(row.categories): continue

    if all(cat_count[cat] < N for cat in top_cats):
        selected_rows.append(row)
        for cat in top_cats:
            cat_count[cat] += 1

    if all(cat_count[cat] >= N for cat in cats):
        break

arXiv_data_cross = pd.DataFrame(selected_rows).sample(frac=1, random_state=42).reset_index(drop=True)

In [6]:
# Split the data set into training and test sets; here the train share is 80%
n_train = .8

train_data_split = arXiv_data_cross.explode('categories').groupby('categories', group_keys=False).head(n_train * N)
train_data = train_data_split.groupby(train_data_split.index).agg(lambda x: list(x) if x.name == 'categories' else x.iloc[0])

test_data_split = arXiv_data_cross.explode('categories').groupby('categories', group_keys=False).tail(N - n_train * N)
test_data = test_data_split.groupby(test_data_split.index).agg(lambda x: list(x) if x.name == 'categories' else x.iloc[0])

In [38]:
# Check that each top 20 category appears exactly 800 times
train_data.explode('categories').categories.value_counts()

categories
cond-mat.mtrl-sci     800
stat.ML               800
cs.CV                 800
hep-th                800
quant-ph              800
gr-qc                 800
astro-ph.CO           800
hep-ph                800
math.CO               800
astro-ph              800
astro-ph.GA           800
cs.AI                 800
cs.CL                 800
cond-mat.str-el       800
cond-mat.mes-hall     800
math.MP               800
math-ph               800
cond-mat.stat-mech    800
math.AP               800
cs.LG                 800
Name: count, dtype: int64

# Training and Testing

Now we can start preparing the chosen data for training/testing of the chosen model---here it's distilbert-base-cased: (relatively) lightweight, cased in order to recognize terms named after people (e.g. Hall). For training/testing we use the Hugging Face trainer API

- First, we turn our sample arXiv data into a Hugging Face Dataset, using both title and abstract; however, since we only want to test with titles, we drop abstracts with probability p=0.2
- Since each paper can be listed under multiple categories, this is a multi-label classification model; we use multi-hot encoded labels (i.e. vectors of dim=N=20)
- Instead of choosing a single category as the winner (using torch.argmax), we look at probabilities for each category (torch.sigmoid) and choose a cutoff at 50%

In [7]:
# Turn arXiv data into torch Dataset class
class ArXivDatasetHF(Dataset):
    def __init__(self, text, target):
        self.text = text
        self.target = target

    def __len__(self):
        return len(self.text['input_ids'])

    def __getitem__(self, idx):
        dct = {key: val[idx] for key, val in self.text.items()}
        dct["labels"] = torch.tensor(self.target[idx], dtype=torch.float32)
        return dct

In [8]:
# Translate between encoded and decoded labels
label2id = {key: val for val, key in enumerate(sorted(cats))}
id2label = {val: key for key, val in label2id.items()}

# Encode labels with multi-het
def label_encoder(data_frame):
    labels = []
    for label in data_frame.categories.tolist():
        id = [0] * len(cats)
        for l in label:
            id[label2id[l]] = 1
        labels.append(id)
    return np.array(labels)

# Tokenize text input
def feature_encoder(data_frame, tokenize, max_len, p):
    # Drop abstract w/ prob p
    li = []
    for i, row in data_frame.iterrows():
        rd = random.random()
        if rd > p:
            # NB: we include Title: and Abstract: in front, hoping the model would learn to understand what we provide
            li.append(f"Title: {row.title}, Abstract: {row.abstract}")
        else:
            li.append(f"Title: {row.title}, Abstract: ")

    return tokenize(
            li,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=max_len
        )

In [9]:
# Choose model data
model_name = 'distilbert-base-cased'  # 'tbs17/MathBERT'
max_length = 512
batch_size = 8
torch.manual_seed(42)
tokenizer = DistilBertTokenizer.from_pretrained(model_name)

# Encode labels, features from our train/test datasets prepared above and convert to HF Dataset
labels = label_encoder(train_data)
features = feature_encoder(train_data, tokenizer, max_length, p=0.2)
train_sample = ArXivDatasetHF(features, labels)
train_loader = DataLoader(train_sample, batch_size=batch_size, shuffle=True)

features = feature_encoder(test_data, tokenizer, max_length, p=1.0)   # p=1 --> no abstracts in the test sample
labels = label_encoder(test_data)
test_sample = ArXivDatasetHF(features, labels)
test_loader = DataLoader(test_sample, batch_size=batch_size, shuffle=False)

In [11]:
# Set up model to use; we choose 'multi_label_classification' for the problem type
model = DistilBertForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    problem_type='multi_label_classification',
    label2id=label2id,
    id2label=id2label
)

# Define metrics used to evaluate model performance; accuracy is subset accuracy for each label (appropriate for multi-label); F1 score measures precision and recall for multi-label predictions
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    probs = torch.sigmoid(torch.tensor(logits)).numpy()
    preds = (probs > 0.5).astype(int)

    acc = accuracy_score(labels.reshape(-1), preds.reshape(-1))
    f1 = f1_score(labels, preds, average='micro')

    return {"accuracy": acc, "f1_score": f1}

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
# Define arguments for training/testing and HF Trainer
train_args = TrainingArguments(
    output_dir="./output",
    learning_rate=1e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    weight_decay=1e-2,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    logging_strategy="steps",
    logging_steps=1,
    logging_first_step=True,
    report_to=["none"],
    load_best_model_at_end=True
)

trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=train_sample,
    eval_dataset=test_sample,
    processing_class=tokenizer,
    compute_metrics=compute_metrics
)

In [13]:
# Train the model for 5 epochs; evaluate every 50 steps
trainer.train()

Step,Training Loss,Validation Loss,Accuracy,F1 Score
50,0.4087,0.407774,0.934959,0.0
100,0.3267,0.303306,0.934959,0.0
150,0.2433,0.266671,0.934959,0.0
200,0.2368,0.249959,0.934959,0.0
250,0.2527,0.242081,0.934959,0.0
300,0.2274,0.235767,0.934959,0.0
350,0.2451,0.22942,0.934959,0.0
400,0.1976,0.223043,0.934959,0.0
450,0.2514,0.217483,0.934959,0.0
500,0.2305,0.211378,0.934959,0.0


TrainOutput(global_step=7625, training_loss=0.12811356942453345, metrics={'train_runtime': 61884.934, 'train_samples_per_second': 0.986, 'train_steps_per_second': 0.123, 'total_flos': 8082442648473600.0, 'train_loss': 0.12811356942453345, 'epoch': 5.0})

In [None]:
def classify(s):
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model.to(device)
    model.eval()

    # Tokenize input string
    tok_title = tokenizer(
        s,
        return_tensors='pt',
        padding=True,
        truncation=True,
        max_length=max_length
    )

    # Move tokenized input to the device
    tok_title = {k: v.to(device) for k, v in tok_title.items()}

    # Run model on tokenized input
    with torch.no_grad():
        outputs = model(**tok_title)
        logits = outputs.logits
        preds = torch.sigmoid(logits)

        # Output the 5 most probable arXiv categories with corresp probabilities
        top5_prob, top5_idx = torch.topk(preds, 5, dim=1)

    for x in range(5):
        print(f"{id2label[top5_idx[0, x].item()]}: {top5_prob[0, x].item() * 100:.2f}%")

In [41]:
classify("Super Yang-Mills Theory on Toric Orbifolds")

hep-th: 71.24%
math.MP: 34.33%
math-ph: 32.58%
hep-ph: 8.50%
cond-mat.str-el: 4.82%


In [29]:
classify("Exploring new fractional quantum Hall states")

cond-mat.mes-hall: 51.69%
cond-mat.str-el: 40.40%
quant-ph: 25.42%
cond-mat.mtrl-sci: 5.83%
cond-mat.stat-mech: 3.47%


In [40]:
classify("Data science in the era of reinforcement learning")

stat.ML: 65.71%
cs.LG: 44.37%
cs.AI: 19.67%
quant-ph: 2.53%
cs.CV: 1.85%
