In [None]:
import sys
sys.path.append("../Synaptic-Flow/")

In [None]:
import pandas as pd
import matplotlib as mpl
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
plt.rcParams["figure.figsize"] = (15,10)
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset


In [None]:
model_dict = {'resnet20' : "ResNet-20", 
              'wide-resnet20' : "WideResNet-20",
              'fc-1000' : "FC-1000"}
pruner_dict = {'synflow': "SynFlow", 'grasp': "GraSP", 'snip': "SNIP", 'mag': "Magnitude", 'rand': "Random", 
              'synflow-dist': "SynFlow-Dist", 'synflow-l2': "SynFlow-L2", 'synflow-dist-l2': "SynFlow-L2-Dist"}
dataset_dict = {'cifar10': "CIFAR-10", 'cifar100' : "CIFAR-100"}

In [None]:
def get_dataframe(dim='path_kernel', 
                  model_class_list=['default', 'lottery'],
                  model_list=['fc', 'fc-1000', 'resnet20', 'wide-resnet20'],
                  pruner_list = ['synflow'],
                  dataset_list = ['cifar10', 'cifar100'],
                  seed_list = ['83', '1337'], # later on need 23, 923
                  comp_list = ["0.0", "1.0"], 
                 root_dir="../Results/pruned/"):
    return_df = {}
    for model_class in model_class_list:
        for model in model_list:
            for pruner in pruner_list:
                curr_dir = root_dir + f"{model_class}/{model}/{pruner}/"
                for dataset in dataset_list:
                    for seed in seed_list:
                        for comp in comp_list:
                            file = curr_dir + f"{dataset}_{seed}_{comp}.csv"
                            # print(f"Trying to get dataframe from file: {file}")
                            if (os.path.exists(file)):
                                # print(f"Reading from file: {file}")
                                df = pd.read_csv(file)
                                if (dim not in df.columns):
                                    pass
                                    # print(f"Didn't find {dim} in file: {file}")
                                elif (np.abs(df['train_loss'].iloc[1] - df['train_loss'].iloc[100]) < 1):
                                    pass
                                    # print(f"Didn't train: {file}")
                                    
                                else:
                                    
                                    return_df[f"{model}--{dataset}--{pruner}--{comp}--{seed}"] = df[dim].iloc[:100].values
                                        
                            else:
                                pass
                                # print("Returning none dataframe as ", file, " doesn't exist")
                                # return None
    return  pd.DataFrame(return_df)

In [None]:
def plot_dataframes(dims = ['path_kernel', 'weight_movement_norm'],
                    model_class_list_list = [["lottery", "default"]],
                    model_list_list = [['resnet20', 'wide-resnet20']],
                   pruner_list_list = [['synflow']],
                   dataset_list_list = [['cifar10'], ['cifar100']],
                   comp_list_list = [["0.0"], ["1.0"]],
                    seed_list=['1337', '82' ,'821', '23', '923', '83'],
                   root_dir="../Results/pruned/"):
    for dim in dims:
        for model_class_list in model_class_list_list:
            for model_list in model_list_list:
                for pruner_list in pruner_list_list:
                    for dataset_list in dataset_list_list:
                        for comp_list in comp_list_list:
                            df = get_dataframe(dim=dim,
                                               model_class_list=model_class_list,
                                               model_list=model_list,
                                               pruner_list=pruner_list,
                                               dataset_list=dataset_list,
                                               comp_list=comp_list,
                                               seed_list = seed_list)
                            
                            ## need to get the path kernel df for the color gradient
                            path_kernel_df = get_dataframe(dim='path_kernel',
                                               model_class_list=model_class_list,
                                               model_list=model_list,
                                               pruner_list=pruner_list,
                                               dataset_list=dataset_list,
                                               comp_list=comp_list,
                                               seed_list = seed_list)
                            if path_kernel_df is None or len(path_kernel_df) == 0:
                                continue
                            path_kernel_df = path_kernel_df.drop(path_kernel_df.index[0])
                            first_epoch_path_kernel_values = path_kernel_df.iloc[0]
                            min_path_kernel = first_epoch_path_kernel_values.min()
                            max_path_kernel = first_epoch_path_kernel_values.max()
                            # print(min_path_kernel, max_path_kernel)
                            if df is not None:
                                df = df.drop(df.index[0])
                                first_epoch_values = df.iloc[0]
                                # print(first_epoch_values.min(), first_epoch_values.max())
                                if dim == "path_kernel":
                                    df = df.sub(first_epoch_values)
                                    df = df.div(first_epoch_values)
                                elif dim == "weight_movement_norm":
                                    df = df.div(first_epoch_values)
                                norm_first_epoch = plt.Normalize(vmin=min_path_kernel, 
                                                                 vmax=max_path_kernel)
                                colors = plt.cm.coolwarm(norm_first_epoch(first_epoch_path_kernel_values.values))
                                legend_columns = []
                                for col in df.columns:
                                    model, dataset, pruner, comp, seed = col.split('--')
                                    label = f"{model_dict[model]}--{pruner_dict[pruner]}"
                                    legend_columns.append(label)

                                colors_list = [list(c) for c in colors]
                                style_list_2 = ['solid', 'dotted', 'dashed', 'dashdot',':',(0, (3, 5, 1, 5, 1, 5)), (0, (5, 10)), (0, (3, 1, 1, 1))]
                                marker_list = [r'$1$', r'$2$', r'$3$', r'$4$', r'$5$', r'$6$', r'$7$', r'$8$', r'$9$', r'$a$', r'$b$', r'$c$', r'$d$', r'$e$', r'$f$', r'$g$', r'$h$']
                                color_dict = {}
                                
                                fig, ax = plt.subplots()
                                loc = 2
                                if dim == "weight_movement_norm":
                                    loc = 4
                                for idx, name in enumerate(df.columns):
                                    color_dict[name] = mpl.colors.rgb2hex(colors_list[idx][:3])
                                
                                marker_plot_dict = {}
                                marker_counter = 0
                                marker_symbol = '+'
                                line_plots = [] 
                                for idx, (name, data) in enumerate(df.iteritems()):
                                    # name looks like: resnet20--cifar100--synflow--0.0--1337
                                    model_name, dataset_name, pruner_name, comp_ratio = name.split('.')[0].split('--') # resnet20--cifar100--synflow--0
                                    
                                    marker_plot_label = f"{model_dict[model_name]}, {pruner_dict[pruner_name]}"
                                    if comp_ratio == '1':
                                        marker_plot_label += ", 10\% pruned"
                                    if (marker_plot_label in marker_plot_dict):
                                        marker = marker_plot_dict[marker_plot_label]
                                        ax.plot(df[name], color = color_dict[name], marker=marker_symbol, markevery=15, ms=10)
                                    else: 
                                        marker_plot_dict[marker_plot_label] = marker_list[marker_counter]
                                        marker_symbol = marker_plot_dict[marker_plot_label]
                                        marker_counter += 1
                                        ax.plot(df[name], color = color_dict[name], marker=marker_symbol, markevery=15, ms=10, label=marker_plot_label)

                                ax.legend()
                        
                                cbax = fig.add_axes([0.91, 0.15, 0.01, 0.7])
                                
                                plt.colorbar(plt.cm.ScalarMappable(norm=norm_first_epoch, cmap=plt.cm.coolwarm), cax=cbax)
                                
                                pruning_str = "10\%"
                                if comp_list[0] == '0.0':
                                    pruning_str = "0\%"
                                if dim == 'path_kernel':
                                    ax.set_title(label=f"Relative Change in Path Kernel - {dataset_dict[dataset_list[0]]} - Pruning {pruning_str}", fontsize=20)
                                    ax.set_xlim(0, 100)
                                    ax.set_xlabel("Epochs", fontsize=20)
                                    ax.set_ylabel(r"$\text{Tr}(\bm{\Pi}_{\bm{\theta}_t}) - \text{Tr}(\bm{\Pi}_{\bm{\theta}_0})$", fontsize=20)
                                    
                                if dim == 'weight_movement_norm':
                                    ax.set_title(label=f"Relative Change in Parameters - {dataset_dict[dataset_list[0]]} - Pruning {pruning_str}", fontsize=20)
                                    ax.set_xlim(0, 100)
                                    ax.set_xlabel("Epochs", fontsize=20)
                                    ax.set_ylabel(r"$\bm{\omega}_t$", fontsize=20)
                                    
                                filename = root_dir+f"{dim}--compression-{','.join(comp_list)}-pruner-{','.join(pruner_list)}--model-{','.join(model_list)}--dataset-{','.join(dataset_list)}.png"
                                # print(f"Saving file: {filename}")
                                plt.savefig(filename)
                                plt.show()
                                print("---"*25)
                        
                        
                        

# Plot: Path Kernel per dataset {CIFAR10/100}, per compression {0.0/1.0}, per model {fc-1000, ResNet-20, Wide-Resnet-20} for all pruners

In [None]:
plt.rcParams["text.usetex"] = True
plt.rcParams['text.latex.preamble']=[r"\usepackage{bm}", r'\usepackage{amsmath}']

plot_dataframes(dims=["path_kernel"], 
                comp_list_list = [["0.0"], ["1.0"]], 
                dataset_list_list = [['cifar100'], ["cifar10"]], 
                model_list_list = [['fc-1000'], ['resnet20'], ['wide-resnet20']],
               pruner_list_list = [['synflow', 'grasp', 'rand', 'snip', 'synflow-dist', 'synflow-l2', 'synflow-dist-l2']])

# Plot: Path Kernel per dataset {CIFAR10/100}, per compression {0.0/1.0}, per pruner {SynFlow, GraSP, Rand, SNIP, SynFlow-L2, SynFlow-Dist, SynFlow-L2-Dist} for all models {fc-1000, ResNet-20, Wide-Resnet-20}

In [None]:
plt.rcParams["text.usetex"] = True
plt.rcParams['text.latex.preamble']=[r"\usepackage{bm}", r'\usepackage{amsmath}']

plot_dataframes(dims=["path_kernel"], 
                comp_list_list = [["0.0"], ["1.0"]], 
                dataset_list_list = [['cifar100'], ["cifar10"]], 
                model_list_list = [['fc-1000', 'resnet20', 'wide-resnet20']],
               pruner_list_list = [['synflow'], ['grasp'], ['rand'], ['snip'], ['synflow-dist'], ['synflow-l2'], ['synflow-dist-l2']])

# Plot: Weight Movement per dataset {CIFAR10/100}, per compression {0.0/1.0}, per model {fc-1000, ResNet-20, Wide-Resnet-20} for all pruners

In [None]:
plt.rcParams["text.usetex"] = True
plt.rcParams['text.latex.preamble']=[r"\usepackage{bm}"]

plot_dataframes(dims=["weight_movement_norm"], 
                comp_list_list = [["0.0"], ["1.0"]], 
                dataset_list_list = [['cifar100'], ["cifar10"]], 
                model_list_list = [['fc-1000'], ['resnet20'], ['wide-resnet20']],
               pruner_list_list = [['synflow', 'grasp', 'rand', 'snip', 'synflow-dist', 'synflow-l2', 'synflow-dist-l2']])

# Plot: Weight Movement per dataset {CIFAR10/100}, per compression {0.0/1.0}, per pruner {SynFlow, GraSP, Rand, SNIP, SynFlow-L2, SynFlow-Dist, SynFlow-L2-Dist} for all models {fc-1000, ResNet-20, Wide-Resnet-20}

In [None]:
plt.rcParams["text.usetex"] = True
plt.rcParams['text.latex.preamble']=[r"\usepackage{bm}"]

plot_dataframes(dims=["weight_movement_norm"], 
                comp_list_list = [["0.0"], ["1.0"]], 
                dataset_list_list = [['cifar100'], ["cifar10"]], 
                model_list_list = [['fc-1000', 'resnet20', 'wide-resnet20']],
               pruner_list_list = [['synflow'], ['grasp'], ['rand'], ['snip'], ['synflow-dist'], ['synflow-l2'], ['synflow-dist-l2']])

# Output across different models:

In [None]:
def get_sorted_values(directory):
    mean_dict = {}
    sum_dict = {}
    norm_dict = {}
    if os.path.exists(directory):
        # print(directory)
        for file in os.listdir(directory):
            if "_output" in file:
                # print(file)
                iteration = int(file.split("_")[0])
                output = np.load(f"{directory}/{file}")

                mean_dict[iteration] = output.mean()
                sum_dict[iteration] = output.sum()
                norm_dict[iteration] = np.linalg.norm(output)

    sorted_mean_dict = dict(sorted(mean_dict.items()))
    sorted_sum_dict = dict(sorted(sum_dict.items()))
    sorted_norm_dict = dict(sorted(norm_dict.items()))
    
    return sorted_mean_dict, sorted_sum_dict, sorted_norm_dict

In [None]:
def get_logit_dataframe(model_class_list=['lottery'],
                        model_list=['resnet20', 'wide-resnet20'],
                        pruner_list = ['synflow'],
                        dataset_list = ['cifar10', 'cifar100'],
                        seed_list = ['83', '1337'],
                        comp_list = ["0.0", "1.0"], 
                        root_dir="../Results/pruned/", op="Mean"):
    marker_list = [r'$1$', r'$2$', r'$3$', r'$4$', r'$5$', r'$6$', r'$7$', r'$8$', r'$9$', r'$a$', r'$b$', r'$c$', r'$d$', r'$e$', r'$f$', r'$g$', r'$h$']

    for model_class in model_class_list:
        for model in model_list:
            for dataset in dataset_list:
                for comp in comp_list:
                    df = pd.DataFrame(columns=[])
      
                    # print("In: ", curr_dir)
                    path_kernel_df = get_dataframe(dim='path_kernel', 
                                                   model_class_list=[model_class],
                                                   model_list=[model],
                                                   pruner_list=pruner_list,
                                                   dataset_list=[dataset],
                                                   comp_list=[comp],
                                                   seed_list = seed_list)
                    if len(path_kernel_df) == 0:
                        print("Continuing...")
                        continue


                    path_kernel_df = path_kernel_df.drop(path_kernel_df.index[0])
                    first_epoch_path_kernel_values = path_kernel_df.iloc[0]

                    min_path_kernel = first_epoch_path_kernel_values.min()
                    max_path_kernel = first_epoch_path_kernel_values.max()

                    norm_first_epoch = plt.Normalize(vmin=min_path_kernel, vmax=max_path_kernel)
                    colors = plt.cm.coolwarm(norm_first_epoch(first_epoch_path_kernel_values.values))
                    colors_list = [list(c) for c in colors]


                    for pruner in pruner_list:
                        curr_dir = root_dir + f"{model_class}/{model}/{pruner}/"
                        for seed in seed_list:
                            

                            dir_name =  f"output_{dataset}_{seed}_{comp}"
                            output_directory = curr_dir + dir_name
                            
                            mean_dict, sum_dict, norm_dict = get_sorted_values(output_directory)
                            if len(list(mean_dict.values())) > 0:
                                val = mean_dict.values()
                                #if op == "Sum":
                                #    val = sum_dict.values()
                                if op == "Norm":
                                    val = norm_dict.values()
                                
                                val = list(val)[:100]
                                if len(val) == 100:

                                    df[f"{model}--{pruner}--{dir_name}"] = val
                                else:
                                    pass
                                    # print("Not found: ", output_directory)

                if (len(df) > 0):

                    legend_columns = []
                    fig, ax = plt.subplots()
                    for col in df.columns:
                        model, pruner, output_dir = col.split('--')
                        legend_columns.append(f"{model_dict[model]}--{pruner_dict[pruner]}")
                    color_dict = {}
                    for idx, name in enumerate(df.columns):

                        color_dict[name] = mpl.colors.rgb2hex(colors_list[idx][:3])
                    marker_plot_dict = {}
                    marker_counter = 0
                    marker_symbol = '+'
                    line_plots = [] 
                    for idx, (name, data) in enumerate(df.iteritems()):

                        # name looks like: resnet20--synflow--output_cifar100_1337_0.0
                        model_name, pruner_name, output_dir_name = name.split('.')[0].split('--') # resnet20--synflow--output_cifar100_1337_0

                        comp_ratio = output_dir_name.split('_')[-1]
                        marker_plot_label = f"{model_dict[model_name]}, {pruner_dict[pruner_name]}"
                        #if comp_ratio == '1':
                        #    marker_plot_label += ", 10\% pruned"
                        if (marker_plot_label in marker_plot_dict):
                            marker = marker_plot_dict[marker_plot_label]
                            ax.plot(df[name], color = color_dict[name], marker=marker_symbol, markevery=15, ms=10)
                        else: 
                            marker_plot_dict[marker_plot_label] = marker_list[marker_counter]
                            marker_symbol = marker_plot_dict[marker_plot_label]
                            marker_counter += 1
                            ax.plot(df[name], color = color_dict[name], marker=marker_symbol, markevery=15, ms=10, label=marker_plot_label)
                    ax.legend()

                    cbax = fig.add_axes([0.91, 0.2, 0.01, 0.5])

                    pruning_str = "10\%"
                    if comp_list[0] == '0.0':
                        pruning_str = "0\%"
                    fig.colorbar(plt.cm.ScalarMappable(norm=norm_first_epoch, cmap=plt.cm.coolwarm), cax=cbax)
                    ax.set_title(label=f"Relative Change in {op} Outputs - Pruning {pruning_str}", fontsize=20)
                    ax.set_xlim(0, 100)
                    ax.set_xlabel("Epochs", fontsize=20)
                    # $f(\bm{\mathcal{X}},\bm{\theta}_t) - f(\bm{\mathcal{X}},\bm{\theta}_0)$
                    ax.set_ylabel(r"$f(\bm{\mathcal{X}},\bm{\theta}_t) - f(\bm{\mathcal{X}},\bm{\theta}_0)$", fontsize=20)
                    fig.savefig(root_dir + f"output---{op}---{'-'.join(model_list)}---{'-'.join(dataset_list)}---{'-'.join(pruner_list)}---{'-'.join(comp_list)}.png", dpi=fig.dpi)
                    plt.show()
                    

In [None]:
def plot_logit_dataframes(model_class_list_list = [["lottery"]],
                          model_list_list = [['resnet20', 'wide-resnet20']],
                          pruner_list_list = [['synflow', 'grasp', 'snip', 'synflow-dist', 'synflow-l2', 'synflow-dist-l2']],
                          dataset_list_list = [['cifar10'], ['cifar100']],
                          comp_list_list = [["0.0"], ["1.0"]],
                          seed_list_list = [['1337', '82' ,'821', '23', '923']],
                          root_dir="../Results/pruned/", op="Mean"):
    for model_class_list in model_class_list_list:
        for model_list in model_list_list:
            for pruner_list in pruner_list_list:
                for dataset_list in dataset_list_list:
                    for comp_list in comp_list_list: 
                        for seed_list in seed_list_list:
                            get_logit_dataframe(comp_list = comp_list, 
                                               dataset_list = dataset_list, 
                                               model_list = model_list,
                                               seed_list = seed_list,
                                               pruner_list = pruner_list,
                                               model_class_list = model_class_list,
                                               root_dir = root_dir, op=op)

# Plot Mean Output per dataset {CIFAR10/100}, per compression ratio {0.0/1.0}, per model {ResNet-20/Wide-ResNet-20} for all pruners

In [None]:
plt.rcParams["text.usetex"] = True
plt.rcParams['text.latex.preamble']=[r"\usepackage{bm}", r'\usepackage{amsmath}']

plot_logit_dataframes(comp_list_list = [["0.0"], ["1.0"]], 
                      dataset_list_list = [['cifar100'], ["cifar10"]], 
                      model_class_list_list = [["lottery"]], 
                      model_list_list = [['resnet20'], ['wide-resnet20']], op="Mean")

# Plot Norm Output per dataset {CIFAR10/100}, per compression ratio {0.0/1.0}, per model {ResNet-20/Wide-ResNet-20} for all pruners

In [None]:
plt.rcParams["text.usetex"] = True
plt.rcParams['text.latex.preamble']=[r"\usepackage{bm}", r'\usepackage{amsmath}']

plot_logit_dataframes(comp_list_list = [["0.0"], ["1.0"]], 
                      dataset_list_list = [['cifar100'], ["cifar10"]], 
                      model_class_list_list = [["lottery"]], 
                      model_list_list = [['resnet20'], ['wide-resnet20']], op="Norm")

# Plot Mean Output per dataset {CIFAR10/100}, per compression ratio {0.0/1.0}, per model {Fc-1000} for all pruners

In [None]:
plt.rcParams["text.usetex"] = True
plt.rcParams['text.latex.preamble']=[r"\usepackage{bm}", r'\usepackage{amsmath}']
plot_logit_dataframes(comp_list_list = [["0.0"], ["1.0"]], 
                      dataset_list_list = [['cifar100'], ["cifar10"]], 
                      model_class_list_list = [["default"]], 
                      model_list_list = [['fc-1000']], op="Mean")

# Plot Norm Output per dataset {CIFAR10/100}, per compression ratio {0.0/1.0}, per model {FC-1000} for all pruners

In [None]:
plt.rcParams["text.usetex"] = True
plt.rcParams['text.latex.preamble']=[r"\usepackage{bm}", r'\usepackage{amsmath}']
plot_logit_dataframes(comp_list_list = [["0.0"], ["1.0"]], 
                      dataset_list_list = [['cifar100'], ["cifar10"]], 
                      model_class_list_list = [["default"]], 
                      model_list_list = [['fc-1000']], op="Norm")

In [None]:
def get_logit_dataframe_per_pruner(model_class_list=['lottery'],
                        model_list=['resnet20', 'wide-resnet20'],
                        pruner_list = ['synflow'],
                        dataset_list = ['cifar10', 'cifar100'],
                        seed_list = ['83', '1337'],
                        comp_list = ["0.0", "1.0"], 
                        root_dir="../Results/pruned/", op="Mean"):
    marker_list = [r'$1$', r'$2$', r'$3$', r'$4$', r'$5$', r'$6$', r'$7$', r'$8$', r'$9$', r'$a$', r'$b$', r'$c$', r'$d$', r'$e$', r'$f$', r'$g$', r'$h$']

    for model_class in model_class_list:
        for pruner in pruner_list:
        
            for dataset in dataset_list:
                for comp in comp_list:
                    df = pd.DataFrame(columns=[])
      
                    # print("In: ", curr_dir)
                    path_kernel_df = get_dataframe(dim='path_kernel', 
                                                   model_class_list=[model_class],
                                                   model_list=model_list,
                                                   pruner_list=[pruner],
                                                   dataset_list=[dataset],
                                                   comp_list=[comp],
                                                   seed_list = seed_list)
                    if len(path_kernel_df) == 0:
                        print("Continuing...")
                        continue


                    path_kernel_df = path_kernel_df.drop(path_kernel_df.index[0])
                    first_epoch_path_kernel_values = path_kernel_df.iloc[0]

                    min_path_kernel = first_epoch_path_kernel_values.min()
                    max_path_kernel = first_epoch_path_kernel_values.max()

                    norm_first_epoch = plt.Normalize(vmin=min_path_kernel, vmax=max_path_kernel)
                    colors = plt.cm.coolwarm(norm_first_epoch(first_epoch_path_kernel_values.values))
                    colors_list = [list(c) for c in colors]


                    for model in model_list:
                        curr_dir = root_dir + f"{model_class}/{model}/{pruner}/"
                        for seed in seed_list:
                            

                            dir_name =  f"output_{dataset}_{seed}_{comp}"
                            output_directory = curr_dir + dir_name
                            
                            mean_dict, sum_dict, norm_dict = get_sorted_values(output_directory)
                            if len(list(mean_dict.values())) > 0:
                                val = mean_dict.values()
                                #if op == "Sum":
                                #    val = sum_dict.values()
                                if op == "Norm":
                                    val = norm_dict.values()
                                
                                val = list(val)[:100]
                                if len(val) == 100:

                                    df[f"{model}--{pruner}--{dir_name}"] = val
                                else:
                                    pass
                                    # print("Not found: ", output_directory)

                if (len(df) > 0):

                    legend_columns = []
                    fig, ax = plt.subplots()
                    for col in df.columns:
                        model, pruner, output_dir = col.split('--')
                        legend_columns.append(f"{model_dict[model]}--{pruner_dict[pruner]}")
                    color_dict = {}
                    for idx, name in enumerate(df.columns):

                        color_dict[name] = mpl.colors.rgb2hex(colors_list[idx][:3])


                    
                    marker_plot_dict = {}
                    marker_counter = 0
                    marker_symbol = '+'
                    line_plots = [] 
                    for idx, (name, data) in enumerate(df.iteritems()):

                        # name looks like: resnet20--synflow--output_cifar100_1337_0.0
                        model_name, pruner_name, output_dir_name = name.split('.')[0].split('--') # resnet20--synflow--output_cifar100_1337_0

                        comp_ratio = output_dir_name.split('_')[-1]
                        marker_plot_label = f"{model_dict[model_name]}, {pruner_dict[pruner_name]}"
                        #if comp_ratio == '1':
                        #    marker_plot_label += ", 10\% pruned"
                        if (marker_plot_label in marker_plot_dict):
                            marker = marker_plot_dict[marker_plot_label]
                            ax.plot(df[name], color = color_dict[name], marker=marker_symbol, markevery=15, ms=10)
                        else: 
                            marker_plot_dict[marker_plot_label] = marker_list[marker_counter]
                            marker_symbol = marker_plot_dict[marker_plot_label]
                            marker_counter += 1
                            ax.plot(df[name], color = color_dict[name], marker=marker_symbol, markevery=15, ms=10, label=marker_plot_label)


                    
                    
#                     for idx, (name, data) in enumerate(df.iteritems()):
#                         ax.plot(df[name], color = color_dict[name], marker=marker_list[idx], markevery=10, ms=10)


                    ax.legend()

                    cbax = fig.add_axes([0.91, 0.2, 0.01, 0.5])
                    
                    pruning_str = "10\%"
                    if comp_list[0] == '0.0':
                        pruning_str = "0\%"
                    fig.colorbar(plt.cm.ScalarMappable(norm=norm_first_epoch, cmap=plt.cm.coolwarm), cax=cbax)
                    ax.set_title(label=f"Relative Change in {op} Outputs - Pruning {pruning_str}", fontsize=20)
                    ax.set_xlim(0, 100)
                    ax.set_xlabel("Epochs", fontsize=20)
                    # $f(\bm{\mathcal{X}},\bm{\theta}_t) - f(\bm{\mathcal{X}},\bm{\theta}_0)$
                    ax.set_ylabel(r"$f(\bm{\mathcal{X}},\bm{\theta}_t) - f(\bm{\mathcal{X}},\bm{\theta}_0)$", fontsize=20)
                    #fig.savefig(root_dir + f"output---{op}---{'-'.join(model_list)}---{'-'.join(dataset_list)}---{'-'.join(pruner_list)}---{'-'.join(comp_list)}.png", dpi=fig.dpi)
                    plt.show()

In [None]:
def plot_logit_dataframes_per_pruner(model_class_list_list = [["lottery"]],
                          model_list_list = [['resnet20', 'wide-resnet20']],
                          pruner_list_list = [['synflow', 'grasp', 'snip', 'synflow-dist', 'synflow-l2', 'synflow-dist-l2']],
                          dataset_list_list = [['cifar10'], ['cifar100']],
                          comp_list_list = [["0.0"], ["1.0"]],
                          seed_list_list = [['1337', '82' ,'821', '23', '923']],
                          root_dir="../Results/pruned/", op="Mean"):
    for model_class_list in model_class_list_list:
        for model_list in model_list_list:
            for pruner_list in pruner_list_list:
                for dataset_list in dataset_list_list:
                    for comp_list in comp_list_list: 
                        for seed_list in seed_list_list:
                            get_logit_dataframe_per_pruner(comp_list = comp_list, 
                                               dataset_list = dataset_list, 
                                               model_list = model_list,
                                               seed_list = seed_list,
                                               pruner_list = pruner_list,
                                               model_class_list = model_class_list,
                                               root_dir = root_dir, op=op)

# Plot Mean Output per dataset {CIFAR10/100}, per compression ratio {0.0/1.0}, per pruner {SynFlow, GraSP, Rand, SNIP, SynFlow-L2, SynFlow-Dist, SynFlow-L2-Dist} for all models {FC-1000/ResNet-20/Wide-ResNet-20}

In [None]:
plt.rcParams["text.usetex"] = True
plt.rcParams['text.latex.preamble']=[r"\usepackage{bm}", r'\usepackage{amsmath}']

plot_logit_dataframes_per_pruner(comp_list_list = [["0.0"], ["1.0"]], 
                      dataset_list_list = [['cifar100'], ["cifar10"]], 
                      model_class_list_list = [["default", "lottery"]], 
                      pruner_list_list = [['synflow'], ['grasp'], [ 'snip'], ['synflow-dist'], ['synflow-l2'], ['synflow-dist-l2']],
                      model_list_list = [["fc-1000", 'resnet20', 'wide-resnet20']], op="Mean")

# Plot Norm Output per dataset {CIFAR10/100}, per compression ratio {0.0/1.0}, per pruner {SynFlow, GraSP, SNIP, SynFlow-L2, SynFlow-Dist, SynFlow-L2-Dist} for all models {FC-1000/ResNet-20/Wide-ResNet-20}

In [None]:
plt.rcParams["text.usetex"] = True
plt.rcParams['text.latex.preamble']=[r"\usepackage{bm}", r'\usepackage{amsmath}']

plot_logit_dataframes_per_pruner(comp_list_list = [["0.0"], ["1.0"]], 
                      dataset_list_list = [['cifar100'], ["cifar10"]], 
                      model_class_list_list = [["default", "lottery"]], 
                      pruner_list_list = [['synflow'], ['grasp'], [ 'snip'], ['synflow-dist'], ['synflow-l2'], ['synflow-dist-l2']],
                      model_list_list = [["fc-1000", 'resnet20', 'wide-resnet20']], op="Norm")