In [None]:
!git clone https://github.com/paolo-gajo/food.git
%cd ./food

In [None]:
%%capture
!pip install datasets
!pip install sentencepiece
!pip install accelerate
!pip install evaluate

In [None]:
import os
import sys
sys.path.append('/content/food/src/word_alignment')

In [None]:
import os
from utils import prep_tasteset, qa_tokenize
from transformers import AutoTokenizer
import warnings
import argparse

def build_tasteset(data_path,
                shuffle_ents=False,
                shuffle_languages=['it'],
                src_lang = 'en',
                tokenizer_name = 'bert-base-multilingual-cased',
                dev_size=0.2,
                shuffled_size = 1,
                unshuffled_size = 1,
                drop_duplicates = True,
                ):
    
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    dataset = prep_tasteset(data_path,
                        tokenizer,
                        shuffle_ents=shuffle_ents,
                        shuffle_languages=shuffle_languages,
                        src_lang = src_lang,
                        dev_size = dev_size,
                        shuffled_size = shuffled_size,
                        unshuffled_size = unshuffled_size,
                        drop_duplicates = drop_duplicates,
                        ).shuffle(seed=42)

    dataset = dataset.map(lambda sample: qa_tokenize(sample, tokenizer),
                        batched=True,
                        batch_size=None,
                        )

    dataset.set_format('torch', columns=['input_ids',
                                        'token_type_ids',
                                        'attention_mask',
                                        'start_positions',
                                        'end_positions'])        
    return dataset

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input', default='./src/word_alignment/EW-TASTE_en-it_DEEPL.json', help='path of the input json dataset')
    parser.add_argument('-o', '--output', default='', help='path of the input json dataset')
    parser.add_argument('-sh', '--shuffle', default=True, help='if True (default=False), shuffle entities in samples')
    parser.add_argument('-l', '--shuffle_languages', default='it', help='space-separated 2-character codes of the dataset target languages to shuffle')
    parser.add_argument('-src', '--src_lang', default='en', help='space-separated 2-character code of the dataset source language')
    parser.add_argument('-t', '--tokenizer_name', default='bert-base-multilingual-cased', help='tokenizer to use')
    parser.add_argument('-d', '--drop_duplicates', default=True, help='if True (default=True), drop rows with the same answer')
    parser.add_argument('-ss', '--shuffled_size', default=1, help='length multiplier for the number of shuffled instances (default=1)')
    parser.add_argument('-us', '--unshuffled_size', default=1, help='length multiplier for the number of unshuffled instances (default=1)')
    parser.add_argument('-ds', '--dev_size', default=1, help='size of the dev split (default=0.2)')

    args = parser.parse_args()

    args.shuffle_ents = True
    args.unshuffled_size = 1
    args.shuffled_size = 1
    args.drop_duplicates = True

    dataset = build_tasteset(args.input,
                             shuffle_ents = args.shuffle,
                             shuffle_languages = args.shuffle_languages,
                             src_lang = args.src_lang,
                             tokenizer_name = args.tokenizer_name,
                             drop_duplicates = args.drop_duplicates,
                             shuffled_size = args.shuffled_size,
                             shuffled_size = args.unshuffled_size,
                             dev_size = args.dev_size
                             )
    
    lang_id = '-'.join(args.shuffle_languages.split())
    suffix = 'drop_duplicates' if args.drop_duplicates == True else 'keep_duplicates'

    if not args.output:
        output_path = os.path.join(os.path.dirname(args.input),
                    f'.{args.src_lang}/{args.tokenizer_name.split("/")[-1]}_{lang_id}_{suffix}_shuffled_size_{args.shuffled_size}')
        if not os.path.isdir(output_path):
            os.makedirs(output_path)
    
    dataset.save_to_disk(output_path)

if __name__ == '__main__':
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        main()

In [None]:
import warnings
import os
import torch
from tqdm.auto import tqdm
from datasets import DatasetDict
from evaluate import load
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from utils import data_loader, SquadEvaluator, push_model
from datetime import datetime

def main():
    # model_name = 'bert-base-multilingual-cased'
    tokenizer_name = 'microsoft/mdeberta-v3-base'
    model_name = 'microsoft/mdeberta-v3-base'
    now = datetime.now()
    dt_string = now.strftime("%Y-%m-%d_%H-%M-%S")

    suffix = 'TASTEset'
    push_model_description = ''

    lang_list = [
        # 'ru',
        # 'nl',
        'it',
        # 'pt',
        # 'et',
        # 'es',
        # 'hu',
        # 'da',
        # 'bg',
        # 'sl',
    ]

    src_lang = 'en'
    lang_id = '-'.join(lang_list)
    # data_path = f'/home/pgajo/working/food/data/XL-WA/data/.{src_lang}/{lang_id}'
    # results_path = f'/home/pgajo/working/food/results/xl-wa/{lang_id}'
    # data_path = f'/home/pgajo/working/food/data/TASTEset/data/EW-TASTE/.en/bert-base-multilingual-cased_it_drop_duplicates'
    data_path = f'/home/pgajo/working/food/data/TASTEset/data/EW-TASTE/.en/{tokenizer_name.split("/")[-1]}_it_drop_duplicates'
    results_path = f'/home/pgajo/working/food/results/tasteset/{lang_id}'

    data = DatasetDict.load_from_disk(data_path) # load prepared tokenized dataset
    # print('max_num_tokens', [len(data['train'][i]['input_ids']) for i in range(len(data['train']))])
    batch_size = 2
    dataset = data_loader(data,
                        batch_size,
                        #   n_rows = 320,
                        )

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForQuestionAnswering.from_pretrained(model_name)
    device = 'cuda'
    model = torch.nn.DataParallel(model).to(device)

    optimizer = torch.optim.AdamW(params=model.parameters(),
                                lr=2e-5,
                                eps=1e-8
                                )

    evaluator = SquadEvaluator(tokenizer, 
                            model,
                            load("squad_v2"),
                            )

    epochs = 10

    for epoch in range(epochs):
        # train
        epoch_train_loss = 0
        model.train()
        split = 'train'
        progbar = tqdm(enumerate(dataset[split]),
                                total=len(dataset[split]),
                                desc=f"{split} - epoch {epoch + 1}")
        for i, batch in progbar:
            outputs = model(**batch) # ['loss', 'start_logits', 'end_logits']
            loss = outputs[0].mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_train_loss += loss.item()
            loss_tmp = round(epoch_train_loss / (i + 1), 4)
            progbar.set_postfix({'Loss': loss_tmp})
            
            evaluator.get_eval_batch(outputs, batch, split)

        evaluator.evaluate(model, split, epoch)
        epoch_train_loss /= len(dataset[split])
        evaluator.epoch_metrics[f'{split}_loss'] = epoch_train_loss

        evaluator.print_metrics(current_epoch = epoch, current_split = split)

        # eval on dev
        epoch_dev_loss = 0
        model.eval()
        split = 'dev'
        progbar = tqdm(enumerate(dataset[split]),
                                total=len(dataset[split]),
                                desc=f"{split} - epoch {epoch + 1}")
        for i, batch in progbar:
            with torch.inference_mode():
                outputs = model(**batch)
            loss = outputs[0].mean()
            epoch_dev_loss += loss.item()
            loss_tmp = round(epoch_dev_loss / (i + 1), 4)
            progbar.set_postfix({'Loss': loss_tmp})
            
            evaluator.get_eval_batch(outputs, batch, split)
        
        evaluator.evaluate(model, split, epoch)
        epoch_dev_loss /= len(dataset[split])
        evaluator.epoch_metrics[f'{split}_loss'] = epoch_dev_loss

        evaluator.print_metrics(current_epoch = epoch, current_split = split)

        evaluator.store_metrics()

        if evaluator.stop_training:
            print(f'Early stopping triggered on epoch {epoch}. \
                \nBest epoch: {evaluator.epoch_best}.')                                               
            break

    evaluator.print_metrics()
    if not os.path.isdir(results_path):
        os.makedirs(results_path)
    evaluator.save_metrics_to_csv(results_path)

    push_model(
        evaluator.best_model,
        # model_name = model_name,
        suffix = suffix + dt_string,
        model_description = push_model_description
    )

if __name__ == '__main__':
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        main()