# NAR training pipeline for distilbert

In [None]:
#| default_exp 60-nar-inference-pipeline-for-distilbert

In [None]:
#| hide
from nbdev.showdoc import *
import nbdev; nbdev.nbdev_export()

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import os, torch, torch.nn.functional as F, pickle
from tqdm.auto import tqdm
from xcai.basics import *
from xcai.models.MMM0XX import DBT007

comet_ml is installed but `COMET_API_KEY` is not set.


In [None]:
os.environ['WANDB_MODE'] = 'disabled'

In [None]:
#| export
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ['WANDB_PROJECT']='xc-nlg_20-nar-training-pipeline-for-distilbert'

In [None]:
#| export
pkl_dir = '/home/scai/phd/aiz218323/scratch/datasets'
pkl_file = f'{pkl_dir}/processed/wikiseealso_data_distilbert-base-uncased_xcnlg_ngame.pkl'

with open(pkl_file, 'rb') as file: block = pickle.load(file)

## BM25

In [None]:
#| export
import tiktoken, math, pickle
from stop_words import get_stop_words
from langdetect import detect
from typing import List
from rank_bm25 import BM25Okapi

In [None]:
#| export
def remove_multilingual_stopwords(text: str) -> str:
    # Detect the language of the text
    try: lang = detect(text)
    except: return text

    # Get the list of stop words for the detected language
    try: stop_words = set(get_stop_words(lang))
    except: return text

    words = text.split()
    filtered_words = [word for word in words if word.lower() not in stop_words]
    return " ".join(filtered_words)

encoder = tiktoken.encoding_for_model("gpt-4")
def preprocess_func(text: str) -> List[str]:
    lowered = text.lower()
    tokens = encoder.encode(lowered)
    return [str(token) for token in tokens]

def tokenize(text): return preprocess_func(remove_multilingual_stopwords(text))

def tokenizer(text): 
    return [tokenize(o) for o in tqdm(text, total=len(text))]

def get_scores(text):
    preds = []
    for o in tqdm(text, total=len(text)):
        sc = torch.tensor(bm25.get_scores(tokenize(o)))
        sc, idx = torch.topk(sc, 200)
        preds.append((sc,idx))
    return preds

In [None]:
#| export
from multiprocessing import Pool
from itertools import chain

In [None]:
n_proc, n_lbl = 8, block.n_lbl
bsz = math.ceil(n_lbl/n_proc)

lbl_text = [block.train.dset.data.lbl_info['input_text'][i*bsz:(i+1)*bsz] for i in range(n_proc)]
    
with Pool(processes=n_proc) as pool:
    lbl_text = list(chain(*tqdm(pool.map(tokenizer, lbl_text))))

data_dir = '/home/scai/phd/aiz218323/scratch/outputs/60-nar-inference-pipeline-for-distilbert'
with open(f'{data_dir}/wikiseealso-lbl.bow', 'wb') as file: 
    pickle.dump(lbl_text, file)


In [None]:
#| export
data_dir = '/home/scai/phd/aiz218323/scratch/outputs/60-nar-inference-pipeline-for-distilbert'
with open(f'{data_dir}/wikiseealso-lbl.bow', 'rb') as file: 
    lbl_text = pickle.load(file)

In [None]:
#| export
bm25 = BM25Okapi(lbl_text)

In [None]:
#| export
n_proc, n_data = 8, block.test.dset.data.n_data
bsz = math.ceil(n_data/n_proc)

data_text = block.test.dset.data.data_info['input_text']
data_text = [data_text[i*bsz:(i+1)*bsz] for i in range(n_proc)]

In [None]:
#| export
with Pool(processes=n_proc) as pool:
    output = list(tqdm(pool.map(get_scores, data_text)))

data_dir = '/home/scai/phd/aiz218323/scratch/outputs/60-nar-inference-pipeline-for-distilbert'
with open(f'{data_dir}/wikiseealso-bm25-output.bow', 'wb') as file: 
    pickle.dump(lbl_text, file)

In [None]:
output = get_scores(block.test.dset.data.data_info['input_text'])

  0%|          | 0/177515 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [None]:
data_dir = '/home/scai/phd/aiz218323/scratch/outputs/60-nar-inference-pipeline-for-distilbert'
with open(f'{data_dir}/wikiseealso-bm25-output.bow', 'wb') as file: 
    pickle.dump(lbl_text, file)

## DBT007

In [None]:
tok_idf = get_tok_idf(block.train.dset, field='lbl2data_input_ids', n_cols=30522)

In [None]:
args = XCLearningArguments(
    output_dir='/home/scai/phd/aiz218323/scratch/outputs/20-nar-training-pipeline-for-distilbert-2-1',
    logging_first_step=True,
    per_device_train_batch_size=1024,
    per_device_eval_batch_size=1024,
    representation_num_beams=200,
    representation_accumulation_steps=100,
    save_strategy="steps",
    evaluation_strategy='steps',
    eval_steps=2000,
    save_steps=2000,
    save_total_limit=5,
    num_train_epochs=100,
    adam_epsilon=1e-8,
    warmup_steps=0,
    weight_decay=0.1,
    learning_rate=2e-4,
    generation_num_beams=10,
    generation_length_penalty=1.5,
    predict_with_generation=True,
    label_names=['lbl2data_idx'],
    target_indices_key='plbl2data_idx',
    target_pointer_key='plbl2data_data2ptr',
    fp16=True,
)

In [None]:
test_dset = block.test.dset.sample(n=2000, seed=50)
metric = PrecRecl(block.n_lbl, test_dset.data.data_lbl_filterer, prop=block.train.dset.data.data_lbl,
                  pk=10, rk=200, rep_pk=[1, 3, 5, 10], rep_rk=[10, 100, 200])

In [None]:
mname = get_best_model(args.output_dir)
model = DBT007.from_pretrained(mname, tn_targ=10_000, ig_tok=0, vocab_weights=tok_idf, reduction='mean')

In [None]:
trie = XCTrie.from_block(block)

In [None]:
learn = XCLearner(
    model=model, 
    args=args,
    trie=trie,
    train_dataset=block.train.dset,
    eval_dataset=test_dset,
    data_collator=block.collator,
    compute_metrics=metric,
)

In [None]:
o = learn.predict(test_dset)

In [None]:
o.metrics