### In-Context Cross-lingual Transfer.
Inference example notebook.

In [1]:
#### REMOVE LATER
## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## 
'''
This cell is needed to make the trainer work in HPC notebooks, if it is not used a weird error is raised.
'''
import os
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "9997"  # modify if RuntimeError: Address already in use
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## ## 
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='0'

In [None]:
## import libraries

from transformers import AutoTokenizer, MT5ForConditionalGeneration, TrainingArguments
from peft import PeftModel
from datasets import Dataset
import numpy as np

from src.data_handling import get_class_set, get_kshot_dataset, get_class_objects
from src.ic_xlt_utils import create_icl_dataset, run_inference, compute_metrics

In [None]:
## instance model and tokenizer

tokenizer = AutoTokenizer.from_pretrained('google/mt5-large')

base_model = MT5ForConditionalGeneration.from_pretrained('google/mt5-large')

### Load data
Also, create One-Shot adaptation dataset from the training set on the target language.

In [None]:
data_dir = 'data/massive'

## load datasets of target languages and select one-shot from training data

target_languages = ['english','azeri','turkish','swahili']

tl_datasets = {}

for tl in target_languages:
    
    tl_datasets[tl] = {
        'test': Dataset.load_from_disk('/'.join([data_dir,'test',tl])) ## to evaluate
    }
    
    ## load training to retrieve the one-shot demonstration and reduce to one-shot
    
    tl_datasets[tl]['train'] = get_kshot_dataset(
        dataset = Dataset.load_from_disk('/'.join([data_dir,'train',tl])),
        k = 1, # One-Shot per label
        seed = 42, # Seed to select shots
    )
    
    tl_datasets[tl]['class_set'],tl_datasets[tl]['lbl2id_class'], _ = get_class_objects(tl_datasets[tl]['train'],tl_datasets[tl]['test'])


### Inference on target languages / IC-XLT

In-Context Cross-lingual Transfer evaluated in a given target language.<br>

In [None]:
## load lora trained with ICT

path_lora = 'trained_loras/massive/ict_m10/' # or 'trained_loras/acd/ict_m10/' to load the ACD trained model

model = PeftModel.from_pretrained(base_model,path_lora) 

In [None]:
## iterate and predict/evaluate over target languages

metrics_per_lang = {}

for tl in target_languages:
    
    print(f'Running inference on {tl}')

    ## prepend one-shot demonstration in context
    icl_dataset_test = create_icl_dataset(dataset_test = tl_datasets[tl]['test'], dataset_train = tl_datasets[tl]['train'])
    
    ## predict samples
    predicted_icxlt = run_inference(
                model = model,
                tokenizer = tokenizer,
                test_texts = icl_dataset_test['text'],
                class_set = tl_datasets[tl]['class_set'],
            )
    
    ## evaluate samples
    metrics_per_lang[tl] = compute_metrics(
        predicted_labels = predicted_icxlt,
        truth_labels = tl_datasets[tl]['test']['label'],
        class_set = tl_datasets[tl]['class_set'],
        lbl2id_class = tl_datasets[tl]['lbl2id_class']
    )
    
## print f1-micro per language
for tl in target_languages:
    print('{} : {}'.format(tl,metrics_per_lang[tl]['f1_score_micro']))

### Inference on target languages / ZS-XLT

Zero/Few Cross-lingual Transfer (fine-tuned) evaluated in a given target language.<br>
To evaluate 1S/8S-XLT, we just continue fine-tuning the checkpoint with the reduced training dataset.

In [None]:
## load model trained with prompt-based fine-tuning 
path_lora_pbt = 'trained_loras/massive/pbt/' # or 'trained_loras/acd/pbt/' to load the ACD trained model
model = PeftModel.from_pretrained(base_model,path_lora_pbt) 

In [None]:
## iterate and predict/evaluate over target languages

metrics_per_lang = {}

for tl in target_languages:
    
    print(f'Running inference on {tl}')

    ## predict samples
    predicted_zsxlt = run_inference(
                model = model,
                tokenizer = tokenizer,
                test_texts = tl_datasets[tl]['test']['text'],
                class_set = tl_datasets[tl]['class_set'],
            )
    
    ## evaluate samples
    metrics_per_lang[tl] = compute_metrics(
        predicted_labels = predicted_zsxlt,
        truth_labels = tl_datasets[tl]['test']['label'],
        class_set = tl_datasets[tl]['class_set'],
        lbl2id_class = tl_datasets[tl]['lbl2id_class']
    )
    
## print f1-micro per language
for tl in target_languages:
    print('{} : {}'.format(tl,metrics_per_lang[tl]['f1_score_micro']))