In [1]:
cd ../..

/people/gerald/Documents/repositories/Educational-French-Question-Answering


In [2]:
import os
import json
import evaluate
import random
from src.data_utils.pb_corpus import FQAGPBDataset
from src.data_utils.corpus import MixedDataset, KeyMapDataset
from src.model.mbart_qg import MBARTQGDataLoaderCollator

In [3]:
os.environ['EFQADATA'] = '/people/gerald/Documents/repositories/Educational-French-Question-Answering/dataset'
os.environ['QA_LOG'] = "/data/workdir/gerald/log"

In [4]:

data_folder = os.path.expandvars("$EFQADATA/source")
datasets_name = ["squad-en-en.pb.json","fquad-fr-fr.pb.json"]
datasets = {}

split = "train"

for dataset_name in datasets_name: 
    with open(os.path.join(data_folder, dataset_name)) as f:
        data = json.load(f)
        datasets[dataset_name.split('.')[0]] = FQAGPBDataset(data['train'], sampler = lambda x : [x[random.randint(0, len(x) - 1)]])


In [5]:
md = KeyMapDataset(MixedDataset(*datasets.values()))

In [6]:
md[0]

{'is_default': True,
 'id': 0,
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to <hl>Saint Bernadette Soubirous<hl> in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question_type': 'NONE',
 'input_lang': 250008,
 'output_lang': 250008,
 'index': 0,
 'dataset_index': 0}

In [7]:
from torch.utils.data import DataLoader
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

In [8]:
dl  = DataLoader(md, batch_size = 2, shuffle=True, collate_fn=MBARTQGDataLoaderCollator(tokenizer))

In [9]:
next(iter(dl))

{'input_ids': tensor([[250008,    239,  45556,  53495,    217,    427,    201,   1837,      6,
           58246,      7,     22,  20938,   1734,  57559,      7,    104,     25,
           70353,    569,  53495,  18750,     95,   1001,    104,     25,  58246,
               7,  61570,     95,  97879,      4,    199,   6181,      7,     82,
             199,  11374,      4,   1609,   5896,      4,    418,   1745,    115,
            3622,      4,     96,     25, 137858,  62195,      4,     96,     25,
          100023,      4,     21,  38397,     82,     96,     25,  44713,  11129,
               4,  48877,    418,   1745,      4,     82,     96,     25,  74896,
               4,   4426,   8266,   2740,   2947,      4,    363,   1745,  16093,
            8266,   2740,      5,  20097,   4360,  30686,      4,    382,    305,
            2489,      6, 126907,      7,   3537,  40279,     18,    807,     21,
           45556,    113, 107056,   5460,      6,  58246,      5,    339,     26,
   

In [14]:
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

from src.model.mbart_qg import MBARTQG, MBARTQGDataLoaderCollator
from src.eval_utils.evaluate_utils import HFMetric, MultiHFMetric

import spacy
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class SpacyTokenizer():
    def __init__(self):
        self.nlp = spacy.load("fr_core_news_lg")
    def __call__(self, x):
        return [t.text for t in self.nlp.tokenizer(x)]
st = SpacyTokenizer()
validation_metrics = MultiHFMetric(
    sacrebleu = HFMetric('sacrebleu', lambda x : x['score'], tokenize = 'intl'),
    rouge = HFMetric('rouge', lambda x : x['rougeL'], tokenizer = st)
)

os.environ['EFQADATA'] = '/people/gerald/Documents/repositories/Educational-French-Question-Answering/dataset'
data_folder = os.path.expandvars("$EFQADATA/source")
train_datasets_name = ["squad-en-en.pb.json","fquad-fr-fr.pb.json"]
valid_datasets_name = ["fquad-fr-fr.pb.json"]
train_datasets = {}
valid_datasets = {}

for dataset_name in train_datasets_name: 
    with open(os.path.join(data_folder, dataset_name)) as f:
        il, ol = dataset_name.split('.')[0].split('-')[-2], dataset_name.split('.')[0].split('-')[-1]
        data = json.load(f)
        train_datasets[dataset_name.split('.')[0]] = FQAGPBDataset(
            data["train"],
            sampler = lambda x : [x[random.randint(0, len(x) - 1)]],
            input_lang = il, output_lang = ol
        )
for dataset_name in valid_datasets_name: 
    with open(os.path.join(data_folder, dataset_name)) as f:
        il, ol = dataset_name.split('.')[0].split('-')[-2], dataset_name.split('.')[0].split('-')[-1]
        data = json.load(f)
        valid_datasets[dataset_name.split('.')[0]] = FQAGPBDataset(
            data["valid"],
            sampler = lambda x : [x[random.randint(0, len(x) - 1)]],
            input_lang = il, output_lang = ol
        )

model = MBARTQG(
    pretrained_name = "facebook/mbart-large-50-many-to-many-mmt",
    fixed_encoder = True,
    validation_callback = validation_metrics, log_dir = os.path.join(os.path.expandvars("$QA_LOG"), 'test')
)

tb_logger = pl_loggers.TensorBoardLogger(save_dir=os.path.expandvars("$QA_LOG"), name="test")
tb_logger.log_hyperparams({"test": 1 })
lr_monitor = LearningRateMonitor(logging_interval='step')

train_dl  = DataLoader(KeyMapDataset(MixedDataset(*train_datasets.values())), batch_size = 2, shuffle=True, num_workers=8, collate_fn=MBARTQGDataLoaderCollator(model.tokenizer))
valid_dl  = DataLoader(KeyMapDataset(MixedDataset(*valid_datasets.values())), batch_size = 2, shuffle=False, num_workers=8, collate_fn=MBARTQGDataLoaderCollator(model.tokenizer))


checkpoint_callback_val_loss = ModelCheckpoint(monitor='val/loss', save_top_k=2, mode="min", filename="val-loss-checkpoint-{epoch:02d}-{val_loss:.2f}")
checkpoint_callback_val_sacrebleu = ModelCheckpoint(monitor='val/sacrebleu', save_top_k=2, mode="max", filename="val-sacrebleu-checkpoint-{epoch:02d}-{val_loss:.2f}")
checkpoint_callback_val_rouge = ModelCheckpoint(monitor='val/rouge', save_top_k=2, mode="max", filename="val-rouge-checkpoint-{epoch:02d}-{val_loss:.2f}")

callbacks = [
    lr_monitor,
    checkpoint_callback_val_loss,
    checkpoint_callback_val_rouge,
    checkpoint_callback_val_sacrebleu
]



In [None]:
trainer = pl.Trainer(
    logger=tb_logger, 
    log_every_n_steps=1, 
    callbacks=callbacks, 
    enable_progress_bar=True,
    limit_train_batches=10000, 
    max_epochs=250, 
    resume_from_checkpoint="/data/workdir/gerald/log/test/version_9/checkpoints/val-sacrebleu-checkpoint-epoch=23-val_loss=0.00.ckpt",
    accumulate_grad_batches=64,
    accelerator='gpu',
    devices=[1]
)
trainer.fit(
    model,
    train_dl,
    valid_dl
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at /data/workdir/gerald/log/test/version_9/checkpoints/val-sacrebleu-checkpoint-epoch=23-val_loss=0.00.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
  super(SGD, self).__init__(params, defaults)

  | Name  | Type                          | Params
--------------------------------------------------------
0 | model | MBartForConditionalGeneration | 610 M 
--------------------------------------------------------
610 M     Trainable params
0         Non-trainable params
610 M     Total params
2,443.522 Total estimated model params size (MB)
Restored all states from the checkpoint file at /data/workdir/gerald/log/test/version_9/checkpoints/val-sacrebleu-checkpoint-epoch=23-val_loss=0.00.ckpt


Sanity Checking: 0it [00:00, ?it/s]



Training: 10000it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]