# Test results graphing

## Mean averages

In [5]:
import pandas as pd
import plotly.graph_objects as go


def plot_mean_bar_chart(df: pd.DataFrame,
                        lowest_score: int,
                        highest_score: int):
    fig = go.Figure()

    for column in df.columns:
        fig.add_trace(go.Bar(x=df.index, y=df[column], name=column))

    fig.update_layout(barmode='group',
                      xaxis=dict(title='Metric'),
                      yaxis=dict(title='Mean score', range=[lowest_score, highest_score],
                                 tickvals=list(range(lowest_score, highest_score + 1, 5))),
                      width=800,
                      legend=dict(x=0, y=1))

    fig.show()

In [6]:
mean_results = pd.DataFrame([['FT All', 51.7, 63.5, 94.4],
                             ['FT CL', 49.0, 61.2, 94.1],
                             ['Helsinki-NLP', 46.2, 55.4, 93.6],
                             ['NLLB 3.3B', 44.7, 51.3, 92.8],
                             ['MADLAD-400', 48.7, 55.6, 93.5]],
                            columns=['Model', 'SacreBLEU', '`TER`', 'Semantic similarity'])
mean_results = mean_results.set_index('Model').T

plot_mean_bar_chart(mean_results, 40, 95)

## SacreBLEU scores

In [7]:
def plot_sacrebleu_bar_chart(df: pd.DataFrame):
    trace1 = go.Bar(x=df['Dataset'], y=df['FT All'], name='FT All')
    trace2 = go.Bar(x=df['Dataset'], y=df['FT CL'], name='FT CL')
    trace3 = go.Bar(x=df['Dataset'], y=df['Helsinki-NLP'], name='Helsinki-NLP')
    trace4 = go.Bar(x=df['Dataset'], y=df['NLLB 3B'], name='NLLB 3B')
    trace5 = go.Bar(x=df['Dataset'], y=df['MADLAD-400'], name='MADLAD-400')
    
    data = [trace1, trace2, trace3, trace4, trace5]
    
    layout = go.Layout(title='SacreBLEU Scores for Machine Translation Models',
                       xaxis=dict(title='Dataset'),
                       yaxis=dict(title='SacreBLEU Score'),
                       barmode='group')
    
    # Create the figure and plot
    fig = go.Figure(data=data, layout=layout)
    fig.show()

In [8]:
sacrebleu_results = pd.DataFrame([['clinspen-te', 54.5, 39.0, 34.9, 37.6],
                                  ['hpo', 48.7, 47.8, 44.5, 53.8],
                                  ['khresmoi-te', 47.9, 49.5, 49.3, 50.0],
                                  ['orphanet-definitions-te', 61.4, 46.3, 45.8, 50.7],
                                  ['pubmed-te', 45.9, 48.4, 49.0, 51.6]],
                                 columns=['Dataset', 'FT All', 'FT CL', 'Helsinki-NLP', 'NLLB 3B', 'MADLAD-400'])

plot_sacrebleu_bar_chart(sacrebleu_results)