In [58]:
# %matplotlib notebook
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import yaml
from matplotlib import style

def filter_nan(arr): 
     """The way csv file is created for qnn networks, contain some 'nans'
    Therefore I created a function to remove 'nan' from the csv file array"""
     index =np.where( np.isnan(arr)==0)
     filter_arr = arr[index]
     return filter_arr;

def return_final_purities(file_list):
    pur95 = []
    pur98 = []
    pur99 = []
    purs = [pur95, pur98, pur99]
    for current_file in file_list:
        csv_path = f"{artifacts_path}{current_file}{metrics_file}"
        with open(csv_path) as file:
            metrics = pd.read_csv(file)
        j = 7
        for purity in purs:
            purity.append(filter_nan(np.array(metrics[obs[j]]))[-1]) # taking last value: index -1
            j = j + 1
    return purs

def purities_vs_bops(current_file):
    pur95 = []
    pur98 = []
    pur99 = []
    bops  = []
    purs = [pur95, pur98, pur99, bops]
    csv_path = f"{artifacts_path}{current_file}{metrics_file}"
    with open(csv_path) as file:
        metrics = pd.read_csv(file)
        pruned_array = filter_nan(np.array(metrics["pruned"]))
        index_list = []
        last_pruned = 0
        for i in range(len(pruned_array)):
            if(pruned_array[i] > last_pruned):
                index_list.append(i-1)
            last_pruned = pruned_array[i]
        j = 7
        for purity in purs:
            for index in index_list :
                purity.append(filter_nan(np.array(metrics[obs[j]]))[index]) # taking last value: index -1
            j = j + 1
    return purs

def purities_vs_config(sweep_dir, num_scans, config_key):
    pur95 = []
    pur98 = []
    pur99 = []
    config = []
    purs = [pur95, pur98, pur99]
    for i in range(num_scans):
        current_file = f"{sweep_dir}{i}/"
        yaml_path = f"{artifacts_path}{current_file}{yaml_file}"
        csv_path = f"{artifacts_path}{current_file}{metrics_file}"
        with open(csv_path) as file:
            metrics = pd.read_csv(file)
        j = 7
        for purity in purs:
            purity.append(filter_nan(np.array(metrics[obs[j]]))[-1]) # taking last value: index -1
            j = j + 1
        with open(yaml_path) as file:
            all_configs = yaml.load(file, Loader=yaml.FullLoader)
            config.append(all_configs[config_key])
    purs.append(config)
    print(purs)
    return purs


# some constants we need
zoom = True

max_epoch = 30
min_xlim = 0
max_bits = 24
min_ylim = [0, 0, 1e-3, 1e-3]
if(zoom):
    min_ylim = [0, 0.95, 1e-3, 1e-3] # for zoom
max_ylim = [0.2, 1, 0.1, 0.1]

linestyles = ['solid', 'densely dotted', 'densely dashed', 'densely dashdotted']
markerstyles = ['o', '^', 'v', 's']
labelloc = [4, 4, 1, 1]

# file lists
artifacts_path = "../../artifacts/chep/metric_learning/"

metrics_file = "metrics.csv"
yaml_file  = "hparams.yaml"

# standard MLP ones
mlp_reference = "reference/version_8/"
mlp_batchnorm = "batchnorm/version_0/"
mlp_batch_noNorm = "batchnorm_no_norm/version_0/"
# pruning of standard MLP
mlp_prune_unstructured_noL1     = "batchnorm_no_norm_prune_unstructured_noL1/version_0/"
mlp_prune_unstructured_wL1      = "batchnorm_no_norm_prune_unstructured_noL1/version_1/"
mlp_prune_structured_wL1_dim1   = "batchnorm_no_norm_prune_structured_wL1/version_1/"
mlp_prune_structured_noL1       = "batchnorm_no_norm_prune_structured_noL1/version_0/"
mlp_prune_structured_wL1_dim0   = "batchnorm_no_norm_prune_structured_wL1/version_2/"
# input quantization in standard MLP; these are sweeps, to be seen how to do!
mlp_input_integer    = "batchnorm_no_norm_inputQuant_integer/version_"
mlp_input_fractional = "batchnorm_no_norm_inputQuant_frac/version_"


files_mlp = [mlp_reference, mlp_batchnorm, mlp_batch_noNorm]
files_mlp_pruned = [mlp_prune_unstructured_noL1, mlp_prune_unstructured_wL1, mlp_prune_structured_wL1_dim1, mlp_prune_structured_noL1, mlp_prune_structured_wL1_dim0]      
label_files_mlp = {0:"MLP Layernorm w Normalization",
                   1: "MLP Batchnorm w Normalization",
                   2: "MLP Batchnorm w/o Normalization"}
label_files_mlp_prune = {0 : "",
                         1: "",
}
# QMLP: quantized weights and activations, no bias!


obs = {0:"val_loss",
       1:"eff",
       2:"pur",
       3:"current_lr",
       4:"R95",
       5:"R98",
       6:"R99",
       7:"pur_95",
       8:"pur_98",
       9:"pur_99",
       10:"total_bops",
       11:"total_mem_w_bits",
       12:"total_mem_o_bits",
       13:"pruned",
       14:"epoch",
       15:"step",
       16:"train_loss"
      }
label_obs = {0:"Validation loss",
             1:"Efficiency",
             2:"Purity",
             3:"Current learning rate",
             4:"Radius for 95% efficiency",
             5:"Radius for 98% efficiency",
             6:"Radius for 99% efficiency",
             7:"Purity at 95% efficiency",
             8:"Purity at 98% efficiency",
             9:"Purity at 99% efficiency",
             10:"BOPs per cluster/event",
             11:"Memory bits (weights)",
             12:"Memory bits (output, per cluster)",
             13:"Number of pruning steps",
             14:"Epoch",
             15:"Training step",
             16:"Training loss"
      }

reference_purities = return_final_purities(files_mlp)
print(reference_purities)
pur_vs_bops = []
for file in files_mlp_pruned:
    pur_vs_bops.append(purities_vs_bops(file))
print(pur_vs_bops[0][0]) #[file][metric]

pur_vs_int = []
purities_vs_config(mlp_input_integer, 8, "integer_part")


[[0.443886786699295, 0.356970340013504, 0.317711740732193], [0.3505945801734924, 0.253142386674881, 0.2227185368537902], [0.2760992050170898, 0.1858686208724975, 0.1613043248653412]]
[0.1715115308761596, 0.1882798969745636, 0.2638836801052093, 0.2757153809070587, 0.2458789497613906, 0.1541450321674347, 0.0379004701972007, 0.0107563911005854, 0.0104131065309047, 0.0101098474115133, 0.0098428940400481, 0.0096173351630568]
[[0.2607817351818084, 0.3135411143302917, 0.3331023454666137, 0.3189413249492645, 0.3114047348499298, 0.3017463684082031, 0.3331411182880401, 0.3127935230731964], [0.17825348675251, 0.2204591631889343, 0.244937777519226, 0.2224880754947662, 0.2154709845781326, 0.2091311514377594, 0.2441044300794601, 0.220982477068901], [0.1308217495679855, 0.1549898833036422, 0.1857315301895141, 0.1597310304641723, 0.1535184830427169, 0.1542944014072418, 0.1841063350439071, 0.1630336493253708], [817889344.0, 0, 817889344.0, 1, 817889344.0, 2, 817889344.0, 5, 817889344.0, 6, 817889344.0,

[[0.2607817351818084,
  0.3135411143302917,
  0.3331023454666137,
  0.3189413249492645,
  0.3114047348499298,
  0.3017463684082031,
  0.3331411182880401,
  0.3127935230731964],
 [0.17825348675251,
  0.2204591631889343,
  0.244937777519226,
  0.2224880754947662,
  0.2154709845781326,
  0.2091311514377594,
  0.2441044300794601,
  0.220982477068901],
 [0.1308217495679855,
  0.1549898833036422,
  0.1857315301895141,
  0.1597310304641723,
  0.1535184830427169,
  0.1542944014072418,
  0.1841063350439071,
  0.1630336493253708],
 [817889344.0,
  0,
  817889344.0,
  1,
  817889344.0,
  2,
  817889344.0,
  5,
  817889344.0,
  6,
  817889344.0,
  7,
  817889344.0,
  3,
  817889344.0,
  4]]