In [1]:
import numpy as np
import pandas as pd
import pickle as p
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import os

In [2]:
dir = "../results3"

In [3]:
files = os.listdir(dir)

In [4]:
architecture = np.array([128, 10])

In [5]:
def avg_usage(rates, architecture):
    weights = architecture / np.sum(architecture)
    avg = np.dot(rates, weights)
    return avg

In [6]:
avg_fn = lambda x: avg_usage(x, architecture)

In [7]:
results = [p.load(open(os.path.join(dir, f), "rb")) for f in files]

In [9]:
"random_removal" in results[0].keys()

True

In [10]:
"random_removal" in results[1].keys()

False

In [14]:
results[0]

{'mask angle': 0.45,
 'cross inhibit': 0.0,
 'random_removal': 0.0,
 'accuracy': DeviceArray(0.261, dtype=float32),
 'matrix usage': array([0.25466648, 0.01686115], dtype=float32),
 'accuracy_spiking': DeviceArray([0.38509998, 0.5864    , 0.6306    , 0.634     , 0.6203    ,
              0.5675    , 0.4554    , 0.3102    , 0.17      , 0.09999999,
              0.09999999], dtype=float32),
 'firing rates': array([0.12215113, 0.05233386])}

In [28]:
def reduce_result(result):
    reduced = {}
    normal_keys = ["mask angle", "cross inhibit", "random_removal", "accuracy"]

    for k in normal_keys:
        if k in result.keys():
            reduced[k] = float(result[k])
        else:
            reduced[k] = 0.0

    if reduced["mask angle"] > 0.0:
        reduced["reduction"] = "mask angle"
    elif reduced["cross inhibit"] > 0.0:
        reduced["reduction"] = "cross inhibit"
    elif reduced["random_removal"] > 0.0:
        reduced["reduction"] = "random"
    else:
        reduced["reduction"] = "none"
    
    avg_keys = ["matrix usage", "firing rates"]
    for k in avg_keys:
        reduced[k] = avg_fn(result[k])

    #reduced["accuracy spiking"] = float(np.max(result["accuracy_spiking"]))
    reduced["accuracy spiking"] = float(np.max(result["accuracy_spiking"][5]))
    
    return reduced

In [29]:
reduce_result(results[0])

{'mask angle': 0.45,
 'cross inhibit': 0.0,
 'random_removal': 0.0,
 'accuracy': 0.26100000739097595,
 'reduction': 'mask angle',
 'matrix usage': 0.23743420761024606,
 'firing rates': 0.11709190400843882,
 'accuracy spiking': 0.5674999952316284}

In [30]:
reduced_results = list(map(reduce_result, results))

In [31]:
df = pd.DataFrame(reduced_results)

In [32]:
ma_z = df["mask angle"] == 0
ci_z = df["cross inhibit"] == 0
rnd_z = df["random_removal"] == 0

In [33]:
ma_i = ci_z * rnd_z
ci_i = ma_z * rnd_z
rnd_i = ma_z * ci_z
ctrl_i = ci_z * rnd_z * ma_z

In [34]:
ma = df[ma_i].sort_values("accuracy spiking")
ci = df[ci_i].sort_values("accuracy spiking")
rnd = df[rnd_i].sort_values("accuracy spiking")

In [35]:
baseline = df[ctrl_i]

In [36]:
spk_a = baseline["accuracy spiking"].iloc[0]

In [37]:
spk_a

0.8468999862670898

In [38]:
fig3 = go.Figure()

fig3.add_trace(go.Scatter(x= 1 - ma["firing rates"], y=ma["accuracy spiking"] / spk_a, name="Explicit"))
fig3.add_trace(go.Scatter(x= 1 - ci["firing rates"], y=ci["accuracy spiking"] / spk_a, name="Inhibitory"))
fig3.add_trace(go.Scatter(x= 1 - rnd["firing rates"], y=rnd["accuracy spiking"] / spk_a, name="Random"))

fig3.update_layout(title="Relative Accuracy by Sparsity",
                  xaxis_title="Sparsity",
                  yaxis_title="Relative Accuracy, Spiking",
                  yaxis_range = [0,1.1],
                  xaxis_range = [-0.05, 1.05],
                  width = 900,
                  height = 600,
                  font = dict({"size": 16,}))

In [26]:
px.scatter(ma, x = "mask angle", y = "accuracy spiking")

In [27]:
px.scatter(df, x="firing rates", y="accuracy spiking", color="reduction")