In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import sys
# setting path
sys.path.append('..')
from utils.eval_utils import get_temp_df, compute_mcc
from scipy.stats import bootstrap
from sklearn.metrics import matthews_corrcoef
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from collections import defaultdict
import pickle
import json

In [None]:
psy_mccs = []
psy_lower = []
psy_higher = []

for i in np.arange(1,6):
    df = pd.read_csv(f"../../result_files/psyroberta_frac{i}0_results.csv")
    res = df[df.epoch==11]
    temp_test = get_temp_df(res, split="test")

    t=0.5
    preds = [1 if i>=t else 0 for i in temp_test.p_mean]
    targets = temp_test.target

    mcc = np.round(matthews_corrcoef(targets, preds),3)
    psy_mccs.append(mcc)
    conf = bootstrap((targets, preds), compute_mcc, vectorized=False, paired=True,random_state=22, n_resamples=1000)
    low_conf, high_conf = np.round(conf.confidence_interval[0],3), np.round(conf.confidence_interval[1],3)
    psy_lower.append(low_conf)
    psy_higher.append(high_conf)
    print(f"MCC = {mcc} [{low_conf},{high_conf}]")

# 100 % discharge summaries result
psy_mccs.append(0.285)
psy_lower.append(0.269)
psy_higher.append(0.304)

In [None]:
ro_mccs = []
ro_lower = []
ro_higher = []

for i in np.arange(1,6):
    df = pd.read_csv(f"../../result_files/roberta_frac{i}0_results.csv")
    res = df[df.epoch==11]
    temp_test = get_temp_df(res, split="test")

    t=0.5
    preds = [1 if i>=t else 0 for i in temp_test.p_mean]
    targets = temp_test.target

    mcc = np.round(matthews_corrcoef(targets, preds),3)
    ro_mccs.append(mcc)
    conf = bootstrap((targets, preds), compute_mcc, vectorized=False, paired=True,random_state=22, n_resamples=1000)
    low_conf, high_conf = np.round(conf.confidence_interval[0],3), np.round(conf.confidence_interval[1],3)
    ro_lower.append(low_conf)
    ro_higher.append(high_conf)
    print(f"MCC = {mcc} [{low_conf},{high_conf}]")

# 100 % discharge summaries result
ro_mccs.append(0.226)
ro_lower.append(0.208)
ro_higher.append(0.244)

In [None]:
meda_mccs = []
meda_lower = []
meda_higher = []

for i in np.arange(1,6):
    df = pd.read_csv(f"../../result_files/medabert_frac{i}0_results.csv")
    res = df[df.epoch==11]
    temp_test = get_temp_df(res, split="test")

    t=0.5
    preds = [1 if i>=t else 0 for i in temp_test.p_mean]
    targets = temp_test.target

    mcc = np.round(matthews_corrcoef(targets, preds),3)
    meda_mccs.append(mcc)
    conf = bootstrap((targets, preds), compute_mcc, vectorized=False, paired=True,random_state=22, n_resamples=1000)
    low_conf, high_conf = np.round(conf.confidence_interval[0],3), np.round(conf.confidence_interval[1],3)
    meda_lower.append(low_conf)
    meda_higher.append(high_conf)
    print(f"MCC = {mcc} [{low_conf},{high_conf}]")

# 100 % discharge summaries result
meda_mccs.append(0.264)
meda_lower.append(0.248)
meda_higher.append(0.283)

In [None]:
bert_mccs = []
bert_lower = []
bert_higher = []

for i in np.arange(1,6):
    df = pd.read_csv(f"../../result_files/bert_frac{i}0_results.csv")
    res = df[df.epoch==11]
    temp_test = get_temp_df(res, split="test")

    t=0.5
    preds = [1 if i>=t else 0 for i in temp_test.p_mean]
    targets = temp_test.target

    mcc = np.round(matthews_corrcoef(targets, preds),3)
    bert_mccs.append(mcc)
    conf = bootstrap((targets, preds), compute_mcc, vectorized=False, paired=True,random_state=22, n_resamples=1000)
    low_conf, high_conf = np.round(conf.confidence_interval[0],3), np.round(conf.confidence_interval[1],3)
    bert_lower.append(low_conf)
    bert_higher.append(high_conf)
    print(f"MCC = {mcc} [{low_conf},{high_conf}]")

# 100 % discharge summaries result
bert_mccs.append(0.215)
bert_lower.append(0.198)
bert_higher.append(0.231)

In [None]:
sns.set_palette('mako_r', n_colors=4)
sns.set_style("white", {"axes.edgecolor": ".8"})


#fig, ax = plt.subplots(figsize=(5,4))
fig, (ax1, ax2) = plt.subplots(1,2, sharey=True, gridspec_kw={"width_ratios":[6, 0.8]})
x= np.arange(0,6)

ax1.plot(x, psy_mccs, "-o", label="PsyRoBERTa", ms=4)
ax1.fill_between(x, psy_lower, psy_higher, alpha=0.15)
ax1.plot(x, meda_mccs, "-o", label="MeDa-BERT", ms=4)
ax1.fill_between(x, meda_lower, meda_higher, alpha=0.15)
ax1.plot(x, ro_mccs,"-o", label="RøBÆRTa", ms=4)
ax1.fill_between(x, ro_lower, ro_higher, alpha=0.15)
ax1.plot(x, bert_mccs, "-o", label="BERT", ms=4)
ax1.fill_between(x, bert_lower, bert_higher, alpha=0.15)

ax2.plot(x, psy_mccs, "-o", label="PsyRoBERTa", ms=4)
ax2.fill_between(x, psy_lower, psy_higher, alpha=0.15)
ax2.plot(x, meda_mccs, "-o", label="MeDa-BERT", ms=4)
ax2.fill_between(x, meda_lower, meda_higher, alpha=0.15)
ax2.plot(x, ro_mccs,"-o", label="RøBÆRTa", ms=4)
ax2.fill_between(x, ro_lower, ro_higher, alpha=0.15)
ax2.plot(x, bert_mccs, "-o", label="BERT", ms=4)
ax2.fill_between(x, bert_lower, bert_higher, alpha=0.15)


ax1.set_ylim(ymin=0, ymax=0.35)
ax1.set_xticks(np.arange(0,5))
ax1.set_xticklabels(["10%", "20%", "30%", "40%", "50%"])
#ax1.axvline(x=4.3, linestyle="--", c="black", linewidth=0.7)
ax1.set_ylabel("MCC")

ax2.set_ylim(ymin=0, ymax=0.35)
ax2.set_xticks(np.arange(0,6))
ax2.set_xticklabels(["10%", "20%", "30%", "40%", "50%","100%"])

ax1.set_xlim(0, 4.5)
ax2.set_xlim(4.5, 5)

# hide the spines between ax and ax2
ax1.spines['right'].set_visible(False)
ax2.spines['left'].set_visible(False)

ax1.legend(loc="upper left");

d = .017  # how big to make the diagonal lines in axes coordinates
# arguments to pass plot, just so we don't keep repeating them
kwargs = dict(transform=ax1.transAxes, color='gray', clip_on=False)
ax1.plot((1.022-d, 1+d), (-d, +d), **kwargs, zorder=10)
ax1.plot((1.022-d, 1+d), (1-d, 1+d), **kwargs, zorder=10)

kwargs.update(transform=ax2.transAxes)  # switch to the bottom axes
ax2.plot((-0.022-d, +d), (1-d, 1+d), **kwargs, zorder=10)
ax2.plot((-0.022-d, +d), (-d, +d), **kwargs , zorder=10)

plt.subplots_adjust(wspace=0.15)

plt.savefig("../../output/data_req_plot.pdf", bbox_inches="tight")
plt.savefig("../../output/data_req_plot.png", bbox_inches="tight")
plt.show()

In [None]:
models = ["psyroberta", "roberta", "medabert", "bert"]
epochs = np.arange(0,12)
data_fraq = np.arange(1,6)
fig, ax = plt.subplots(1,5, figsize=(15,2))

data_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))

#for m_ m in enumerate(models):
for i in tqdm(data_fraq):
    for m in tqdm(models):
        mccs = []
        mcc_lower = []
        mcc_upper = []
        df = pd.read_csv(f"result_files/{m}_frac{i}0_results.csv")
        for j in tqdm(epochs):
            res = df[df.epoch==j]
            temp_test = get_temp_df(res, split="test")

            t=0.5
            preds = [1 if i>=t else 0 for i in temp_test.p_mean]
            targets = temp_test.target

            mcc = np.round(matthews_corrcoef(targets, preds),3)
            conf = bootstrap((targets, preds), compute_mcc, vectorized=False, paired=True,random_state=22, n_resamples=1000)
            low_conf, high_conf = np.round(conf.confidence_interval[0],3), np.round(conf.confidence_interval[1],3)
            mccs.append(mcc)
            mcc_lower.append(low_conf)
            mcc_upper.append(high_conf)
        data_dict[f"freq{i}"][m]["mccs"] = mccs
        data_dict[f"freq{i}"][m]["mccs_lower"] = mcc_lower
        data_dict[f"freq{i}"][m]["mccs_upper"] = mcc_upper
        ax[i-1].plot(epochs, mccs, "-o", label=m)
        ax[i-1].fill_between(epochs, mcc_lower, mcc_upper, alpha=0.15)


data = json.loads(json.dumps(data_dict))

In [None]:
sns.set_palette('mako_r', n_colors=4)
sns.set_style("white", {"axes.edgecolor": ".8"})

with open('data_epochs_dict.pkl', 'rb') as file:
    data = pickle.load(file)

models = ["psyroberta","medabert", "roberta", "bert"]
epochs = np.arange(0,12)
data_fraq = np.arange(1,6)
fig, ax = plt.subplots(1,5, figsize=(22,4), sharey=True)

for i in data_fraq:
    for m in models:
        ax[i-1].plot(epochs, data[f"freq{i}"][m]["mccs"], "-", label=m)
        ax[i-1].fill_between(epochs, data[f"freq{i}"][m]["mccs_lower"], data[f"freq{i}"][m]["mccs_upper"], alpha=0.15)
        ax[i-1].legend(fontsize=10)
        ax[i-1].set_xlabel("Epoch")
        ax[i-1].set_ylabel("MCC")
        ax[i-1].set_xticks(np.arange(0,12))
        ax[i-1].set_title(f"{i}0%")

plt.savefig("../../output/data_req_epochs_plot.pdf", bbox_inches="tight")
plt.savefig("../../output/data_req_epochs_plot.png", bbox_inches="tight")
plt.show()