In [1]:
codes_repo = '/home/shayan/phoenix/marrovision/'
warehouse_repo = '/home/shayan/warehouse/marrovision/'

In [2]:
import torch
import numpy
import sklearn
from tqdm import tqdm
import matplotlib.pyplot as plt
import plotly_express as px
import seaborn
import pandas
import os
import sys
import functools
sys.path.insert(0, codes_repo)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
labels = ['ABE', 'ART','BAS','BLA', 'EBO', 'EOS', 'FGC', 'HAC', 'KSC', 'LYI', 'LYT', 'MMZ', 'MON', 'MYB', 'NGB', 'NGS', 'NIF', 'OTH', 'PEB', 'PLM', 'PMO']

In [4]:
stats_swav = torch.load(os.path.join(warehouse_repo, 'ssl/swav/exp1/swav-200ep-bmc-linear-eval/linear_eval_files/linear_epoch=99.pth'), map_location='cpu')
precision, recall, f1, support = stats_swav['metrics']['prf'][None]
swav_df = pandas.DataFrame(dict(class_abbreviation=labels, swav_precision=precision, swav_recall=recall, swav_f1=f1, test_support=support))

In [5]:
stats_supcon = torch.load(os.path.join(warehouse_repo, 'ssl/swav/exp1/supcon-200ep-bmc-linear-eval/linear_eval_files/linear_epoch=99.pth'), map_location='cpu')
precision, recall, f1, support = stats_supcon['metrics']['prf'][None]
supcon_df = pandas.DataFrame(dict(class_abbreviation=labels, supcon_precision=precision, supcon_recall=recall, supcon_f1=f1, test_support=support))

In [6]:
def get_f1(x, y):
    return (2*x*y)/(x+y)

In [7]:
meta_df = pandas.DataFrame([
    dict(
        class_name="Band Neutrophils",
        class_abbreviation="NGB",
        baseline_precision_strict=54,
        baseline_recall_strict=65,
        baseline_support=9968,
    ),
    dict(
        class_name="Segmented neutrophils",
        class_abbreviation="NGS",
        baseline_precision_strict=92,
        baseline_recall_strict=71,
        baseline_support=29424,
    ),
    dict(
        class_name="Lymphocytes",
        class_abbreviation="LYT",
        baseline_precision_strict=90,
        baseline_recall_strict=70,
        baseline_support=26242,
    ),
    dict(
        class_name="Monocytes",
        class_abbreviation="MON",
        baseline_precision_strict=57,
        baseline_recall_strict=70,
        baseline_support=4040,
    ),
    dict(
        class_name="Eosinophils",
        class_abbreviation="EOS",
        baseline_precision_strict=85,
        baseline_recall_strict=91,
        baseline_support=5883,
    ),
    dict(
        class_name="Basophils",
        class_abbreviation="BAS",
        baseline_precision_strict=14,
        baseline_recall_strict=64,
        baseline_support=441,
    ),
    dict(
        class_name="Metamyelocytes",
        class_abbreviation="MMZ",
        baseline_precision_strict=30,
        baseline_recall_strict=64,
        baseline_support=3055,
    ),
    dict(
        class_name="Myelocytes",
        class_abbreviation="MYB",
        baseline_precision_strict=52,
        baseline_recall_strict=59,
        baseline_support=6557,
    ),
    dict(
        class_name="Promyelocytes",
        class_abbreviation="PMO",
        baseline_precision_strict=76,
        baseline_recall_strict=72,
        baseline_support=11994,
    ),
    dict(
        class_name="Blasts",
        class_abbreviation="BLA",
        baseline_precision_strict=75,
        baseline_recall_strict=65,
        baseline_support=11973,
    ),
    dict(
        class_name="Plasma cells",
        class_abbreviation="PLM",
        baseline_precision_strict=81,
        baseline_recall_strict=84,
        baseline_support=7629,
    ),
    dict(
        class_name="Smudge cells",
        class_abbreviation="KSC",
        baseline_precision_strict=28,
        baseline_recall_strict=90,
        baseline_support=42,
    ),
    dict(
        class_name="Other cells",
        class_abbreviation="OTH",
        baseline_precision_strict=22,
        baseline_recall_strict=84,
        baseline_support=294,
    ),
    dict(
        class_name="Artefacts",
        class_abbreviation="ART",
        baseline_precision_strict=82,
        baseline_recall_strict=74,
        baseline_support=19630,
    ),
    dict(
        class_name="Not identifiable",
        class_abbreviation="NIF",
        baseline_precision_strict=27,
        baseline_recall_strict=63,
        baseline_support=3538,
    ),
    dict(
        class_name="Proerythroblasts",
        class_abbreviation="PEB",
        baseline_precision_strict=57,
        baseline_recall_strict=63,
        baseline_support=2740,
    ),
    dict(
        class_name="Erythroblasts",
        class_abbreviation="EBO",
        baseline_precision_strict=88,
        baseline_recall_strict=82,
        baseline_support=27395,
    ),
    dict(
        class_name="Hairy cells",
        class_abbreviation="HAC",
        baseline_precision_strict=35,
        baseline_recall_strict=80,
        baseline_support=409,
    ),
    dict(
        class_name="Abnormal eosinophils",
        class_abbreviation="ABE",
        baseline_precision_strict=2,
        baseline_recall_strict=20,
        baseline_support=8,
    ),
    dict(
        class_name="Immature lymphocytes",
        class_abbreviation="LYI",
        baseline_precision_strict=8,
        baseline_recall_strict=53,
        baseline_support=65,
    ),
    dict(
        class_name="Faggot cells",
        class_abbreviation="FGC",
        baseline_precision_strict=17,
        baseline_recall_strict=63,
        baseline_support=47,
    ),
])
meta_df['baseline_f1_strict'] = meta_df.apply(lambda x: get_f1(x['baseline_precision_strict'], x['baseline_recall_strict']), axis=1)
meta_df = meta_df.sort_values(by='class_abbreviation')
for c in [f'baseline_{x}' for x in ['precision_strict', 'recall_strict', 'f1_strict']]:
    meta_df[c] = meta_df[c].apply(lambda x: x * 0.01)

In [8]:
results_df = functools.reduce(lambda a, b: pandas.merge(a, b, on=['class_abbreviation'], how='outer'), [meta_df, swav_df, supcon_df])

In [9]:
results_df['swav_f1_difference_from_mateketal'] = results_df.apply(lambda x: x['swav_f1'] - x['baseline_f1_strict'], axis=1)
results_df['supcon_f1_difference_from_mateketal'] = results_df.apply(lambda x: x['supcon_f1'] - x['baseline_f1_strict'], axis=1)

In [10]:
results_df

Unnamed: 0,class_name,class_abbreviation,baseline_precision_strict,baseline_recall_strict,baseline_support,baseline_f1_strict,swav_precision,swav_recall,swav_f1,test_support_x,supcon_precision,supcon_recall,supcon_f1,test_support_y,swav_f1_difference_from_mateketal,supcon_f1_difference_from_mateketal
0,Abnormal eosinophils,ABE,0.02,0.2,8,0.036364,1.0,1.0,1.0,2,0.0,0.0,0.0,2,0.963636,-0.036364
1,Artefacts,ART,0.82,0.74,19630,0.777949,0.883333,0.890983,0.887142,3926,0.896613,0.903464,0.900025,3926,0.109193,0.122077
2,Basophils,BAS,0.14,0.64,441,0.229744,0.721311,0.494382,0.586667,89,0.765625,0.550562,0.640523,89,0.356923,0.410779
3,Blasts,BLA,0.75,0.65,11973,0.696429,0.821998,0.817537,0.819761,2395,0.826688,0.848434,0.83742,2395,0.123333,0.140992
4,Erythroblasts,EBO,0.88,0.82,27395,0.848941,0.941326,0.948713,0.945005,5479,0.942593,0.949991,0.946278,5479,0.096064,0.097336
5,Eosinophils,EOS,0.85,0.91,5883,0.878977,0.969697,0.951572,0.960549,1177,0.963027,0.951572,0.957265,1177,0.081572,0.078288
6,Faggot cells,FGC,0.17,0.63,47,0.26775,0.25,0.1,0.142857,10,0.5,0.2,0.285714,10,-0.124893,0.017964
7,Hairy cells,HAC,0.35,0.8,409,0.486957,0.866667,0.634146,0.732394,82,0.789474,0.54878,0.647482,82,0.245438,0.160525
8,Smudge cells,KSC,0.28,0.9,42,0.427119,1.0,0.666667,0.8,9,0.833333,0.555556,0.666667,9,0.372881,0.239548
9,Immature lymphocytes,LYI,0.08,0.53,65,0.139016,0.428571,0.230769,0.3,13,0.5,0.153846,0.235294,13,0.160984,0.096278
