### Install dependencies

In [None]:
!conda activate py39
!conda install --file ./requirements.txt
!conda install --file ./examples/requirements.txt


### Download GLUE dataset


In [None]:
!python download_glue_data.py --data_dir glue_data --tasks all


### Variables


In [None]:
PATH_TO_DATA = "glue_data"
MODEL_TYPE = "bert"  # bert or roberta
MODEL_SIZE = "base"  # base or large
DATASETS = ["CoLA", "SST-2", "MRPC", "STS-B",
            "QQP", "MNLI", "QNLI", "RTE", "WNLI"]  # GLUE
# CoLA  acceptability
# SST-2 sentiment
# MRPC  paraphrase
# STS-B sentence similarity FIXME: doesn't work
# QQP   paraphase
# MNLI  NLI
# QNLI  QA/NLI
# RTE   NLI
# WNLI  coreference/NLI


In [None]:
PATH_TO_DATA = "ner_data"
MODEL_TYPE = "bert"  # bert or roberta
MODEL_SIZE = "base"  # base or large
DATASETS = ["CoNLL"]


### This is for reading `.npy` data in `plotting/`


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

correct_colors = ['#99e2b4', '#78c6a3', '#56AB91', '#469D89',
                  '#358F80', '#248277', '#14746F', '#036666']
incorrect_colors = ['#FFF0F3', '#FFB3C1', '#FF758F', '#FF4D6D',
                    '#C9184A', '#A4133C', '#800F2F', '#590D22']


def set_style(y_lim):
    ax1.set_title(DATASET+' early exit evaluation')
    ax1.set_xlabel('Layer')
    ax1.set_ylabel('Number (total: {})'.format(y_lim))
    ax1.xaxis.set_major_locator(plt.MultipleLocator(1))
    ax1.set_xlim(left=0.5, right=12.5)
    ax1.set_ylim(bottom=0, top=y_lim)

    ax2.set_xlabel('Layer')
    ax2.set_ylabel('Score')
    ax2.xaxis.set_major_locator(plt.MultipleLocator(1))
    ax2.yaxis.set_major_locator(plt.MultipleLocator(0.1))
    ax2.set_ylim(bottom=0, top=1)

    ax3.set_title(DATASET+' performance V.S. time-saving')
    ax3.set_xlabel('time-saving')
    ax3.set_ylabel('performance')
    ax3.xaxis.set_major_locator(plt.MultipleLocator(0.1))
    ax3.yaxis.set_major_locator(plt.MultipleLocator(0.1))
    ax3.set_xlim(left=-0.05, right=1.05)
    ax3.set_ylim(bottom=-0.05, top=1.05)


def plot(fname, data, cnt):
    if fname == 'each_layer.npy':
        x = np.arange(1, 13)
        ax2.plot(x, data)
    else:
        total = list(data[0].values())
        correct = list(data[1].values())
        incorrect = [i - j for i, j in zip(total, correct)]
        y_lim = np.sum(total)
        x = np.arange(1, 13)
        width = 0.12
        bias = (cnt-4.5)*width
        ax1.bar(x+bias, correct, width, label='correct',
                color=correct_colors[(cnt-1) % len(correct_colors)])
        ax1.bar(x+bias, incorrect, width,
                bottom=correct, label='incorrect', color=incorrect_colors[(cnt-1) % len(incorrect_colors)])
        time_prop = 1-data[3]
        performance = data[4]
        if isinstance(performance,  dict):
            pass
        else:
            ax3.scatter(time_prop, performance)
            ax3.annotate(str(0.1*(cnt-1))[0:3], (time_prop, performance))
        return y_lim


for DATASET in DATASETS:
    print(DATASET)
    relative_path = "plotting/saved_models/"+MODEL_TYPE + \
        "-"+MODEL_SIZE+"/"+DATASET+"/two_stage/"
    for path, lists, frame in os.walk(relative_path):
        frame.sort()
        cnt = 0
        intergrate_fig = plt.figure(figsize=(10, 6))
        ax1 = intergrate_fig.add_subplot(1, 1, 1)
        ax2 = ax1.twinx()
        trade_off_fig = plt.figure(figsize=(10, 6))
        ax3 = trade_off_fig.add_subplot(1, 1, 1)
        y_lim = 0
        for fname in frame:
            data = np.load(path+"/"+fname, allow_pickle=True)
            print(data)
            y = plot(fname, data, cnt)
            if y is not None:
                y_lim = max(y_lim, y)
            cnt += 1
        set_style(y_lim)
        plt.show()
