In [None]:
# Following https://github.com/huggingface/notebooks/blob/master/examples/translation.ipynb

In [None]:
model_checkpoint = 't5-small'
model_name='sparql-translator-t5-2021-06-20'
model_path='../../data/models/'+model_name
ds_path='../../data/dataset/lc-quad-wikidata-2021-06-20'
tokenizer_dir = '../../data/tokenizers/lc-quad-wikidata-2021-06-20'

In [None]:
from tqdm import tqdm

In [None]:
from datasets import load_dataset, load_metric, Dataset, load_from_disk
raw_datasets = load_from_disk(ds_path)

In [None]:
# Preprocessing

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
print(raw_datasets['test']['translation'][0]['en'],tokenizer(raw_datasets['test']['translation'][0]['en']))

In [None]:
prefix = "translate English to Sparql: "

In [None]:
max_input_length = 0 
max_target_length = 0
for d in tqdm(raw_datasets['train']['translation']):
    len_en = len(d['en'])
    len_qry = len(d['sparql'])
    if len_en > max_input_length: max_input_length=len_en
    if len_qry > max_target_length: max_target_length=len_qry

In [None]:
print(max_input_length, max_target_length)

In [None]:
source_lang = "en"
target_lang = "sparql"

def preprocess_function(examples):
    inputs = []
    targets= []
    for ex in examples["translation"]:
      inputs.append(prefix + ex[source_lang])
      targets.append(ex[target_lang])
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
      tokenizer(targets, max_length=max_target_length, truncation=True)


    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

In [None]:
tokenized_datasets

In [None]:
# Fine-tuning the model

In [None]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [None]:
batch_size = 8
args = Seq2SeqTrainingArguments(
    model_name,
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True,
    fp16=False,
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
import numpy as np
metric = load_metric("sacrebleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [None]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=data_collator,
    tokenizer=tokenizer
)

In [None]:
trainer.train()

In [None]:
trainer.save_model(model_path)

In [None]:
from transformers import pipeline

translator = pipeline(
    "translation_xx_to_yy",
    model=model2,
    tokenizer=tokenizer
)

In [None]:
translator('translate English to Sparql: Who is marlin?')#.to("cpu")