In [None]:
import csv
import os

from IPython.display import Image
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd

RESULT_DIR = '/content/drive/MyDrive/dna-nn/results/'
FIGURE_DIR = '/content/drive/MyDrive/dna-nn/figures/'

mpl.style.use('seaborn-white')
plt.rcParams['font.size'] = 12

%matplotlib inline

In [None]:
model_files = [file.split('.')[0].split('-')[:2]
               for file in os.listdir(RESULT_DIR)
               if file.endswith('dynamics.csv')]

df = pd.DataFrame(model_files, columns=['model', 'dataset'])
df['done'] = True
df.sort_values(['dataset', 'model'], inplace=True)

df.set_index(['dataset', 'model'], inplace=True)
index = pd.MultiIndex.from_product([df.index.levels[0], df.index.levels[1]])
df_multi_idx = df.reindex(index, fill_value=False)

df_long = df_multi_idx.reset_index()
df = df_long.pivot(index='model', columns='dataset', values='done')
# df = df.style.applymap(lambda val: 'color:black' if val else 'color:red')
df

dataset,histone,motif_discovery,splice
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
cnn_nguyen_2_conv2d,True,True,True
cnn_nguyen_conv1d_2_conv2d,True,True,True
cnn_zeng_2_conv2d,True,True,True
cnn_zeng_3_conv2d,True,True,True
cnn_zeng_4_conv2d,True,True,True


In [None]:
df_long.sort_values(['model', 'dataset'], inplace=True)
df_long = df_long.reindex(columns=['model', 'dataset', 'done'])
df_long.set_axis(range(len(df_long)), inplace=True)
df_long

Unnamed: 0,model,dataset,done
0,cnn_nguyen_2_conv2d,histone,True
1,cnn_nguyen_2_conv2d,motif_discovery,True
2,cnn_nguyen_2_conv2d,splice,True
3,cnn_nguyen_conv1d_2_conv2d,histone,True
4,cnn_nguyen_conv1d_2_conv2d,motif_discovery,True
5,cnn_nguyen_conv1d_2_conv2d,splice,True
6,cnn_zeng_2_conv2d,histone,True
7,cnn_zeng_2_conv2d,motif_discovery,True
8,cnn_zeng_2_conv2d,splice,True
9,cnn_zeng_3_conv2d,histone,True


In [None]:
def plot_dynamics(file, ax):
    tmp = pd.read_csv(file)[['accuracy', 'val_accuracy']]
    tmp.plot(ax=ax, legend=False)

def plot_roc(file, ax):
    tmp = pd.read_csv(file)
    if 'ovr' in tmp.columns:
        for cls in set(tmp['ovr']):
            ax.plot((tmp.loc[tmp['ovr']==cls])['fpr'], 
                    (tmp.loc[tmp['ovr']==cls])['tpr'])
    else:
        ax.plot(tmp['fpr'], tmp['tpr'])
    ax.plot([0, 1], [0, 1], color='grey', linestyle='dashed')

def plot_pr(file, ax):
    tmp = pd.read_csv(file)
    if 'ovr' in tmp.columns:
        for cls in set(tmp['ovr']):
            ax.plot((tmp.loc[tmp['ovr']==cls])['recall'], 
                    (tmp.loc[tmp['ovr']==cls])['precision'])
    else:
        ax.plot(tmp['recall'], tmp['precision'])

def plot_results(r, c, file_type, plot_func, xlabel, ylabel, save_to):
    fig, axs = plt.subplots(r, c, sharex=True, sharey=True, figsize=(12, 16))
    axs = axs.ravel()
    for idx, row in df_long.iterrows():
        if idx < c:
            text = df_multi_idx.index.levels[0][idx]
            axs[idx].text(0.5, 1.05, text, transform=axs[idx].transAxes, ha='center')
        if (idx+1) % c == 0:
            text = df_multi_idx.index.levels[1][(idx+1)//c-1]
            axs[idx].text(1.05, 0.5, text, transform=axs[idx].transAxes)
        if idx % c == 0:
            axs[idx].set_ylabel(ylabel)
        for ax in axs[-c:]:
            ax.set_xlabel(xlabel)
        axs[idx].spines['top'].set_visible(False)
        axs[idx].spines['right'].set_visible(False)
        
        model = row['model']
        dataset = row['dataset']
        file = f'{RESULT_DIR}{model}-{dataset}-{file_type}.csv'
        if not os.path.exists(file):
            continue
        plot_func(file, axs[idx])
    if file_type == 'dynamics':
        axs[0].legend()
    fig.savefig(FIGURE_DIR + save_to, bbox_inches='tight')
    plt.close(fig)

In [None]:
r, c = len(df_multi_idx.index.levels[1]), len(df_multi_idx.index.levels[0])
plot_results(r, c, 'dynamics', plot_dynamics, 'epoch', 'accuracy', 'dynamics.png')
plot_results(r, c, 'roc', plot_roc, 'fpr', 'tpr', 'roc.png')
plot_results(r, c, 'pr', plot_pr, 'recall', 'precision', 'pr.png')

In [None]:
Image(FIGURE_DIR + 'roc.png', width=450)

In [None]:
accuracy_files = [file for file in os.listdir(RESULT_DIR) if file.endswith('accuracy.csv')]
accuracies = []
for a in accuracy_files:
    model, dataset = a.split('-')[:2]
    with open(RESULT_DIR + a, 'r') as f:
        reader = csv.DictReader(f)
        d = next(reader)
        for k in d:
            d[k] = float(d[k])
        d['model'] = model
        d['dataset'] = dataset
        accuracies.append(d)
acc = pd.DataFrame(accuracies)
acc = acc.reindex(columns=['dataset', 'model', 'accuracy', 'val_accuracy', 'test_accuracy'])
acc.sort_values(['dataset', 'model'], inplace=True)
acc.set_index(['dataset', 'model'], inplace=True)
acc

Unnamed: 0_level_0,Unnamed: 1_level_0,accuracy,val_accuracy,test_accuracy
dataset,model,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
histone,cnn_nguyen_2_conv2d,0.930488,0.787082,0.788864
histone,cnn_nguyen_conv1d_2_conv2d,0.989974,0.800445,0.799109
histone,cnn_zeng_2_conv2d,0.782011,0.769265,0.751002
histone,cnn_zeng_3_conv2d,0.838155,0.763474,0.746548
histone,cnn_zeng_4_conv2d,0.959324,0.799109,0.797327
motif_discovery,cnn_nguyen_2_conv2d,0.692906,0.672174,0.658528
motif_discovery,cnn_nguyen_conv1d_2_conv2d,0.762113,0.721449,0.712494
motif_discovery,cnn_zeng_2_conv2d,0.647117,0.645652,0.643501
motif_discovery,cnn_zeng_3_conv2d,0.767124,0.696087,0.686865
motif_discovery,cnn_zeng_4_conv2d,0.923379,0.685217,0.680595


In [None]:
r = len(acc.index.levels[0])
c = len(acc.index.levels[1])
fig, axs = plt.subplots(r, 1, sharex=True, sharey=True, figsize=(12,8))
for i0, ax in zip(acc.index.levels[0], axs):
    acc.loc[i0].plot.bar(ax=ax, legend=False)
    ax.set_ylabel('accuracy')
    ax.text(1.01, 0.5, i0, transform=ax.transAxes)
    ax.tick_params('x', rotation=-10)
    ax.hlines(0.8, -1, c + 1, color='grey', linestyles='dashed')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
axs[0].legend(bbox_to_anchor=(1.1, 1), loc='upper left')
plt.savefig(FIGURE_DIR + 'accuracy.png', bbox_inches='tight')
plt.close(fig)

In [None]:
Image(FIGURE_DIR + 'accuracy.png')