In [1]:
import os
import re
import nltk
import json
import pickle
from tqdm import tqdm
import numpy as np
import torch
import sqlite3
import pandas as pd
import random
import transformers
import logging

In [2]:
import itertools
import sys
if '/home/ryparmar/experimental-martin/pretraining/src/' not in sys.path:
    sys.path.append('/home/ryparmar/experimental-martin/pretraining/src')

import util, io_util, eval
from torch.utils.data import DataLoader, TensorDataset
from model import Encoder as Model

## PRETRAINING

In [4]:
class Config:
    def __init__(self):
        self.mode = 'pretraining'
        self.task = 'BFS+ICT'
#         self.claims_path = "/mnt/data/factcheck/CTK/par5/ctk-data"
#         self.articles_path = "/mnt/data/factcheck/CTK/par5/interim/ctk_filtered.db"
#         self.articles_chunks_path = '/mnt/data/factcheck/ict_chunked_data/ids-chunks-288-pretraining-ctk_filtered.pkl' 
        self.claims_path = "/mnt/data/factcheck/fever/data-cs/fever-data"  #"/mnt/data/factcheck/fever/data-cs/fever-data"
        self.articles_path = "/mnt/data/factcheck/fever/data-cs/fever/fever.db"  #"/mnt/data/factcheck/CTK/par4/interim/ctk_filtered.db"
        self.articles_chunks_path = '/mnt/data/factcheck/ict_chunked_data/ids-chunks-288-pretraining-wiki_cs.pkl' #
        self.model_weight = "/home/ryparmar/trained_models/debug.w"
        self.bert_model = "bert-base-multilingual-cased"
        self.learning_rate = 1e-5
        self.max_seq = 288
        self.epoch = 1
        self.bs = 64
        self.test_bs = 64
        self.remove_prob = 0.9
        self.use_cuda = True if torch.cuda.is_available() else False
        self.devices = "0" if torch.cuda.is_available() else ""
        self.continue_training = "/home/ryparmar/trained_models/mbert_wiki_pre_10ep-bfs_10ep-ict_1e-5_288_best"  #False
        self.logger = logging.getLogger(__name__)
    def add(self, name, val):
        if name == 'cls_token_id':
            self.cls_token_id = val
        if name == 'pad_token_id':
            self.pad_token_id = val
        if name == 'device':
            self.device = val
        
config = Config()

In [5]:
def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)


def instantiate_model(config, tokenizer):
    configure_devices(config)
    model = Model(config)
    optimizer = transformers.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=0)
    metrics = None

    if config.continue_training:
        state_dict = torch.load(config.continue_training, map_location='cpu')
        model.load_state_dict(state_dict['model'])
        if 'optimizer_state_dict' in state_dict:
            optimizer.load_state_dict(state_dict['optimizer_state_dict'])
            for g in optimizer.param_groups:
                g['lr'] = config.learning_rate
        
        try:
            print(f"Loaded model:\nEpochs: {state_dict['epoch']}\nLoss: {state_dict['loss']}\n", 
                  f"Recall: {state_dict['rec']}\nMRR: {state_dict['mrr']}")
        except:
            pass
        
    if config.use_cuda:
        model = model.cuda()
        optimizer_to(optimizer, config.device)
        model = torch.nn.DataParallel(model, device_ids=config.devices)
    return model, optimizer, metrics


def configure_devices(config):
    config.devices = [int(device) for device in range(torch.cuda.device_count())]
    config.device = config.devices[0] if config.use_cuda else "cpu"


def get_loader(data, batch_size):
    data = TensorDataset(data)
    return DataLoader(data,
                      batch_size=batch_size,
                      shuffle=True,
                      sampler=None, drop_last=True)

def ids2docs(ids, id2doc: dict):
    return [id2doc[int(i)] for i in ids]

In [6]:
tokenizer = transformers.BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")

config.add('cls_token_id', tokenizer.encode(tokenizer.cls_token, add_special_tokens=False)[0])
config.add('pad_token_id', tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0])

In [7]:
def tok(x):
    print(tokenizer.convert_ids_to_tokens(x))

In [8]:
model, optimizer, metrics = instantiate_model(config, tokenizer)
loss_fn = torch.nn.CrossEntropyLoss()

Loaded model:
Epochs: [0, 1, 2, 3]
Loss: [0.16748802861127624, 0.15736810512863084, 0.1498402876436489, 0.14569516037571323]
 Recall: [0.440744, 0.450195, 0.377738, 0.528953]
MRR: [0.259687, 0.266514, 0.216023, 0.319852]


In [9]:
metrics = eval.Metrics(metrics)

# Main

In [9]:
doc_chunks = util.make_chunks(config.articles_path, tokenizer, config, save_chunks=True)
# doc_chunks = list of documents; document = list of chunks; 
# chunk = list of sentences; sentence = list of tokens

In [10]:
type(doc_chunks)

dict

In [11]:
ids = list(doc_chunks.keys())

In [12]:
print(len(doc_chunks), len(doc_chunks[ids[0]]), len(doc_chunks[ids[0]][0]))

2309706 3 6


In [13]:
tok(doc_chunks[ids[0]][0][0])

['Z', '##lín', '15', '.', 'srpna', '(', 'Č', '##T', '##K', ')', '-', 'Kraj', '##ský', 'sou', '##d', 've', 'Z', '##lín', '##ě', 'dok', '##on', '##čuje', 'dok', '##azo', '##vání', 'v', 'kor', '##up', '##ční', 'ka', '##uze', 'kolem', 'fina', '##nční', '##ho', 'ú', '##řadu', 'v', 'Kromě', '##říž', '##i', '.']


In [14]:
util.nested_list_len(doc_chunks[ids[0]][0])

239

### Convert dict of chunks into list of all chunks 

In [28]:
doc_chunks = [chunk for doc, chunks in doc_chunks.items() for chunk in chunks]

In [29]:
len(doc_chunks)

1143110

In [30]:
tok(doc_chunks[6][0])

['Kromě', 'prof', '##esi', '##on', '##álních', 'astronom', '##ů', 'se', 'astronomi', '##i', 'v', '##ěn', '##uje', 'i', 'řada', 'astronom', '##ů', 'amat', '##ér', '##ských', '.']


In [15]:
dev_chunks = util.make_chunks("/mnt/data/factcheck/fever/data-cs/fever/fever.db", 
                                tokenizer, config, as_eval=True, save_chunks=True)
dev_articles_ids = list(dev_chunks.keys())
dev_chunks, dev_chunks_mask = util.process_chunks(dev_chunks, config)

Padding chunks...: 100%|██████████| 451629/451629 [00:14<00:00, 30760.37it/s]


In [16]:
claims_dev, evidence_dev, labels_dev = util.load_claims('dev', config,
                                             path='/mnt/data/factcheck/fever/data-cs/fever-data/dev.jsonl')
claims_dev, claims_dev_mask = util.process_claims(claims_dev, tokenizer, config, _pad_max=True)

Loaded 9999 claims from dev split.


## Evaluation

In [None]:
# Sample the documents
c = 0
sdev_ch, sdev_m = {}, {}
for k, v in dev_chunks.items():
    sdev_ch[k] = v
    sdev_m[k] = dev_chunks_mask[k]
    c+=1
    if c == 1000:
        break
    
# print(type(sdev_ch), type(sdev_m))

In [12]:
eval_claim_embeddings, eval_document_embeddings = eval.evaluation_preprocessing(claims_dev, claims_dev_mask, 
                                                                                dev_chunks, dev_chunks_mask, model, config)

Generating chunks embeddings...: 100%|██████████| 451629/451629 [00:04<00:00, 91534.08it/s] 
Embedding given chunks...: 100%|██████████| 7057/7057 [38:54<00:00,  3.02it/s]
Embedding given chunks...: 100%|██████████| 157/157 [00:51<00:00,  3.05it/s]


In [13]:
print(eval_claim_embeddings.shape, eval_document_embeddings.shape)

(9999, 512) (451629, 512)


In [50]:
evidence_dev[1][0][0][2]

'Sammy Cahn'

In [51]:
predicted = np.array(['Sammy Cahn', 'Sammy', 'Cahn', 'Sammy Cahn'])

In [52]:
ranks = np.where(predicted == evidence_dev[1][0][0][2])[0][-1]

In [53]:
ranks = [np.where(predicted_evidence == ev)[0][-1] for ev in evidence if ev in predicted_evidence]

3

In [64]:
kk = 20
precision, recall, f1, mrr = retriever_score(eval_document_embeddings,dev_articles_ids, eval_claim_embeddings, 
                                            evidence_dev, labels_dev, config, k=kk)
print(f"F1: {f1}\tPrecision@{kk}: {precision}\tRecall@{kk}: {recall}\tMRR@{kk}: {mrr}")

Calculating evaluation metrics: 9999it [57:35,  2.89it/s]

F1: 0.055573	Precision@20: 0.029298	Recall@20: 0.538554	MRR@20: 0.34086650690207226





## Training

In [17]:
loader = ( get_loader(torch.tensor([i for i in range(len(claims_train))]), config.bs) 
               if config.mode == 'finetuning'
               else get_loader(torch.tensor([i for i in range(len(doc_chunks))]), config.bs))

id2doc = {i: doc_id for i, (doc_id, _) in enumerate(doc_chunks.items())} if isinstance(doc_chunks, dict) else []

In [24]:
batch = next(iter(loader))
batch = batch[0]

In [25]:
batch

tensor([ 678107,  893890,  116736,  719527,  500809,  313594,  528943,  928537,
        1350184, 1197479,  672918, 2300976, 1733225,  502688,  913880,  184345,
        1088734,  450399, 1796903,   79222,  886595, 1352382, 1331826,  370023,
          37339, 1208541, 1469085, 1892516, 1020169,  303151,  358553,  423013])

In [26]:
# ids2docs(batch, id2doc)

In [27]:
batch = next(iter(loader))
batch = batch[0]
query, query_mask, context, context_mask = util.get_pretraining_batch(ids2docs(batch, id2doc), doc_chunks, 
                                                                            tokenizer, config)
print(f"{query.shape} {context.shape}")
print(f"{query} {context}")

torch.Size([32, 287]) torch.Size([32, 287])
tensor([[  101, 64121, 44254,  ...,     0,     0,     0],
        [  101, 10685, 24204,  ...,     0,     0,     0],
        [  101, 87631,   112,  ...,     0,     0,     0],
        ...,
        [  101, 23488, 13341,  ...,     0,     0,     0],
        [  101, 53068, 10333,  ...,     0,     0,     0],
        [  101, 14074, 17513,  ...,     0,     0,     0]]) tensor([[  101,   294, 13188,  ...,     0,     0,     0],
        [  101, 28096, 10738,  ...,     0,     0,     0],
        [  101, 21416, 10193,  ...,     0,     0,     0],
        ...,
        [  101, 10469, 11798,  ...,     0,     0,     0],
        [  101, 23837, 10738,  ...,     0,     0,     0],
        [  101, 42392, 10413,  ...,     0,     0,     0]])


In [None]:
for epoch_num in range(1):
    model.train()
    batch_num = len(loader)
    num_training_examples, running_loss = 0, 0.0
    for batch in tqdm(loader, total=batch_num):
        optimizer.zero_grad()
        batch = batch[0]
        num_training_examples += batch.size(0)
        if config.mode == 'finetuning':
            query, query_mask, \
            context, context_mask = util.get_finetuning_batch(batch, claims_train, claims_train_mask, evidence_train,
                                                            doc_chunks, chunks_mask, articles_ids, config)
        else:
            query, query_mask, \
            context, context_mask = util.get_pretraining_batch(ids2docs(batch, id2doc), doc_chunks, 
                                                                        tokenizer, config)

        query_cls_out = model(x=query, x_mask=query_mask)
        context_cls_out = model(x=context, x_mask=context_mask)
        logit = torch.matmul(query_cls_out, context_cls_out.transpose(-2, -1))
        correct_class = torch.tensor([i for i in range(len(query))]).long().to(config.device)
        loss = loss_fn(logit, correct_class)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * batch.size(0)
        epoch_avg_loss = running_loss / num_training_examples

### Make Chunks

In [None]:
wiki_json = io_util.load_json(wiki_path)

In [None]:
wiki_docs = util.transform_wiki(wiki_json)

In [None]:
pars, par_ids = io_util.load_db(ctk_path, limit=500000) #TODO REMOVE  # returns paragraphs and paragraph ids

In [None]:
print(len(pars), len(par_ids))

In [None]:
def remove_invalid_pars(pars, par_ids):
    ret_pars, ret_par_ids = [], []
    for pid, p in tqdm(zip(par_ids, pars)):
        if not p.strip().endswith('...') and p.strip().endswith('.'):
            ret_pars.append(p)
            ret_par_ids.append(pid)
    return ret_pars, ret_par_ids

In [None]:
pars, par_ids = remove_invalid_pars(pars, par_ids)

In [None]:
print(len(pars), len(par_ids))

In [None]:
doc_id, par_id  = par_ids[4].split('_')
print(doc_id, par_id)

In [None]:
docs = util.transform_ctk(pars, par_ids)

In [None]:
docs[doc_id]

In [None]:
docs[doc_id][par_id]

In [None]:
print(len(docs))

In [None]:
docs_tokenized = util.tokenize_documents(docs, tokenizer)

In [None]:
print(len(docs_tokenized))

In [None]:
len(docs_tokenized[doc_id][par_id])

In [None]:
for s in docs_tokenized[doc_id][par_id]:
    print(tokenizer.convert_ids_to_tokens(s))

In [None]:
print((f"#chunks: {len(docs_tokenized[doc_id])}\n",
       f"#sentences in paragraph {par_id}: {len(docs_tokenized[doc_id][par_id])}\n"))

In [None]:
doc_chunks = util.create_chunks(docs_tokenized, tokenizer, config)

In [None]:
len(doc_chunks)

In [None]:
print(len(doc_chunks[doc_id]))

In [None]:
for i, ch in enumerate(doc_chunks[doc_id]):
    print(f"chunk: {i}")
    for s in ch:
        print(tokenizer.convert_ids_to_tokens(s))

In [None]:
if config.task.upper() == 'ICT' and config.mode == 'pretraining': 
    doc_chunks = util.flatten_chunks(doc_chunks)
    print(doc_chunks[0])

In [None]:
ictc, icts = ict_pretraining_targets_and_contexts([0,1,2], flat_chunks, config)

In [None]:
def get_loader(data, batch_size):
    data = TensorDataset(data)
    return DataLoader(data,
                      batch_size=batch_size,
                      shuffle=True,
                      sampler=None, drop_last=True)

In [None]:
id2doc = {i: doc_id for i, (doc_id, _) in enumerate(doc_chunks.items())} if isinstance(doc_chunks, dict) else []

In [None]:
id2doc[0]

In [None]:
c, s = bfs_pretraining_targets_and_contexts([0,1,2], [[]] + chunks, config)

In [None]:
print(tokenizer.convert_ids_to_tokens(c[1]), '\n', tokenizer.convert_ids_to_tokens(s[1]))

In [None]:
len(s)

In [None]:
# chunks = util.create_chunks(docs_tokenized, tokenizer, config)

In [None]:
chunks, titles = util.make_chunks(fever_path, tokenizer, config, save_chunks=False)

dev_articles_ids = titles
dev_chunks, dev_masks = util.process_chunks(chunks, config)

### Claims

In [None]:
claims_dev, evidence_dev, labels_dev = util.load_claims('dev', config)

In [None]:
evidence_dev

In [None]:
claims_dev, claims_dev_mask = util.process_claims(claims_dev, tokenizer, config)

In [None]:
len(claims_dev)

### Get batch

In [None]:
loader = ( get_loader(torch.tensor([i for i in range(len(claims_train))]), config.batch_size) 
                    if config.mode == 'finetuning' 
                    else get_loader(torch.tensor([i for i in range(len(chunks))]), config.batch_size) )

In [None]:
batch = next(iter(loader))
batch = batch[0]

In [None]:
target, target_mask, context, context_mask = util.get_pretraining_batch(batch, chunks, tokenizer, config)

In [None]:
target.shape

In [None]:
context.shape

In [None]:
target[0]

In [None]:
context[1]