# Language Model Pre-training

This notebook presents how to pretrain ULMFiT language model on the **ArxivPapers** dataset. You can download the pretrained model at https://github.com/sotagents/axcell/releases/download/v1.0/lm.pth.xz .

In [2]:
BATCH_SIZE = 256
BPTT = 80
VOCAB_SIZE = 30000
UNIGRAM_MODEL_SENTENCES = 5000000

In [2]:
from axcell.data.paper_collection import PaperCollection
from pathlib import Path

# path to extracted papers from ArxivPapers dataset
PAPERS_PATH = Path('./data/arxiv-papers/papers')

In [4]:
%time pc = PaperCollection.from_files(PAPERS_PATH, load_tables=False)

CPU times: user 2min 49s, sys: 11.3 s, total: 3min
Wall time: 5min 46s


In [4]:
from axcell.helpers.datasets import read_arxiv_papers

V1_URL = 'https://github.com/sotagents/axcell/releases/download/v1.0/'
ARXIV_PAPERS_URL = V1_URL + 'arxiv-papers.csv.xz'
arxiv_papers = read_arxiv_papers(ARXIV_PAPERS_URL)

assert len(pc) == (arxiv_papers.status == 'success').sum()

In [3]:
import re
import pandas as pd
import numpy as np

anchors_re = re.compile(r"xxanchor-[^ ]*")
refs_re = re.compile(r"xxref-[^ ]*")


def remove_anchors(s):
    return anchors_re.sub("", s)

def replace_references(s):
    return refs_re.sub("xxref", s)

def clean_text(s):
    s = remove_anchors(s)
    s = replace_references(s)
    return s

def get_texts(pc):
    texts = []
    for p in sorted(pc, key=lambda p: p.paper_id):
        # do not include empty texts
        if not hasattr(p.text, "fragments"):
            continue
        header = f"Title\n{p.text.title}\n\nAbstract\n{p.text.abstract}\n\nBody\n"
        last_section = None
        fragments = []
        for f in p.text.fragments:
            if last_section != f.header:
                fragments.append(f.header+"\n")
                last_section = f.header
            fragments.append(f.text)
        text = header + '\n'.join(fragments)
        text = clean_text(text)
        texts.append(text)
    return pd.DataFrame({'text': texts})

In [8]:
%time texts = get_texts(pc)

CPU times: user 1min 25s, sys: 3.19 s, total: 1min 28s
Wall time: 1min 28s


In [19]:
print(texts.text.iloc[-1][:100]+'...')

Title
VQA-LOL: Visual Question Answering under the Lens of Logic

Abstract
Logical connectives and t...


In [32]:
# texts.to_pickle("/data/arxiv/dumps/arxiv-papers-texts-dataframe.pkl")

In [4]:
texts = pd.read_pickle("/data/arxiv/dumps/arxiv-papers-texts-dataframe.pkl")

Reduce number of sentences to avoid sentencepiece badalloc

In [5]:
sentences = [
    sentence
    for text in texts.text.values
    for sentence in text.split('\n')
    if sentence.strip()
]

np.random.seed(12345)

indices = np.random.choice(range(len(sentences)), size=UNIGRAM_MODEL_SENTENCES, replace=False)
sentences = [sentences[index] for index in indices]

In [3]:
from fastai.text import *

BASE_PATH = Path('./models')
BASE_PATH.mkdir(parents=True, exist_ok=True)

processor = SPProcessor(vocab_sz=VOCAB_SIZE, mark_fields=True)

In [7]:
%time processor.train_func(sentences, BASE_PATH)

CPU times: user 1h 40min 51s, sys: 36.1 s, total: 1h 41min 27s
Wall time: 41min 29s


PosixPath('models/tmp')

In [4]:
processor = SPProcessor(sp_model=BASE_PATH / "tmp" / "spm.model", sp_vocab=BASE_PATH / "tmp" / "spm.vocab", mark_fields=True)

In [7]:
%%time

data_lm = (
    TextList.from_df(
        texts, BASE_PATH, cols="text", processor=processor
    ).split_by_rand_pct(0.1, seed=12345)
     .label_for_lm()
     .databunch(bs=BATCH_SIZE, bptt=BPTT)
)

CPU times: user 17min 29s, sys: 51 s, total: 18min 20s
Wall time: 26min 28s


In [8]:
# data_lm.save('arxiv-papers-texts-data_lm.pkl')

In [5]:
data_lm = load_data(BASE_PATH, 'arxiv-papers-texts-data_lm.pkl', bs=BATCH_SIZE, bptt=BPTT)



In [6]:
learn = language_model_learner(
    data_lm, AWD_LSTM, drop_mult=0.1,
    pretrained=False, metrics=[accuracy, Perplexity()]
).to_fp16(clip=0.1)

learn.fit_one_cycle(cyc_len=12, max_lr=0.01, moms=(0.8, 0.7), div_factor=10, wd=0.1)

epoch,train_loss,valid_loss,accuracy,perplexity,time
0,3.019458,3.264306,0.392344,26.161938,1:54:36
1,3.056603,3.422664,0.376507,30.651068,1:53:43
2,3.141768,3.550231,0.362796,34.821327,1:53:26
3,3.090492,3.525985,0.36687,33.987396,1:53:16
4,3.107407,3.491773,0.370532,32.844139,1:54:11
5,3.059378,3.445549,0.375365,31.360525,1:54:10
6,3.030591,3.368207,0.382388,29.026358,1:53:57
7,2.965446,3.278792,0.39136,26.543692,1:53:37
8,2.919746,3.163137,0.404793,23.644709,1:53:10
9,2.812866,3.019272,0.421912,20.47644,1:53:43


In [7]:
learn.save_encoder('pretrained-on-papers_enc')

In [8]:
learn.save('pretrained-on-papers_learner_with_opt')