In [1]:
import os
import pickle
import pandas as pd
from typing import List
import math
import os.path

import json

In [2]:
RESULT_DIR = "xtreme_up_results"

# Metric / langs defineition

In [3]:

task_langs = {'qa_in_lang': ['ar', 'bn', 'en', 'fi', 'id', 'ko', 'ru', 'sw', 'te'],
              'qa_cross_lang': ['ar', 'as', 'bho', 'bn', 'brx', 'fi', 'gbm', 'gom', 'gu', 'hi', 'hne',
                                'kn', 'ko', 'mai', 'ml', 'mni', 'mr', 'mwr', 'or', 'pa', 'ps', 'ru', 'sa', 'ta', 'te', 'ur'],
              'ner': ['am', 'bbj', 'bm', 'ee', 'ha', 'ig', 'lg', 'luo', 'mos', 'ny', 'pcm', 'rw', 'sn',
                      'sw','tn', 'tw', 'wo', 'xh', 'yo', 'zu'],
              'semantic_parsing': ['am', 'be', 'bn', 'de', 'en', 'es', 'fi', 'fr', 'ha', 'hi', 'ja',
                                    'pt_br', 'ru', 'sw', 'ta', 'th', 'tr', 'yo', 'zu'],
              'transliteration': ['am_Ethi_Latn', 'am_Latn_Ethi', 'bn_Beng_Latn', 'bn_Latn_Beng',
                                  'gu_Gujr_Latn', 'gu_Latn_Gujr', 'hi_Deva_Latn', 'hi_Latn_Deva',
                                  'kn_Knda_Latn', 'kn_Latn_Knda', 'ml_Latn_Mlym', 'ml_Mlym_Latn',
                                  'mr_Deva_Latn', 'mr_Latn_Deva', 'pa_Arab_Guru', 'pa_Arab_Latn',
                                  'pa_Guru_Arab', 'pa_Guru_Latn', 'pa_Latn_Arab', 'pa_Latn_Guru',
                                  'sd_Arab_Latn', 'sd_Latn_Arab', 'si_Latn_Sinh', 'si_Sinh_Latn',
                                  'ta_Latn_Taml', 'ta_Taml_Latn', 'te_Latn_Telu', 'te_Telu_Latn',
                                  'ur_Arab_Latn', 'ur_Latn_Arab']}
task_metric = {'qa_in_lang': 'f1', 'qa_cross_lang': 'f1', 'ner': 'span_f1',
               'semantic_parsing': 'sequence_accuracy', 'transliteration': 'cer'}

In [4]:
def parse_model_task_results(model_sized, task):
    model_prefix = model_sized[:4]
    eval_row = {'index': model_sized}
    for lang in task_langs[task]:
        res_file = f"{RESULT_DIR}/{model_sized}_{task}/xtreme_up_{task}.{lang}_{model_prefix}-metrics.jsonl"
        if not os.path.isfile(res_file):
            res_file = f"{RESULT_DIR}/{model_sized}_{task}/xtreme_up_{task}_{lang}_{model_prefix}-metrics.jsonl"
        if not os.path.isfile(res_file):
            eval_row[lang] = 0.0
        else:
            with open(res_file, "r") as in_file:
                line = in_file.readlines()[0]
                eval_row[lang] = json.loads(line)[task_metric[task]]
    return eval_row

In [5]:
def parse_all_task_results(task):
    evals = []
    for model in ("byt5", "myt5"):
        for size in ("small", "base", "large"):
            evals.append(parse_model_task_results(f"{model}_{size}", task))

    df = pd.DataFrame(evals)
    df.set_index('index', inplace=True)
    df.to_csv(f"{RESULT_DIR}/{task}_results.csv")

    display(df)

## In lang QA

In [6]:
parse_all_task_results('qa_in_lang')

Unnamed: 0_level_0,ar,bn,en,fi,id,ko,ru,sw,te
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
byt5_small,78.083873,53.043482,72.710871,75.398729,74.151818,69.798301,70.545182,73.880529,76.249349
byt5_base,82.007528,68.224339,76.3734,81.219868,78.032602,76.502224,76.339183,77.592784,81.787151
byt5_large,81.546073,59.33545,76.294668,80.743523,77.83315,75.660073,75.894278,77.100421,78.71988
myt5_small,77.409724,53.332339,70.012895,73.042036,67.721824,65.274577,69.61321,63.854888,77.031417
myt5_base,82.795277,69.178673,75.729562,81.019042,77.404821,77.246726,78.043532,74.972135,84.301735
myt5_large,82.29031,67.178066,74.896505,80.509612,76.104474,74.777607,76.610998,74.169717,83.611667


## QA Cross lang

In [7]:
parse_all_task_results('qa_cross_lang')

Unnamed: 0_level_0,ar,as,bho,bn,brx,fi,gbm,gom,gu,hi,...,mr,mwr,or,pa,ps,ru,sa,ta,te,ur
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
byt5_small,56.915722,34.047995,32.74644,39.175363,29.568194,55.701087,29.650744,30.317965,31.007181,55.402878,...,31.178064,32.202792,31.773563,31.817654,31.898276,50.238015,30.743799,32.847854,41.158399,31.60386
byt5_base,69.055925,42.761833,43.35129,56.641513,33.640154,64.328708,41.03422,38.371847,44.84761,65.391852,...,46.437757,42.667868,40.302864,45.522719,42.632626,64.131233,39.128778,46.339536,61.995586,44.702289
byt5_large,67.676048,39.033885,38.095984,54.046699,29.150175,67.556242,39.730632,36.117588,40.99943,63.455632,...,38.802186,37.437869,36.478317,43.721512,43.325763,63.200404,37.633314,41.648501,59.230546,43.173684
myt5_small,56.077281,32.275417,31.950042,39.145223,30.616148,55.723843,30.944495,30.69006,30.71684,56.387056,...,32.04497,31.413536,32.069303,29.931191,30.79936,52.010287,31.261939,30.81888,44.294422,31.264709
myt5_base,66.535974,34.12382,34.631563,52.437162,32.485875,59.104266,35.099132,32.819375,36.139197,60.986478,...,36.816859,35.63236,34.474992,37.681222,37.088255,63.420728,33.508984,39.347878,52.521126,37.433125
myt5_large,67.735054,35.731439,35.084101,54.990979,29.914883,59.213784,34.347444,32.495573,34.64729,60.96263,...,35.605162,35.08759,31.723383,37.609823,35.855462,62.512697,33.035667,39.808514,52.319554,37.788562


## NER


In [8]:
parse_all_task_results('ner')

Unnamed: 0_level_0,am,bbj,bm,ee,ha,ig,lg,luo,mos,ny,pcm,rw,sn,sw,tn,tw,wo,xh,yo,zu
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
byt5_small,0.561955,0.745731,0.776091,0.876461,0.873141,0.840553,0.832655,0.746567,0.743871,0.869519,0.845836,0.761933,0.890905,0.876933,0.846772,0.746933,0.784622,0.821339,0.782965,0.838062
byt5_base,0.611684,0.737689,0.797671,0.884841,0.878177,0.855198,0.837318,0.766917,0.74518,0.878307,0.859716,0.776498,0.894664,0.882623,0.854113,0.770686,0.790985,0.833194,0.781134,0.853287
byt5_large,0.598628,0.737195,0.788527,0.879042,0.884211,0.841294,0.843686,0.764331,0.721573,0.883151,0.853972,0.771589,0.895668,0.882855,0.858524,0.771791,0.777728,0.830367,0.79718,0.845181
myt5_small,0.547697,0.426934,0.768714,0.813439,0.854783,0.836609,0.807692,0.760563,0.714211,0.858278,0.816214,0.753837,0.883364,0.867365,0.837435,0.751938,0.742358,0.810374,0.359578,0.831588
myt5_base,0.586621,0.45283,0.771845,0.816989,0.868387,0.83528,0.826495,0.72956,0.719544,0.864257,0.834777,0.767091,0.904116,0.878828,0.857977,0.737589,0.740806,0.813862,0.367901,0.835821
myt5_large,0.613757,0.421704,0.770625,0.822255,0.86496,0.820988,0.826316,0.733591,0.73191,0.866578,0.833622,0.761538,0.889149,0.870557,0.846053,0.744548,0.734458,0.806028,0.367599,0.821413


## Semantic parsing

In [9]:
parse_all_task_results('semantic_parsing')

Unnamed: 0_level_0,am,be,bn,de,en,es,fi,fr,ha,hi,ja,pt_br,ru,sw,ta,th,tr,yo,zu
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
byt5_small,13.949716,24.195729,21.465261,28.440119,30.00811,28.582803,25.979995,28.651896,20.7083,16.539116,26.736956,29.683698,27.764261,21.221952,16.599081,17.866435,26.034063,14.571506,18.842931
byt5_base,18.626656,31.927548,25.304136,32.549338,37.280346,35.589172,31.413896,33.834022,28.413085,23.426871,31.035415,35.847526,34.090295,27.277643,23.925385,24.718127,31.413896,19.27548,24.925656
byt5_large,15.950257,27.683158,23.114355,34.441741,34.576913,32.085987,27.764261,32.669921,25.925926,21.386054,29.81887,32.468235,30.413625,23.060287,19.221411,20.555074,31.170587,16.869424,21.059746
myt5_small,11.300351,19.599892,17.274939,23.14139,23.14139,23.089172,17.03163,19.827262,14.59854,14.923469,16.355772,22.060016,24.736415,15.166261,14.571506,16.825672,18.788862,9.9216,12.327656
myt5_base,11.327386,20.464991,18.27521,23.709111,24.871587,26.194268,18.788862,24.483665,17.356042,18.835034,16.112463,25.57448,23.95242,16.977562,15.923222,15.871639,19.680995,9.516085,12.246553
myt5_large,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


## Transliteration

In [10]:
parse_all_task_results('transliteration')

Unnamed: 0_level_0,am_Ethi_Latn,am_Latn_Ethi,bn_Beng_Latn,bn_Latn_Beng,gu_Gujr_Latn,gu_Latn_Gujr,hi_Deva_Latn,hi_Latn_Deva,kn_Knda_Latn,kn_Latn_Knda,...,sd_Arab_Latn,sd_Latn_Arab,si_Latn_Sinh,si_Sinh_Latn,ta_Latn_Taml,ta_Taml_Latn,te_Latn_Telu,te_Telu_Latn,ur_Arab_Latn,ur_Latn_Arab
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
byt5_small,27.645732,52.949731,20.797342,56.29969,30.794946,57.556379,22.929014,56.260412,23.978567,58.866014,...,36.747779,45.363223,60.344884,23.280207,57.776029,22.308871,52.695728,21.41897,31.325997,41.46808
byt5_base,28.209744,52.255073,20.790585,56.209038,30.827519,57.2113,22.693709,55.732314,23.837183,58.484455,...,37.018661,45.053359,59.867986,23.193175,57.428859,22.047152,52.315918,21.05443,31.455946,41.221261
byt5_large,28.450162,56.263902,20.751168,57.337286,30.256639,57.523665,22.549966,55.931576,23.96686,58.781107,...,36.778789,45.840884,61.788583,23.217653,58.4993,22.219503,52.675877,21.234422,31.600334,42.480039
myt5_small,51.576956,78.771858,52.94217,69.87419,49.828565,64.61519,44.859703,66.681911,56.310505,74.686052,...,52.551942,59.335078,69.804797,45.481578,71.217294,48.707364,74.609602,54.613702,49.164097,55.840357
myt5_base,53.817055,74.889642,52.641478,69.055873,51.795786,64.371419,39.235995,51.804789,57.493809,72.503616,...,53.487715,58.883147,69.636799,47.611147,70.138752,50.190438,74.515642,55.924905,50.252163,55.599709
myt5_large,53.124288,74.04955,50.83507,68.65529,50.894036,63.974631,35.97519,49.483335,56.040344,71.721629,...,52.536437,58.818265,69.787455,46.557695,70.009142,49.482946,73.905563,55.356451,48.844381,55.210969
