In [3]:
%load_ext autoreload
%autoreload 2
import torch
from argparse import ArgumentParser
import pytorch_lightning as pl

from datasets import load_metric

from project.data.data import *
from project.models.models import *
from project.metrics.metrics import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
predictions = [
    "hello there general kenobi",                             # tokenized prediction of the first sample
    "foo bar foobar"                                             # tokenized prediction of the second sample
]
references = [
    "hello there general kenobi",  # tokenized references for the first sample (2 references)
    "foo bar foobar"                                           # tokenized references for the second sample (1 reference)
]

In [7]:
input = Input(predictions, references)
metric = Metrics().compute_metrics(input)

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


## Preprocess raw data:

In [2]:
#processor = RaceDataProcessor()
#processor.process_data("RACE", "LON")

## Train model

In [10]:
# Parse arguments:
parser = ArgumentParser()
parser = RaceDataModule.add_model_specific_args(parser)
parser = RaceModule.add_model_specific_args(parser)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args(
    "--data_path LON --batch_size 16 --num_workers 6 --d_model 768 --nhead 8 "
    "--num_layers 1 --learning_rate 1e-5 --special_tokens [CON] [QUE] [ANS] [DIS] "
    "--dm_pretrained_model distilbert-base-cased --m_pretrained_model distilbert-base-cased "
    "--gpus 1 --max_epochs 5 --check_val_every_n_epoch 1".split()
)

pl.seed_everything(1234)


# Module and data module:
def customed_collate_fn(batch, tokenizer):
    con_token, que_token, ans_token, dis_token = tokenizer.additional_special_tokens

    inputs = []
    targets = []

    for item in batch:
        inputs.append(" ".join([con_token, item["article"], ans_token, item["answer"]]))
        targets.append(" ".join([que_token, item["question"], dis_token, dis_token.join(item["distractors"])]))

    return {
        "inputs": tokenizer(inputs, padding=True, truncation=True, return_tensors="pt"),
        "targets": tokenizer(targets, padding=True, truncation=True, return_tensors="pt"),
    }


data_module = RaceDataModule(args, customed_collate_fn)
module = RaceModule(args)


# Callbacks:
checkpoint = ModelCheckpoint(
    dirpath="./checkpoint/fx-{epoch:02d}-{val_loss:.7f}",
    monitor="val_loss"
)


# Trainer:
trainer = pl.Trainer.from_argparse_args(
    args,
    checkpoint_callback=checkpoint
)


#trainer.fit(module, data_module)

Global seed set to 1234
Global seed set to 1234
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
GPU available: True, used: True
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
TPU available: None, using: 0 TPU cores


In [12]:
data_module.prepare_data()
data_module.setup()

In [92]:
tokenizer.decode(tokenizer(["con cac", "cai lon", "dit me may"])["input_ids"][0], skip_special_tokens=True)

'con cac'

In [14]:
tokenizer = data_module.tokenizer
batch = next(iter(data_module.test_dataloader()))
inputs, targets = batch["inputs"], batch["targets"]

In [21]:
generated = module(
    target=targets["input_ids"][:, :-1],
    memory=module.encode(inputs),
    input_key_padding_mask=targets["attention_mask"][:, :-1] == 0,
    memory_key_padding_mask=inputs["attention_mask"] == 0
)

In [28]:
targets["input_ids"].shape

torch.Size([16, 55])

In [26]:
generated.argmax(axis=1).shape

torch.Size([16, 54])