### In-Context Cross-lingual Transfer.
Training example notebook.

In [None]:
#### REMOVE LATER
## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## 
'''
This cell is needed to make the trainer work in HPC notebooks, if it is not used a weird error is raised.
'''
import os
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "9996"  # modify if RuntimeError: Address already in use
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## 
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='0'

In [None]:
## import libraries
import pandas as pd
from transformers import AutoTokenizer, MT5ForConditionalGeneration, TrainingArguments
from datasets import Dataset

from src.data_handling import get_class_objects
from src.ic_xlt_utils import train_lora, preprocess_function

In [None]:
## set source language
source_language = 'english'

## load data
data_dir = 'data/massive' ## or 'data/acd'

## convert to transformer Dataset object
dataset_train = Dataset.load_from_disk('/'.join([data_dir,'train',source_language]))
dataset_test = Dataset.load_from_disk('/'.join([data_dir,'test',source_language]))

## retrieve useful variables
class_set,lbl2id_class, id2lbl_class = get_class_objects(dataset_train,dataset_test)

We employ an mT5 model.

In [None]:
## import model and tokenizer

tokenizer = AutoTokenizer.from_pretrained('google/mt5-large')
base_model = MT5ForConditionalGeneration.from_pretrained('google/mt5-large')

$M$ is the number of examples prepended to the context.<br>
If $M=0$ or set to None, the training is done as Prompt-based FT with input output $x_i\to y_i$. <br>
If $M\geq1$ then the training is done through In-Context Tuning with $X^{src},x_i\to y_i$.  Where $X^{src}$ are the context examples drawn from the training dataset.<br>

In [None]:
## preprocess and tokenize text

M = 10

def preprocess_wrapper_icl(sample):
    '''
    Wrapper for preprocessing each training sample and add context examples if required
    '''
    return preprocess_function(
        sample, 
        tokenizer, 
        ict_n = M,
        )

tokenized_dataset_train = dataset_train.map(
    preprocess_wrapper_icl, 
    batched = True,
    remove_columns=["text",'label'])

In [None]:
print('Training data sample:')

tokenizer.decode(tokenized_dataset_train['input_ids'][0], skip_special_tokens = True)

In [None]:
training_args = TrainingArguments(
    
        output_dir = 'checkpoints_trained', #directory to save the checkpoint
        learning_rate = 0.0004,
        auto_find_batch_size = True,
        per_device_train_batch_size = 8,
        per_device_eval_batch_size = 8,
        num_train_epochs = 10,
        save_strategy = 'epoch',
        seed = 1,
        data_seed = 1,
        ddp_find_unused_parameters = False,
    )

In [None]:
model = train_lora(    
    base_model = base_model,
    peft_training_args = training_args,
    dataset_train = tokenized_dataset_train,
    lora_config = None, ## to load a LoRA with custom parameters (LoraConfig object)
    lora_checkpoint = None, ## provide to continue to fine-tune an already trained LoRA
)