In [1]:
import os
import numpy
import torch
import matplotlib.pyplot as plt
import sklearn
import pickle
import pandas
import gzip
import sys
import plotly_express as px
import plotly.graph_objects as go
import functools
sys.path.insert(0, os.path.abspath('../../'))
from marrovision.cortex.data.bone_marrow.utilities import get_results_comparison_table

def get_stats_for_exp(outputs_repo: str):
    stats_filepath = os.path.join(outputs_repo, 'stats_latest-rank0.pth')
    if not os.path.isfile(stats_filepath):
        stats_filepath = os.path.join(outputs_repo, 'stats_latest.pth')
        
    stats = torch.load(stats_filepath, map_location='cpu')
    return stats

In [2]:
warehouse_repo = '.../warehouse/marrovision/bone_marrow_cell_classification/'

exp_warehouse_repo = {
    'resnext50_32x4d': os.path.join(warehouse_repo, 'resnext50_32x4d'),
    'resnext50_32x4d_100ep': [
        os.path.join(warehouse_repo, 'resnext50_32x4d_100ep', f'fold_{i}') for i in range(5)
    ],
    'resnext50_32x4d_cutmix_and_mixup': os.path.join(warehouse_repo, 'resnext50_32x4d_cutmix_and_mixup'),
    'resnext50_32x4d_mixup': os.path.join(warehouse_repo, 'resnext50_32x4d_mixup')
}

exp_stats = dict()
for exp_name, filepath in exp_warehouse_repo.items():
    if isinstance(filepath, list):
        exp_stats[exp_name] = [get_stats_for_exp(e) for e in filepath]
    else:
        exp_stats[exp_name] = get_stats_for_exp(filepath)

In [4]:
mname = 'resnext50_32x4d_mixup'
df1 = get_results_comparison_table(exp_stats[mname]['test']).sort_values(by='class_abbreviation')
df1 = df1.rename({e: mname + '_' + e[len('model_'):] for e in df1.columns if e.startswith('model_')}, axis=1, errors='raise')

In [5]:
df1

Unnamed: 0,class_name,class_abbreviation,baseline_precision_strict,baseline_recall_strict,baseline_support,baseline_f1_strict,resnext50_32x4d_mixup_precision,resnext50_32x4d_mixup_recall,resnext50_32x4d_mixup_f1,test_support,f1_difference_from_mateketal
18,Abnormal eosinophils,ABE,0.02,0.2,8,0.036364,0.285714,1.0,0.444444,2,0.408081
13,Artefacts,ART,0.82,0.74,19630,0.777949,0.875542,0.720326,0.790386,3926,0.012437
5,Basophils,BAS,0.14,0.64,441,0.229744,0.084171,0.752809,0.151412,89,-0.078331
9,Blasts,BLA,0.75,0.65,11973,0.696429,0.742373,0.640084,0.687444,2395,-0.008985
16,Erythroblasts,EBO,0.88,0.82,27395,0.848941,0.931737,0.85198,0.890075,5479,0.041134
4,Eosinophils,EOS,0.85,0.91,5883,0.878977,0.944186,0.862362,0.901421,1177,0.022444
20,Faggot cells,FGC,0.17,0.63,47,0.26775,0.053846,0.7,0.1,10,-0.16775
17,Hairy cells,HAC,0.35,0.8,409,0.486957,0.227273,0.792683,0.353261,82,-0.133696
11,Smudge cells,KSC,0.28,0.9,42,0.427119,0.291667,0.777778,0.424242,9,-0.002876
19,Immature lymphocytes,LYI,0.08,0.53,65,0.139016,0.151515,0.384615,0.217391,13,0.078375


In [6]:
mname = 'resnext50_32x4d_100ep'
df2 = []
for fold_index, stats in enumerate(exp_stats[mname]):
    tmp_df = get_results_comparison_table(stats['test']).sort_values(by='class_abbreviation')
    tmp_df = tmp_df.rename({e: mname + '_' + e[len('model_'):] for e in tmp_df.columns if e.startswith('model_')}, axis=1, errors='raise')
    tmp_df['fold_index'] = fold_index
    df2.append(tmp_df.copy())

df2 = pandas.concat(df2)

In [7]:
df2.groupby(['class_name', 'class_abbreviation', 'baseline_precision_strict',
       'baseline_recall_strict', 'baseline_support', 'baseline_f1_strict',]).mean().reset_index().sort_values(by='class_abbreviation')

Unnamed: 0,class_name,class_abbreviation,baseline_precision_strict,baseline_recall_strict,baseline_support,baseline_f1_strict,resnext50_32x4d_100ep_precision,resnext50_32x4d_100ep_recall,resnext50_32x4d_100ep_f1,test_support,f1_difference_from_mateketal,fold_index
0,Abnormal eosinophils,ABE,0.02,0.2,8,0.036364,0.4,0.2,0.266667,1.6,0.230303,2.0
1,Artefacts,ART,0.82,0.74,19630,0.777949,0.90945,0.831228,0.868576,3926.0,0.090627,2.0
3,Basophils,BAS,0.14,0.64,441,0.229744,0.570175,0.607686,0.58699,88.2,0.357246,2.0
4,Blasts,BLA,0.75,0.65,11973,0.696429,0.849763,0.79337,0.820444,2394.6,0.124015,2.0
6,Erythroblasts,EBO,0.88,0.82,27395,0.848941,0.94704,0.923161,0.934939,5479.0,0.085998,2.0
5,Eosinophils,EOS,0.85,0.91,5883,0.878977,0.937316,0.965835,0.951361,1176.6,0.072383,2.0
7,Faggot cells,FGC,0.17,0.63,47,0.26775,0.214688,0.253333,0.227106,9.4,-0.040644,2.0
8,Hairy cells,HAC,0.35,0.8,409,0.486957,0.667957,0.792231,0.723603,81.8,0.236647,2.0
20,Smudge cells,KSC,0.28,0.9,42,0.427119,0.660784,0.863889,0.735843,8.4,0.308724,2.0
9,Immature lymphocytes,LYI,0.08,0.53,65,0.139016,0.405808,0.276923,0.320574,13.0,0.181558,2.0


In [8]:
df2.groupby(['class_name', 'class_abbreviation', 'baseline_precision_strict',
       'baseline_recall_strict', 'baseline_support', 'baseline_f1_strict',]).std().reset_index().sort_values(by='class_abbreviation')

Unnamed: 0,class_name,class_abbreviation,baseline_precision_strict,baseline_recall_strict,baseline_support,baseline_f1_strict,resnext50_32x4d_100ep_precision,resnext50_32x4d_100ep_recall,resnext50_32x4d_100ep_f1,test_support,f1_difference_from_mateketal,fold_index
0,Abnormal eosinophils,ABE,0.02,0.2,8,0.036364,0.547723,0.273861,0.365148,0.547723,0.365148,1.581139
1,Artefacts,ART,0.82,0.74,19630,0.777949,0.005039,0.00248,0.00295,0.0,0.00295,1.581139
3,Basophils,BAS,0.14,0.64,441,0.229744,0.041805,0.027704,0.020618,0.447214,0.020618,1.581139
4,Blasts,BLA,0.75,0.65,11973,0.696429,0.009946,0.017843,0.007759,0.547723,0.007759,1.581139
6,Erythroblasts,EBO,0.88,0.82,27395,0.848941,0.005632,0.003321,0.003259,0.0,0.003259,1.581139
5,Eosinophils,EOS,0.85,0.91,5883,0.878977,0.002551,0.003019,0.002563,0.547723,0.002563,1.581139
7,Faggot cells,FGC,0.17,0.63,47,0.26775,0.084371,0.113963,0.092941,0.547723,0.092941,1.581139
8,Hairy cells,HAC,0.35,0.8,409,0.486957,0.038515,0.038219,0.022299,0.447214,0.022299,1.581139
20,Smudge cells,KSC,0.28,0.9,42,0.427119,0.133277,0.144471,0.100063,0.547723,0.100063,1.581139
9,Immature lymphocytes,LYI,0.08,0.53,65,0.139016,0.226506,0.11666,0.14017,0.0,0.14017,1.581139
