In [22]:
import numpy as np
import pandas as pd
import math
import random

import torch
from torch.optim import AdamW
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import DistilBertTokenizer, DistilBertForSequenceClassification

from safetensors.torch import save_file

In [2]:
# Load the data in chunks and filter right away (easier on memory and time)
path = "../data_sets/arxiv-metadata-oai-snapshot.json"

columns = ["title", "abstract", "categories"]
chunks_filt = []

for chunk in pd.read_json(path, lines=True, chunksize=10000):
    chunk_filt = chunk[columns]
    chunks_filt.append(chunk_filt)

data = pd.concat(chunks_filt, ignore_index=True)

In [3]:
# Split the categories string into a list of single categories
data.categories = data.categories.str.split(" ")
data.head()

Unnamed: 0,title,abstract,categories
0,Calculation of prompt diphoton production cros...,A fully differential calculation in perturba...,[hep-ph]
1,Sparsity-certifying Graph Decompositions,"We describe a new algorithm, the $(k,\ell)$-...","[math.CO, cs.CG]"
2,The evolution of the Earth-Moon system based o...,The evolution of Earth-Moon system is descri...,[physics.gen-ph]
3,A determinant of Stirling cycle numbers counts...,We show that a determinant of Stirling cycle...,[math.CO]
4,From dyadic $\Lambda_{\alpha}$ to $\Lambda_{\a...,In this paper we show how to compute the $\L...,"[math.CA, math.FA]"


In [4]:
# Check for the most frequently published under categories (excluding cross-listing)
num_labels = 20
cat_ranking = data.categories.map(lambda s: s[0]).value_counts()

cats = cat_ranking.index.tolist()[:num_labels]
cat_ranking.iloc[:num_labels]

categories
hep-ph                137418
cs.CV                 122466
quant-ph              119233
hep-th                109071
cs.LG                 108537
astro-ph               94246
gr-qc                  67256
cond-mat.mes-hall      66164
cs.CL                  66147
cond-mat.mtrl-sci      65379
math.AP                52843
astro-ph.GA            50283
cond-mat.str-el        50130
math.CO                48775
astro-ph.SR            45457
astro-ph.HE            42273
astro-ph.CO            42268
cond-mat.stat-mech     42007
math.PR                41237
math.AG                37692
Name: count, dtype: int64

In [5]:
# Just for fun: check how many papers are listed under how many different categories
data.categories.map(lambda s: len(s)).value_counts()

categories
1     1458652
2      829424
3      335912
4      108365
5       30931
6        6672
7         961
8         161
9          33
10         14
11          2
13          1
Name: count, dtype: int64

In [66]:
# Keep only the main arXiv category for each paper (forget cross-listed ones); then only keep entries whose cat is in top 20 cats
N = 1000

data_filt = data.copy()
data_filt.categories = data_filt.categories.map(lambda x: x[0])
data_filt = data_filt[data_filt.categories.isin(cats)]

arXiv_data = data_filt.groupby('categories', group_keys=False).sample(N, random_state=3)

In [67]:
# Split into train/test data; here, train share is 80%
n_train = .8

train_data = arXiv_data.groupby('categories', group_keys=False).head(math.floor(n_train*N)).sample(frac=1, random_state=42)
test_data = arXiv_data.groupby('categories', group_keys=False).tail(math.ceil(N-n_train*N))

In [11]:
# Check that each top 20 category appear exactly 800 times
train_data.categories.value_counts()

categories
cs.CV                 800
cond-mat.mes-hall     800
math.CO               800
quant-ph              800
astro-ph.GA           800
astro-ph.SR           800
math.AG               800
cond-mat.str-el       800
cond-mat.mtrl-sci     800
astro-ph.HE           800
math.PR               800
hep-th                800
hep-ph                800
cs.CL                 800
gr-qc                 800
astro-ph.CO           800
cs.LG                 800
astro-ph              800
cond-mat.stat-mech    800
math.AP               800
Name: count, dtype: int64

In [12]:
# Turn arXiv data into torch Dataset
class ArxivDataset(Dataset):
    def __init__(self, text, target):
        self.text = text
        self.target = target

    def __len__(self):
        return len(self.text['input_ids'])

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.text.items()}
        label = torch.tensor(self.target[idx], dtype=torch.long)
        return item, label

In [13]:
# Translate between encoded and decoded labels
label2id = {key: val for val, key in enumerate(sorted(cats))}
id2label = {val: key for key, val in label2id.items()}

# Encode labels
def label_encoder(data_frame):
    return np.array([label2id[x] for x in data_frame.categories.tolist()])

# Tokenize text input; drop abstract with probability p
def feature_encoder(data_frame, tokenize, max_len, p):
    # drop the abstract w/ prob p
    li = []
    for i, row in data_frame.iterrows():
        rd = random.random()
        if rd > p:
            li.append(f"Title: {row.title}, Abstract: {row.abstract}")
        else:
            li.append(f"Title: {row.title}, Abstract: ")

    return tokenize(
            li,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=max_len
        )

In [52]:
# Choose model data
model_name = 'distilbert-base-cased'
max_length = 512
batch_size = 8
torch.manual_seed(42)

tokenizer = DistilBertTokenizer.from_pretrained(model_name)

In [None]:
# Encode labels, features from our train/test datasets prepared above and convert to torch Dataset
labels = label_encoder(train_data)
features = feature_encoder(train_data, tokenizer, max_length, p=0.2)
train_sample = ArxivDataset(features, labels)
train_loader = DataLoader(train_sample, batch_size=batch_size, shuffle=True)

In [68]:
# For testing, choose p=1, i.e. no abstracts are included
labels = label_encoder(test_data)
features = feature_encoder(test_data, tokenizer, max_length, p=1.0)
test_sample = ArxivDataset(features, labels)
test_loader = DataLoader(test_sample, batch_size=batch_size, shuffle=True)

In [19]:
# To improve accuracy and time, we choose decaying learning rates: high ones for the final layers and very small ones for base layers (as we want to fine-tune)
def get_optimiser_parameters(model, base_lr, weight_decay, lr_decay=0.95):
    # Get model layers
    layers = [model.distilbert.embeddings] + list(model.distilbert.transformer.layer)
    layers.reverse()
    param_groups = []

    # Add classifier head
    param_groups.append({
        'params': model.classifier.parameters(),
        'lr': base_lr,
        'weight_decay': weight_decay
    })

    # Decay learning rate through the layers
    for i, layer in enumerate(layers):
        lr = base_lr * (lr_decay ** i)
        param_groups.append({
            'params': layer.parameters(),
            'lr': lr,
            'weight_decay': weight_decay
        })

    return param_groups

In [20]:
# Set up the model, loss function and optimiser (using lr decay as defined above)
model = DistilBertForSequenceClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    label2id=label2id,
    id2label=id2label
)

loss_function = nn.CrossEntropyLoss()

base_lr = 1e-5
weight_decay = 1e-2
optimiser_params = get_optimiser_parameters(model, base_lr, weight_decay)
optimiser = AdamW(optimiser_params)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
# Training loop
n_epochs = 5

print("Training...")
model.train()
for epoch in range(n_epochs):
    epoch_loss = 0.0
    correct = 0
    number = 0

    for i, batch in enumerate(train_loader):
        print(f'Training batch {i + 1}/{len(train_loader)}')

        features, labels = batch
        optimiser.zero_grad()
        outputs = model(**features)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)

        loss = loss_function(logits, labels)
        epoch_loss += loss.item()

        acc_labels = (labels == preds).sum().item()
        correct += acc_labels
        num_lab = labels.numel()
        number += num_lab
        acc = acc_labels / num_lab

        loss.backward()
        optimiser.step()

        print(f"Loss = {loss.item():.4f}, Acc = {acc:.4f}")

    print(f"Epoch {epoch + 1}: Loss = {epoch_loss:.4f}, Accuracy = {correct / number:.4f}")

Training...
Training batch 1/2000
Loss = 2.9883, Acc = 0.0000
Training batch 2/2000
Loss = 3.0597, Acc = 0.0000
Training batch 3/2000
Loss = 3.0383, Acc = 0.1250
Training batch 4/2000
Loss = 3.0437, Acc = 0.0000
Training batch 5/2000
Loss = 3.0591, Acc = 0.0000
Training batch 6/2000
Loss = 3.0272, Acc = 0.0000
Training batch 7/2000
Loss = 3.0201, Acc = 0.0000
Training batch 8/2000
Loss = 3.0190, Acc = 0.0000
Training batch 9/2000
Loss = 3.0559, Acc = 0.0000
Training batch 10/2000
Loss = 3.0070, Acc = 0.1250
Training batch 11/2000
Loss = 3.0687, Acc = 0.0000
Training batch 12/2000
Loss = 3.0282, Acc = 0.0000
Training batch 13/2000
Loss = 2.9158, Acc = 0.1250
Training batch 14/2000
Loss = 2.9222, Acc = 0.1250
Training batch 15/2000
Loss = 3.0397, Acc = 0.0000
Training batch 16/2000
Loss = 3.0425, Acc = 0.0000
Training batch 17/2000
Loss = 3.0083, Acc = 0.0000
Training batch 18/2000
Loss = 2.9627, Acc = 0.0000
Training batch 19/2000
Loss = 2.9601, Acc = 0.1250
Training batch 20/2000
Loss 

In [23]:
save_file(model.state_dict(), 'arXiv_multi_class.safetensors')

In [69]:
# Testing loop
print('Testing...')

epoch_loss = 0.0
correct = 0
number = 0

model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        print(f'Testing batch {i + 1}/{len(test_loader)}')

        features, labels = batch
        optimiser.zero_grad()
        outputs = model(**features)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)

        loss = loss_function(logits, labels)
        epoch_loss += loss.item()

        acc_labels = (labels == preds).sum().item()
        correct += acc_labels
        num_lab = labels.numel()
        number += num_lab
        acc = acc_labels / num_lab

        print(f"Loss = {loss.item():.4f}, Acc = {acc:.4f}")

print(f"Loss = {epoch_loss:.4f}, Accuracy = {correct / number:.4f}")

Testing...
Testing batch 1/500
Loss = 0.5258, Acc = 0.8750
Testing batch 2/500
Loss = 0.8026, Acc = 0.6250
Testing batch 3/500
Loss = 2.2231, Acc = 0.3750
Testing batch 4/500
Loss = 1.7356, Acc = 0.5000
Testing batch 5/500
Loss = 0.6652, Acc = 0.8750
Testing batch 6/500
Loss = 2.3263, Acc = 0.2500
Testing batch 7/500
Loss = 1.3173, Acc = 0.6250
Testing batch 8/500
Loss = 1.1849, Acc = 0.6250
Testing batch 9/500
Loss = 2.7510, Acc = 0.2500
Testing batch 10/500
Loss = 0.7739, Acc = 0.7500
Testing batch 11/500
Loss = 1.0772, Acc = 0.6250
Testing batch 12/500
Loss = 2.4410, Acc = 0.5000
Testing batch 13/500
Loss = 1.0691, Acc = 0.7500
Testing batch 14/500
Loss = 0.9057, Acc = 0.7500
Testing batch 15/500
Loss = 1.6623, Acc = 0.3750
Testing batch 16/500
Loss = 1.8628, Acc = 0.3750
Testing batch 17/500
Loss = 0.8815, Acc = 0.7500
Testing batch 18/500
Loss = 0.4630, Acc = 0.7500
Testing batch 19/500
Loss = 1.0615, Acc = 0.5000
Testing batch 20/500
Loss = 0.9111, Acc = 0.7500
Testing batch 21/5

In [24]:
# Define a function that classifies new titles
def classify(s):
    # Tokenize input string
    tok_title = tokenizer(
            s,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=max_length
        )

    # Run model on tokenized input
    model.eval()
    with torch.no_grad():
        outputs = model(**tok_title)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)

        print(f"class: {id2label[preds[0].item()]}")

In [25]:
classify('Pseudo-algebroids for categorical rings')

class: math.AG


In [26]:
classify('Free electron trajectories in graphene')

class: cond-mat.mes-hall


In [34]:
classify('Equivariant DW Invariants from Physical Fluxes')

class: hep-th
