In [8]:
import os
import pickle
import pandas as pd
from typing import List
import math
import os.path
import pandas as pd
import seaborn as sns

import json
import matplotlib as mpl
from matplotlib import pyplot as plt

In [9]:
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.serif'] = ['Times New Roman']
mpl.rcParams['mathtext.fontset'] = 'cm'
# bigger size
mpl.rcParams['font.size'] = 16
# bold x-labels
mpl.rcParams['axes.labelweight'] = 'bold'
# smaller labels on axes
mpl.rcParams['axes.labelsize'] = 14

mpl.rcParams['xtick.labelsize'] = 12

In [10]:
RESULT_DIR = "xtreme_up_results"

# Metric / langs defineition

In [23]:

task_langs = {'qa_in_lang': ['ar', 'bn', 'en', 'fi', 'id', 'ko', 'ru', 'sw', 'te'],
              'qa_cross_lang': ['ar', 'as', 'bho', 'bn', 'brx', 'fi', 'gbm', 'gom', 'gu', 'hi', 'hne',
                                'kn', 'ko', 'mai', 'ml', 'mni', 'mr', 'mwr', 'or', 'pa', 'ps', 'ru', 'sa', 'ta', 'te', 'ur'],
              'ner': ['am', 'bbj', 'bm', 'ee', 'ha', 'ig', 'lg', 'luo', 'mos', 'ny', 'pcm', 'rw', 'sn',
                      'sw','tn', 'tw', 'wo', 'xh', 'yo', 'zu'],
              'semantic_parsing': ['am', 'be', 'bn', 'de', 'en', 'es', 'fi', 'fr', 'ha', 'hi', 'ja',
                                    'pt_br', 'ru', 'sw', 'ta', 'th', 'tr', 'yo', 'zu'],
              'xnli': ['ar', 'bg', 'de', 'el', 'en', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh'],
              'transliteration': ['am_Ethi_Latn', 'am_Latn_Ethi', 'bn_Beng_Latn', 'bn_Latn_Beng',
                                  'gu_Gujr_Latn', 'gu_Latn_Gujr', 'hi_Deva_Latn', 'hi_Latn_Deva',
                                  'kn_Knda_Latn', 'kn_Latn_Knda', 'ml_Latn_Mlym', 'ml_Mlym_Latn',
                                  'mr_Deva_Latn', 'mr_Latn_Deva', 'pa_Arab_Guru', 'pa_Arab_Latn',
                                  'pa_Guru_Arab', 'pa_Guru_Latn', 'pa_Latn_Arab', 'pa_Latn_Guru',
                                  'sd_Arab_Latn', 'sd_Latn_Arab', 'si_Latn_Sinh', 'si_Sinh_Latn',
                                  'ta_Latn_Taml', 'ta_Taml_Latn', 'te_Latn_Telu', 'te_Telu_Latn',
                                  'ur_Arab_Latn', 'ur_Latn_Arab'],
              'translation': ['en_af', 'en_am', 'en_as', 'en_az', 'en_be', 'en_bg', 'en_bn', 'en_bs', 
                              'en_ceb', 'en_ckb', 'en_cy', 'en_da', 'en_el', 'en_et', 
                              'en_ff', 'en_fil', 'en_ga', 'en_gl', 'en_gu', 'en_ha', 'en_hy', 'en_id', 'en_ig', 'en_is', 'en_jv', 'en_ka',
                              'en_kk', 'en_km', 'en_kn', 'en_ky', 'en_lb', 'en_lg', 'en_ln', 'en_lo', 'en_lt', 
                              'en_lv', 'en_mi', 'en_mk', 'en_ml', 'en_mn', 'en_mr', 'en_ms', 'en_mt', 'en_my', 'en_ne', 
                              'en_no', 'en_nso', 'en_ny', 'en_om', 'en_or', 'en_pa', 'en_ps',
                              'en_ro', 'en_sd', 'en_sk', 'en_sl', 'en_sn', 'en_so', 'en_sw', 
                              'en_ta', 'en_te', 'en_tg', 'en_th','en_uk', 'en_ur', 'en_uz', 'en_xh',
                              'en_yo', 'en_zu']}
task_metric = {'qa_in_lang': 'f1', 'qa_cross_lang': 'f1', 'ner': 'span_f1',
               'semantic_parsing': 'sequence_accuracy', 'transliteration': 'cer', 'translation': 'chrf',
              'xnli': 'accuracy'}

In [20]:
def parse_model_task_results(model_sized, task):
    model_prefix = model_sized[:4]
    eval_row = {'index': model_sized}
    for lang in task_langs[task]:
        if task == "xnli":
            res_file = f"{RESULT_DIR}/{model_sized}_{task}/{model_prefix}_{task}_dev_test.{lang}-metrics.jsonl"
        else:
            res_file = f"{RESULT_DIR}/{model_sized}_{task}/xtreme_up_{task}.{lang}_{model_prefix}-metrics.jsonl"
        if not os.path.isfile(res_file):
            res_file = f"{RESULT_DIR}/{model_sized}_{task}/xtreme_up_{task}_{lang}_{model_prefix}-metrics.jsonl"
        if not os.path.isfile(res_file):
            eval_row[lang] = 0.0
        else:
            with open(res_file, "r") as in_file:
                line = in_file.readlines()[-1]
                eval_row[lang] = json.loads(line)[task_metric[task]]
    return eval_row

In [21]:
def parse_all_task_results(task):
    evals = []
    for size in ("small", "base", "large"):
        for model in ("byt5", "myt5"):
        
            evals.append(parse_model_task_results(f"{model}_{size}", task))

    df = pd.DataFrame(evals)
    df.set_index('index', inplace=True)
    df.to_csv(f"{RESULT_DIR}/{task}_results.csv")
    df.to_latex(os.path.join(RESULT_DIR, f"{task}_results.tex"), float_format="%.1f".__mod__, label=f"{task}_results")

    display(df)

## In lang QA

In [14]:
parse_all_task_results('qa_in_lang')

  df.to_latex(os.path.join(RESULT_DIR, f"{task}_results.tex"), float_format="%.1f".__mod__, label=f"{task}_results")


Unnamed: 0_level_0,ar,bn,en,fi,id,ko,ru,sw,te
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
byt5_small,78.131159,64.606878,72.528835,81.444931,73.243717,68.110302,77.52614,69.054093,77.843651
myt5_small,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
byt5_base,83.676614,72.690396,77.756327,81.281116,79.620133,77.622872,78.059827,76.68464,85.28362
myt5_base,83.075396,69.925234,76.962359,81.984063,78.183006,76.551542,78.742481,76.29302,84.663867
byt5_large,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
myt5_large,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


## QA Cross lang

In [15]:
parse_all_task_results('qa_cross_lang')

  df.to_latex(os.path.join(RESULT_DIR, f"{task}_results.tex"), float_format="%.1f".__mod__, label=f"{task}_results")


Unnamed: 0_level_0,ar,as,bho,bn,brx,fi,gbm,gom,gu,hi,...,mr,mwr,or,pa,ps,ru,sa,ta,te,ur
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
byt5_small,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
myt5_small,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
byt5_base,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
myt5_base,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
byt5_large,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
myt5_large,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


## NER


In [16]:
parse_all_task_results('ner')

  df.to_latex(os.path.join(RESULT_DIR, f"{task}_results.tex"), float_format="%.1f".__mod__, label=f"{task}_results")


Unnamed: 0_level_0,am,bbj,bm,ee,ha,ig,lg,luo,mos,ny,pcm,rw,sn,sw,tn,tw,wo,xh,yo,zu
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
byt5_small,0.563189,0.758157,0.816915,0.885732,0.870765,0.868212,0.851163,0.790698,0.764116,0.898421,0.87155,0.795818,0.926751,0.894668,0.862302,0.77708,0.808734,0.857221,0.81914,0.850565
myt5_small,0.563208,0.681319,0.786979,0.868152,0.852596,0.865014,0.82596,0.800499,0.725664,0.890173,0.836015,0.774137,0.916609,0.894394,0.848269,0.78577,0.768154,0.836673,0.82153,0.830688
byt5_base,0.621878,0.742229,0.805769,0.887224,0.90108,0.874384,0.84965,0.804878,0.763636,0.900868,0.873673,0.801757,0.937313,0.903086,0.870596,0.789639,0.828309,0.869565,0.831424,0.869261
myt5_base,0.596522,0.71866,0.790743,0.873198,0.89824,0.867786,0.839202,0.798042,0.736241,0.890005,0.854309,0.790318,0.932719,0.870593,0.866731,0.774845,0.772408,0.853605,0.821299,0.847694
byt5_large,0.607639,0.725107,0.799807,0.881342,0.881266,0.84318,0.845748,0.771392,0.735124,0.891426,0.852018,0.767131,0.900186,0.884758,0.856404,0.776213,0.801958,0.834448,0.787199,0.85203
myt5_large,0.621129,0.68926,0.791786,0.87051,0.875439,0.833435,0.836118,0.753512,0.753684,0.879617,0.852277,0.778531,0.902354,0.888792,0.847391,0.77708,0.749889,0.8198,0.794227,0.838455


## Semantic parsing

In [17]:
parse_all_task_results('semantic_parsing')

  df.to_latex(os.path.join(RESULT_DIR, f"{task}_results.tex"), float_format="%.1f".__mod__, label=f"{task}_results")


Unnamed: 0_level_0,am,be,bn,de,en,es,fi,fr,ha,hi,ja,pt_br,ru,sw,ta,th,tr,yo,zu
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
byt5_small,18.491484,32.495269,30.846175,30.575831,32.657475,30.095541,28.899703,29.703342,24.439038,21.173469,30.00811,32.522303,35.955664,25.601514,25.844823,22.853426,29.170046,19.248446,22.898081
myt5_small,16.815356,27.169505,22.979184,29.548527,28.845634,27.348726,24.033523,25.309801,22.70884,18.792517,20.329819,28.872668,32.414166,21.08678,25.168965,20.511709,28.521222,13.544201,16.84239
byt5_base,23.222493,40.713706,36.469316,36.982968,40.767775,39.012739,36.30711,38.265115,34.144363,29.12415,34.739119,40.524466,43.011625,31.819411,34.766153,30.355594,37.30738,24.060557,28.953771
myt5_base,18.383347,28.494188,23.249527,30.765072,31.143552,29.578025,27.142471,29.778445,23.357664,22.44898,19.626926,31.603136,33.333333,23.168424,24.871587,24.544666,28.359016,13.895647,17.356042
byt5_large,18.599622,31.657205,30.711003,34.522844,35.144634,33.121019,29.981076,34.847916,25.736686,25.680272,31.44093,34.712084,35.739389,26.358475,26.439578,24.631396,32.06272,18.599622,22.816978
myt5_large,16.490943,26.196269,20.627197,31.63017,31.63017,28.144904,25.736686,28.05107,21.735604,18.707483,18.085969,30.359557,32.765612,21.411192,21.167883,19.167389,25.655583,12.97648,16.680184


## XNLI

In [24]:
parse_all_task_results('xnli')

  df.to_latex(os.path.join(RESULT_DIR, f"{task}_results.tex"), float_format="%.1f".__mod__, label=f"{task}_results")


Unnamed: 0_level_0,ar,bg,de,el,en,es,fr,hi,ru,sw,th,tr,ur,vi,zh
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
byt5_small,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
myt5_small,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
byt5_base,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
myt5_base,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
byt5_large,68.722555,72.255489,71.776447,71.257485,77.165669,73.732535,73.512974,66.167665,70.459082,67.025948,65.369261,69.580838,63.772455,70.0,69.421158
myt5_large,68.323353,72.275449,71.656687,71.936128,77.54491,74.0,68.666667,68.666667,68.0,64.0,62.666667,69.333333,63.333333,71.333333,69.333333


## Translation



In [142]:
parse_all_task_results('translation')

Unnamed: 0_level_0,en_af,en_am,en_as,en_az,en_be,en_bg,en_bn,en_bs,en_ceb,en_ckb,...,en_ta,en_te,en_tg,en_th,en_uk,en_ur,en_uz,en_xh,en_yo,en_zu
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
byt5_small,4.148148,7.991803,4.217059,3.728735,3.645833,4.178674,4.521964,4.229323,3.66135,4.117196,...,4.007434,4.924376,3.597122,4.195804,4.071543,4.586347,3.513629,3.856252,4.163997,3.695492
myt5_small,3.508332,7.668109,4.345687,3.428327,3.733049,4.09894,5.039079,4.189164,3.722211,4.016439,...,3.518972,3.975014,4.046063,4.547689,4.106174,4.731574,3.283909,3.868313,4.616477,3.544423
byt5_base,4.425783,6.792059,4.165675,4.171754,3.780004,3.453881,4.279279,4.417506,3.982364,4.130435,...,3.557137,3.891439,4.154079,3.164557,4.125781,4.206879,3.768624,4.125258,3.320216,3.970617
myt5_base,4.506587,7.107709,4.346746,3.753891,4.1511,4.36233,3.751234,4.098361,3.921569,3.115265,...,3.81255,3.260324,4.187286,4.143646,4.250919,4.36233,3.846154,3.986952,4.340836,3.824823
byt5_large,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
myt5_large,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


## Transliteration

In [140]:
parse_all_task_results('transliteration')

Unnamed: 0_level_0,am_Ethi_Latn,am_Latn_Ethi,bn_Beng_Latn,bn_Latn_Beng,gu_Gujr_Latn,gu_Latn_Gujr,hi_Deva_Latn,hi_Latn_Deva,kn_Knda_Latn,kn_Latn_Knda,...,sd_Arab_Latn,sd_Latn_Arab,si_Latn_Sinh,si_Sinh_Latn,ta_Latn_Taml,ta_Taml_Latn,te_Latn_Telu,te_Telu_Latn,ur_Arab_Latn,ur_Latn_Arab
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
byt5_small,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
myt5_small,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
byt5_base,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
myt5_base,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
byt5_large,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
myt5_large,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


## Scores and times plot

In [58]:
palette = sns.color_palette("tab10", 4)
category_colors = {cat: col for cat, col in zip(['Translation', 'NER', 'QA'  , 'Semantic Parsing'], palette)}
category_linestyles = {cat: ls for cat, ls in zip(['Translation', 'NER', 'QA'  , 'Semantic Parsing'], ['-', '--', '-.', ':'])}

In [98]:
# plot scores against the infernce times

aggregate_results = [
	{"task": "qa_in_lang", "myt5_score": 75.3, "byt5_score": 73.2, "myt5_time": 0.6464, "byt5_time": 1.00},
    {"task": "ner", "myt5_score": 82.8 , "byt5_score": 83.1 , "myt5_time": 0.9224 , "byt5_time": 1.00},
    {"task": "semantic_parsing", "myt5_score": 19.6, "byt5_score": 25.1, "myt5_time": 0.7443 , "byt5_time": 1.00 },
    {"task": "translation", "myt5_score": 20.42, "byt5_score": 20.10, "myt5_time":  0.6547, "byt5_time":1.00}
]
agg_df = pd.DataFrame(aggregate_results)



In [None]:
aggregate_results = [
	{"task": "qa_in_lang", "myt5_score": 75.3, "byt5_score": 73.2, "myt5_time": 35.6, "byt5_time": 36.2},
    {"task": "ner", "myt5_score": 80.8 , "byt5_score": 81.5 , "myt5_time": 12.6 , "byt5_time": 13.8},
    {"task": "semantic_parsing", "myt5_score": 19.6, "byt5_score": 25.1, "myt5_time": 12.4, "byt5_time": 13.2 },
    {"task": "translation", "myt5_score": 20.42, "byt5_score": 20.10, "myt5_time":  12.6, "byt5_time":15.9}
]
#TODO XNLI

agg_df = pd.DataFrame(aggregate_results)
agg_df["task"] = agg_df["task"].replace({"qa_in_lang": "QA", "ner": "NER", "semantic_parsing": "Semantic Parsing", "translation": "Translation"})

# rename tasks




In [None]:
agg_df_t = agg_df.transpose()
agg_df_t.columns = agg_df["task"]
agg_df_t.drop("task", inplace=True)

agg_df_t.to_latex("output/xtreme_results_lr.tex", float_format="%.1f".__mod__, label="xtreme_results_lr")


In [None]:
agg_df["myt5_time"] = agg_df["myt5_time"] / agg_df["byt5_time"]
agg_df["byt5_time"] = agg_df["byt5_time"] / agg_df["byt5_time"]

In [None]:
plt.figure(figsize=(6.5,5.5))

for task in agg_df['task'].unique():
    task_df = agg_df[agg_df['task'] == task]
    plt.plot([task_df['myt5_time'].values[0],
              task_df['byt5_time'].values[0]],
             [task_df['myt5_score'].values[0],task_df['byt5_score'].values[0]], label=f'{task}',
             color=category_colors[task],
             linestyle=category_linestyles[task], 
             lw=2., alpha=0.7)
    # plot lefthand and righthand point with diffrent marker
    plt.scatter(task_df['myt5_time'].values[0], task_df['myt5_score'].values[0], color=category_colors[task], marker='X', s=80.)
    plt.scatter(task_df['byt5_time'].values[0], task_df['byt5_score'].values[0], color=category_colors[task], marker='s', s=80.)

# Add legend items for model point (without color)

plt.scatter([], [], color='k', alpha=0.5, marker='X',s=80., label='MyT5')
plt.scatter([], [], color='k', alpha=0.5, marker='s',s=80., label='ByT5')



# Rename Tas
# Add labels and title
plt.xlabel('Inference Time (relative)')
plt.ylabel('Score')
plt.legend(loc='upper left')

plt.tight_layout()
plt.savefig(f"output/xtreme_up_comparison.png", dpi=300)
