In [1]:
# # header
import sys
sys.path.append(r"../")
from src.curves_utils import *

%load_ext autoreload
%autoreload 2

In [None]:
start_folder = r"../pretrained/curves"
results_folder, logger = startup_folders(start_folder, name=f"exp_curves")

In [None]:
model_params = load_dicts(start_folder, "model_params")
tasks = load_dicts(start_folder, "tasks")
train_params = load_dicts(start_folder, "train_params")
DeVice, num_workers, pin_memory = get_device()
logger.info(f"model_params: {model_params}")
logger.info(f"tasks: {tasks}")
logger.info(f"train_params: {train_params}")


In [5]:
# # setting up the tasks
tasks["CurveTracing"]["composer"] = CurveTracing
tasks["CurveTracing"]["datasets"] = []
tasks["CurveTracing"]["dataloaders"] = []


In [None]:
# datasets and dataloaders
DeVice, num_workers, pin_memory = get_device()
for o in tasks:
    tasks[o]["datasets"].append(tasks[o]["composer"](n_samples=2**14, **tasks[o]["params"]))
    tasks[o]["datasets"].append(tasks[o]["composer"](n_samples=2**10, **tasks[o]["params"]))
    tasks[o]["datasets"][-1].build_valid_test()
    tasks[o]["datasets"].append(tasks[o]["composer"](n_samples=2**10, **tasks[o]["params"]))
    tasks[o]["datasets"][-1].build_valid_test()
    tasks[o]["dataloaders"] = build_loaders(tasks[o]["datasets"], batch_size=train_params["batch_size"], num_workers=num_workers, pin_memory=pin_memory)


In [None]:
# create a blank model
model = AttentionModel(**model_params)
conductor = AttentionTrain(model, None, None, tasks, logger, results_folder)

# load states into the model
model_dir = os.path.join(start_folder, "model" + ".pth")
assert os.path.exists(model_dir), "Could not find the model.pth in the given dir!"
model.load_state_dict(torch.load(model_dir, map_location=DeVice))


In [None]:
# evaluating...
conductor.eval(DeVice, "test", False)

In [None]:
# decision accuracy
roelfsema_ = Roelfsema(model, tasks, logger)
roelfsema_.test_accuracy_curve(DeVice)

In [10]:
# plotting...
plot_all(10, model, tasks, results_folder, "_test", DeVice, logger, False, "test")

# Curve Tracing

## Prep

In [None]:
batch_size = 256
tasks["CurveTracing"]["datasets"][-1].training = False
this_dl = DataLoader(tasks["CurveTracing"]["datasets"][-1], batch_size=batch_size, shuffle=False)

target_composites, distractor_composites, masks, rec_fields, components = next(iter(this_dl))
target_composites = target_composites.to(DeVice)
distractor_composites = distractor_composites.to(DeVice)
masks = masks.to(DeVice)
rec_fields = rec_fields.to(DeVice)
components = components.to(DeVice)
both_composites = (components[:, 1:2] + components[:, 4:5]).clamp(0.0, 1.0)

print(target_composites.shape, distractor_composites.shape, masks.shape, rec_fields.shape, components.shape, both_composites.shape)


In [12]:
fix_attend_saccade = tasks["CurveTracing"]["params"]["fix_attend_saccade"]
n_iter = sum(fix_attend_saccade)
n_layers = model.n_convs
n_fix, n_att, n_sac = fix_attend_saccade
n_fix_att = n_fix + n_att
n_layers = model.n_convs

In [13]:
# get the receptive field
model.eval()
with torch.no_grad():
    # list_ind, base_act = get_rec_field_act(model, rec_fields)
    
    targets_ = get_activity(model, target_composites)
    distractors_ = get_activity(model, distractor_composites)
    tar_cue_ = get_activity(model, components[:, 0:1])
    dis_cue_ = get_activity(model, components[:, 3:4])
    both_ = get_activity(model, both_composites)[0]
    
    model.initiate_forward(batch_size=rec_fields.size(0))
    *_, receptive_ = model.for_forward(rec_fields)
    
    tmasks_, *_ = model(target_composites)
    dmasks_, *_ = model(distractor_composites)

## Plot doubles

In [14]:
how_many = 16
for i in range(how_many):
    plt.figure(figsize=(n_iter*3, 4*3))
    for j in range(n_iter):
        plt.subplot(4, n_iter, j + 1)
        plt.imshow(target_composites[i, j, 0].cpu(), cmap="gray")
        plt.axis("off")
        plt.subplot(4, n_iter, n_iter + j + 1)
        plt.imshow(tmasks_[i, j, 0].cpu(), cmap="plasma")
        plt.axis("off")
        plt.subplot(4, n_iter, 2*n_iter + j + 1)
        plt.imshow(distractor_composites[i, j, 0].cpu(), cmap="gray")
        plt.axis("off")
        plt.subplot(4, n_iter, 3*n_iter + j + 1)
        plt.imshow(dmasks_[i, j, 0].cpu(), cmap="plasma")
        plt.axis("off")
    plt.savefig(f"{results_folder}/plot_bunch_{i}.svg", format="svg")
    plt.close()

## Percentile

In [15]:
q = torch.linspace(0, 0.99, 100)
for i in range(n_iter):
    for j in range(n_layers):
        tar_q = torch.quantile(targets_[i][j][receptive_[j] > 0.0].ravel(), q.to(DeVice))
        dis_q = torch.quantile(distractors_[i][j][receptive_[j] > 0.0].ravel(), q.to(DeVice))
        mean_tar_q = torch.quantile(targets_[i][j][receptive_[j] > 0.0].ravel(), 0.5).cpu()
        mean_dis_q = torch.quantile(distractors_[i][j][receptive_[j] > 0.0].ravel(), 0.5).cpu()
        plt.figure(figsize=(6, 4))
        plt.title(f"Layer {j}, Iteration {i}, {mean_tar_q:.2f}, {mean_dis_q:.2f}")
        plt.plot(tar_q.cpu(), 100.0*q, c="r")
        plt.plot(dis_q.cpu(), 100.0*q, c="b")
        plt.arrow(mean_tar_q, 50, 0.0, -45, color='r', head_width=0.05, head_length=5, alpha=1.0, width=0.01)
        plt.arrow(mean_dis_q, 50, 0.0, -45, color='b', head_width=0.05, head_length=5, alpha=1.0, width=0.01)
        plt.ylim(0, 100)
        plt.xlim(0, max(tar_q.max().cpu().item(), mean_dis_q.max().cpu().item()))
        plt.savefig(os.path.join(results_folder, f"Percentile_layer_{j}_iter_{i}.svg"), format="svg")
        plt.close()
        # plt.show()

In [16]:
q = torch.linspace(0, 1.0, 51)
tar_m_ = [[] for _ in range(n_layers)]
dis_m_ = [[] for _ in range(n_layers)]
for j in range(n_layers):
    for i in range(n_fix, n_fix_att):
        if i == n_fix:
            tar_m_[j] = targets_[i][j][receptive_[j] > 0.0].ravel()
            dis_m_[j] = distractors_[i][j][receptive_[j] > 0.0].ravel()
        else:
            tar_m_[j] = torch.cat((tar_m_[j], targets_[i][j][receptive_[j] > 0.0].ravel()), dim=0)
            dis_m_[j] = torch.cat((dis_m_[j], distractors_[i][j][receptive_[j] > 0.0].ravel()), dim=0)

for j in range(n_layers):
    tar_q = torch.quantile(tar_m_[j], q.to(DeVice))
    dis_q = torch.quantile(dis_m_[j], q.to(DeVice))
    mean_tar_q = torch.quantile(tar_m_[j], 0.5).cpu()
    mean_dis_q = torch.quantile(dis_m_[j], 0.5).cpu()
    plt.figure(figsize=(6, 4))
    plt.title(f"Layer {j}, Iteration {n_fix}-{n_fix_att}, {mean_tar_q:.2f}, {mean_dis_q:.2f}")
    plt.plot(tar_q.cpu(), 100.0*q, c="r")
    plt.plot(dis_q.cpu(), 100.0*q, c="b")
    plt.arrow(mean_tar_q, 50, 0.0, -45, color='r', head_width=0.05, head_length=5, alpha=1.0, width=0.01)
    plt.arrow(mean_dis_q, 50, 0.0, -45, color='b', head_width=0.05, head_length=5, alpha=1.0, width=0.01)
    plt.ylim(0, 100)
    plt.xlim(-max(tar_q.max().cpu().item(), mean_dis_q.max().cpu().item())/5, max(tar_q.max().cpu().item(), mean_dis_q.max().cpu().item()))
    plt.savefig(os.path.join(results_folder, f"Percentile_layer_{j}.svg"), format="svg")
    plt.close()
    # plt.show()


## Curves

In [17]:
curve_tar_act = [[] for _ in range(model.n_convs)]
curve_dis_act = [[] for _ in range(model.n_convs)]
for j in range(n_layers):
    for i in range(n_iter):
        curve_tar_act[j].append(targets_[i][j][receptive_[j] > 0.0].mean().clone())
        curve_dis_act[j].append(distractors_[i][j][receptive_[j] > 0.0].mean().clone())
curve_tar_act = torch.tensor(curve_tar_act)
curve_dis_act = torch.tensor(curve_dis_act)
plot_curves(n_layers, curve_tar_act, curve_dis_act, results_folder, "Curve_layer")


In [18]:
curve_tar_act = [[] for _ in range(model.n_convs)]
curve_dis_act = [[] for _ in range(model.n_convs)]
for j in range(n_layers):
    for i in range(n_iter):
        curve_tar_act[j].append((targets_[i][j] - both_[j])[receptive_[j] > 0.0].mean().clone())
        curve_dis_act[j].append((distractors_[i][j] - both_[j])[receptive_[j] > 0.0].mean().clone())
curve_tar_act = torch.tensor(curve_tar_act)
curve_dis_act = torch.tensor(curve_dis_act)
plot_curves(n_layers, curve_tar_act, curve_dis_act, results_folder, "CurveDeBoth_layer")


In [19]:
curve_tar_act = [[] for _ in range(model.n_convs)]
curve_dis_act = [[] for _ in range(model.n_convs)]
for j in range(n_layers):
    for i in range(n_iter):
        curve_tar_act[j].append((targets_[i][j] - both_[j])[receptive_[j] > 0.0].mean().clone())
        curve_dis_act[j].append((distractors_[i][j] - both_[j])[receptive_[j] > 0.0].mean().clone())
curve_tar_act = torch.tensor(curve_tar_act)
curve_dis_act = torch.tensor(curve_dis_act)
curve_tar_act = curve_tar_act - curve_tar_act[:, :2].mean(dim=1, keepdim=True)
curve_dis_act = curve_dis_act - curve_dis_act[:, :2].mean(dim=1, keepdim=True)
plot_curves(n_layers, curve_tar_act, curve_dis_act, results_folder, "CurveDeBothDe_layer")


In [20]:
curve_tar_act = [[] for _ in range(model.n_convs)]
curve_dis_act = [[] for _ in range(model.n_convs)]
for j in range(n_layers):
    for i in range(n_iter):
        curve_tar_act[j].append((targets_[i][j] - both_[j])[receptive_[j] > 0.0].mean().clone())
        curve_dis_act[j].append((distractors_[i][j] - both_[j])[receptive_[j] > 0.0].mean().clone())
curve_tar_act = torch.tensor(curve_tar_act)
curve_dis_act = torch.tensor(curve_dis_act)
curve_tar_act = curve_tar_act - curve_tar_act[:, :2].mean(dim=1, keepdim=True)
curve_dis_act = curve_dis_act - curve_dis_act[:, :2].mean(dim=1, keepdim=True)
curve_tar_act = curve_tar_act / curve_tar_act.max(dim=1, keepdim=True).values
curve_dis_act = curve_dis_act / curve_tar_act.max(dim=1, keepdim=True).values
plot_curves(n_layers, curve_tar_act, curve_dis_act, results_folder, "CurveDeBothDeNorm_layer")


## Modulation

In [21]:
modulation = [[] for _ in range(n_iter)]
for i in range(n_iter):
    for j in range(n_layers):
        tr = targets_[i][j][receptive_[j] > 0.0]
        dr = distractors_[i][j][receptive_[j] > 0.0]
        tdr = ((tr - dr).abs() > 0.0)  # & (tr > 0.0) & (dr > 0.0)
        mi = modulation_index(tr[tdr], dr[tdr])
        mi = torch.nan_to_num(mi, nan=0.0, posinf=0.0, neginf=0.0)
        mi = mi[(mi.abs() > 1e-6) & (mi.abs() < 1.0)]
        modulation[i].append(mi)

for i in range(n_fix, n_fix_att):
    for j in range(n_layers):
        plt.figure(figsize=(6, 4))
        median = modulation[i][j].median().cpu().item()
        plt.hist(modulation[i][j].cpu(), bins=20, range=(-1, 1))
        ymax = plt.gca().get_ylim()[-1]
        plt.arrow(median, ymax, 0.0, -ymax/20, color='r', head_width=0.05, head_length=ymax/40, alpha=1.0, width=0.01)
        plt.title(f"Layer {j} Iter {i} Median: {median:.2f}")
        plt.xlim(-1, 1)
        plt.xlabel("Modulation Index")
        plt.ylabel("Number of Neurons")
        plt.savefig(os.path.join(results_folder, f"Modulation_layer_{j}_iter_{i}.svg"), format="svg")
        plt.close()

In [22]:
tar_m_ = [[] for _ in range(n_layers)]
dis_m_ = [[] for _ in range(n_layers)]
for j in range(n_layers):
    temp_tar = []
    temp_dis = []
    for i in range(n_fix, n_fix_att):
        temp_tar.append(targets_[i][j][receptive_[j] > 0.0].unsqueeze(0))
        temp_dis.append(distractors_[i][j][receptive_[j] > 0.0].unsqueeze(0))
    tar_m_[j] = torch.cat(temp_tar, dim=0)
    dis_m_[j] = torch.cat(temp_dis, dim=0)

    e = 1e-3
    a = (tar_m_[j] - dis_m_[j]).abs()
    b = (tar_m_[j]).prod(dim=0, keepdim=True)
    c = (dis_m_[j]).prod(dim=0, keepdim=True)
    d = (tar_m_[j] > e) & (dis_m_[j] > e) & (a > e) & (b > e) & (c > e) & (a < 1.0) & (a.prod(dim=0, keepdim=True) > e)
    tar_m_[j] = tar_m_[j] * d
    dis_m_[j] = dis_m_[j] * d
    tar_m_[j] = tar_m_[j][tar_m_[j] > 0.0]
    dis_m_[j] = dis_m_[j][dis_m_[j] > 0.0]


In [23]:
modulation = []
for j in range(n_layers):
    mi = modulation_index(tar_m_[j], dis_m_[j])
    mi = torch.nan_to_num(mi, nan=0.0, posinf=0.0, neginf=0.0)
    modulation.append(mi[(mi.abs() > 1e-3) & (mi.abs() < 1.0)])
for j in range(n_layers):
    plt.figure(figsize=(6, 4))
    median = modulation[j].median().cpu().item()
    plt.hist(modulation[j].cpu(), bins=20, range=(-1, 1))
    ymax = plt.gca().get_ylim()[-1]
    plt.arrow(median, ymax, 0.0, -ymax/20, color='r', head_width=0.05, head_length=ymax/40, alpha=1.0, width=0.01)
    plt.title(f"Layer {j} Iter {n_fix}-{n_fix_att} Median: {median:.2f}")
    plt.xlim(-1, 1)
    plt.xlabel("Modulation Index")
    plt.ylabel("Number of Neurons")
    plt.savefig(os.path.join(results_folder, f"Modulation_layer_{j}_iter_{n_fix}-{n_fix_att}.svg"), format="svg")
    plt.close()

In [24]:
modulation = [[] for _ in range(n_iter)]
for i in range(n_iter):
    for j in range(n_layers):
        tr = (targets_[i][j] - both_[j])[receptive_[j] > 0.0]
        dr = (distractors_[i][j] - both_[j])[receptive_[j] > 0.0]
        tdr = ((tr - dr).abs() > 0.0)  # & (tr > 0.0) & (dr > 0.0)
        mi = modulation_index(tr[tdr], dr[tdr])
        mi = torch.nan_to_num(mi, nan=0.0, posinf=0.0, neginf=0.0)
        mi = mi[(mi.abs() > 1e-6) & (mi.abs() < 1.0)]
        modulation[i].append(mi)

for i in range(n_fix, n_fix_att):
    for j in range(n_layers):
        plt.figure(figsize=(6, 4))
        median = modulation[i][j].median().cpu().item()
        plt.hist(modulation[i][j].cpu(), bins=20, range=(-1, 1))
        ymax = plt.gca().get_ylim()[-1]
        plt.arrow(median, ymax, 0.0, -ymax/20, color='r', head_width=0.05, head_length=ymax/40, alpha=1.0, width=0.01)
        plt.title(f"Layer {j} Iter {i} Median: {median:.2f}")
        plt.xlim(-1, 1)
        plt.xlabel("Modulation Index")
        plt.ylabel("Number of Neurons")
        plt.savefig(os.path.join(results_folder, f"ModulationDe_layer_{j}_iter_{i}.svg"), format="svg")
        plt.close()


In [25]:
tar_m_ = [[] for _ in range(n_layers)]
dis_m_ = [[] for _ in range(n_layers)]
for j in range(n_layers):
    for i in range(n_fix, n_fix_att):
        tr = (targets_[i][j] - both_[j])[receptive_[j] > 0.0]
        dr = (distractors_[i][j] - both_[j])[receptive_[j] > 0.0]
        if i == n_fix:
            tar_m_[j] = tr[((tr - dr).abs() > 0.0)]
            dis_m_[j] = dr[((tr - dr).abs() > 0.0)]
        else:
            tar_m_[j] = torch.cat((tar_m_[j], tr[((tr - dr).abs() > 0.0)]), dim=0)
            dis_m_[j] = torch.cat((dis_m_[j], dr[((tr - dr).abs() > 0.0)]), dim=0)

modulation = []
for j in range(n_layers):
    mi = modulation_index(tar_m_[j], dis_m_[j])
    mi = torch.nan_to_num(mi, nan=0.0, posinf=0.0, neginf=0.0)
    modulation.append(mi[(mi.abs() > 1e-6) & (mi.abs() < 1.0)])

for j in range(n_layers):
    plt.figure(figsize=(6, 4))
    median = modulation[j].median().cpu().item()
    plt.hist(modulation[j].cpu(), bins=20, range=(-1, 1))
    ymax = plt.gca().get_ylim()[-1]
    plt.arrow(median, ymax, 0.0, -ymax/20, color='r', head_width=0.05, head_length=ymax/40, alpha=1.0, width=0.01)
    plt.title(f"Layer {j} Iter {n_fix}-{n_fix_att} Median: {median:.2f}")
    plt.xlim(-1, 1)
    plt.xlabel("Modulation Index")
    plt.ylabel("Number of Neurons")
    plt.savefig(os.path.join(results_folder, f"ModulationDe_layer_{j}_iter_{n_fix}-{n_fix_att}.svg"), format="svg")
    plt.close()


# Invariant Tuning

In [26]:
theta_res = 180
stimulie, tar_cues, dis_cues, tar_stim = make_stimuli(128, 128, theta_res, bar=True, ds=this_dl.dataset)
stimulie, tar_cues, dis_cues = (x.unsqueeze(1) for x in (stimulie, tar_cues, dis_cues))
tccsss = torch.cat([*(tar_cues for _ in range(n_fix)), *(stimulie for _ in range(n_att))], dim=1)
dccsss = torch.cat([*(dis_cues for _ in range(n_fix)), *(stimulie for _ in range(n_att))], dim=1)
model.to(DeVice)
stimulie = stimulie.to(DeVice)
tar_cues = tar_cues.to(DeVice)
dis_cues = dis_cues.to(DeVice)
tar_stim = tar_stim.to(DeVice)
tccsss = tccsss.to(DeVice)
dccsss = dccsss.to(DeVice)
both_stim = tccsss[:, -1:]

In [27]:
with torch.no_grad():
    tmasks_, *_ = model(tccsss)
    dmasks_, *_ = model(dccsss)

In [None]:
bi = 45
plt.figure(figsize=(12, 3))
plt.subplot(141)
plt.imshow(stimulie[bi][0][0].cpu())
plt.axis("off")
plt.subplot(142)
plt.imshow(tar_cues[bi][0][0].cpu())
plt.axis("off")
plt.subplot(143)
plt.imshow(dis_cues[bi][0][0].cpu())
plt.axis("off")
plt.subplot(144)
plt.imshow(tar_stim[bi][0].cpu())
plt.axis("off")
plt.show()

In [None]:
bi = 45
plt.figure(figsize=(n_fix_att*2, 2*2))
for i in range(n_fix_att):
    plt.subplot(2, n_fix_att, i+1)
    plt.imshow(tmasks_[bi, i, 0].detach().cpu(), vmax=1.0, vmin=-1.0, cmap="plasma")
    plt.axis("off")
    plt.subplot(2, n_fix_att, i+1+n_fix_att)
    plt.imshow(dmasks_[bi, i, 0].detach().cpu(), vmax=1.0, vmin=-1.0, cmap="plasma")
    plt.axis("off")
plt.show()

In [30]:
# list_ind, base_act = get_rec_field_act(model, tar_stim, e=0)
targets_ = get_activity(model, tccsss)  # iter x layer x batch x channel x h x w
distractors_ = get_activity(model, dccsss)
both_ = get_activity(model, both_stim)
single_ = get_activity(model, tar_stim)[0]
cue_tar = get_activity(model, tccsss[:, :n_fix])
cue_dis = get_activity(model, dccsss[:, :n_fix])

rec_fields = (tar_stim - tccsss[:, 0]).clamp(0.0, 1.0)
model.initiate_forward(batch_size=rec_fields.size(0))
*_, receptive_ = model.for_forward(rec_fields)


cue_tar_tns = [[] for _ in range(n_layers)]
cue_dis_tns = [[] for _ in range(n_layers)]
for layer_ in range(n_layers):
    for iter_ in range(n_fix):
        if iter_ == 0:
            cue_tar_tns[layer_] = cue_tar[0][layer_].unsqueeze(0)
            cue_dis_tns[layer_] = cue_dis[0][layer_].unsqueeze(0)
        else:
            cue_tar_tns[layer_] = torch.cat([cue_tar_tns[layer_], cue_tar[iter_][layer_].unsqueeze(0)], dim=0)
            cue_dis_tns[layer_] = torch.cat([cue_dis_tns[layer_], cue_dis[iter_][layer_].unsqueeze(0)], dim=0)
    cue_tar_tns[layer_] = cue_tar_tns[layer_].mean(dim=0)
    cue_dis_tns[layer_] = cue_dis_tns[layer_].mean(dim=0)


fit_gaussian = FitBellCurve()


In [None]:
e = 1e-3
iter_ = 4
layer_ = 6
for iter_ in range(n_fix, n_fix_att):
    for layer_ in range(model.n_convs):
        tar_i_l = targets_[iter_][layer_]  # - cue_tar_tns[layer_]
        dis_i_l = distractors_[iter_][layer_]  # - cue_tar_tns[layer_]
        sin_i_l = single_[layer_]
        both_i_l = both_[0][layer_]
        base_i_l = receptive_[layer_]

        s_tar_i_l = tar_i_l.diff(dim=0).abs().sum(dim=0, keepdim=True)
        s_dis_i_l = dis_i_l.diff(dim=0).abs().sum(dim=0, keepdim=True)
        s_sin_i_l = sin_i_l.diff(dim=0).abs().sum(dim=0, keepdim=True)
        s_both_i_l = both_i_l.diff(dim=0).abs().sum(dim=0, keepdim=True)
        s_base_i_l = base_i_l.diff(dim=0).abs().sum(dim=0, keepdim=True)

        s_tar_dis_i_l = (tar_i_l - dis_i_l).abs().sum(dim=0, keepdim=True)
        # s_i_l = (s_tar_i_l > e) & (s_dis_i_l > e) & (s_sin_i_l > e) & (s_tar_dis_i_l > e)
        s_i_l = (s_tar_i_l > e) & (s_dis_i_l > e) & (s_base_i_l > e) & (s_tar_dis_i_l > e)

        p_s_i_l = s_i_l.permute(1, 2, 3, 0)

        p_tar_i_l = tar_i_l.permute(1, 2, 3, 0)[p_s_i_l[..., 0]].detach().cpu() # .numpy()
        p_dis_i_l = dis_i_l.permute(1, 2, 3, 0)[p_s_i_l[..., 0]].detach().cpu() # .numpy()
        p_sin_i_l = sin_i_l.permute(1, 2, 3, 0)[p_s_i_l[..., 0]].detach().cpu() # .numpy()
        p_both_i_l = both_i_l.permute(1, 2, 3, 0)[p_s_i_l[..., 0]].detach().cpu() # .numpy()
        p_base_i_l = base_i_l.permute(1, 2, 3, 0)[p_s_i_l[..., 0]].detach().cpu() # .numpy()

        p_tar_i_l = torch.tensor(gaussian_filter1d(p_tar_i_l, 2, axis=-1))
        p_dis_i_l = torch.tensor(gaussian_filter1d(p_dis_i_l, 2, axis=-1))
        p_sin_i_l = torch.tensor(gaussian_filter1d(p_sin_i_l, 2, axis=-1))
        p_both_i_l = torch.tensor(gaussian_filter1d(p_both_i_l, 2, axis=-1))
        p_base_i_l = torch.tensor(gaussian_filter1d(p_base_i_l, 2, axis=-1))

        good_tar_ = []
        good_dis_ = []
        good_ones = []
        for i in range(p_tar_i_l.size(0)):
            (ta, tb, tc, td), te = fit_gaussian(p_tar_i_l[i])
            (da, db, dc, dd), de = fit_gaussian(p_dis_i_l[i])
            if te < 0.01 and de < 0.01:
                good_ones.append(i)
                good_tar_.append([ta, tb, tc, td])
                good_dis_.append([da, db, dc, dd])
                if i%100 == 0:
                    polar_plot(p_tar_i_l[i], p_dis_i_l[i], 180, results_folder, f"_good_L{layer_}_i{iter_}_N{i}")
            else:
                if i%100 == 0:
                    polar_plot(p_tar_i_l[i], p_dis_i_l[i], 180, results_folder, f"_bad_L{layer_}_i{iter_}_N{i}")
        logger.info(f"iter: {iter_}, layer: {layer_}, all: {s_i_l.numel()}, good: {s_i_l.sum().item()}, fit: {len(good_ones)}")

        good_tar_ = torch.tensor(good_tar_)
        good_dis_ = torch.tensor(good_dis_)
        for i, n in enumerate(("Amp", "Asymp", "Width", "Pref")):
            i = modulation_index(good_tar_[:, i], good_dis_[:, i])
            plt.figure(figsize=(4, 3))
            plt.title(f"Layer: {layer_}, Iter: {iter_} Plot: {n} Median: {i.median().item():.2f}")
            plt.hist(i, bins=min(20, max(len(i)//10, 20)), range=(-1, 1))
            ymax = plt.gca().get_ylim()[-1]
            plt.arrow(i.median().item(), ymax, 0.0, -ymax/20, color='r', head_width=0.05, head_length=ymax/40, alpha=1.0, width=0.01)

            plt.savefig(os.path.join(results_folder, f"AAA_TuningNorm_{n}_layer_{layer_}_iter_{iter_}.svg"), format="svg")
            plt.close()


In [32]:
e = 1e-3
iter_ = 4
layer_ = 6
tar_i_l = targets_[iter_][layer_]  # - cue_tar_tns[layer_]
dis_i_l = distractors_[iter_][layer_]  # - cue_tar_tns[layer_]
sin_i_l = single_[layer_]
both_i_l = both_[0][layer_]
base_i_l = receptive_[layer_]

s_tar_i_l = tar_i_l.diff(dim=0).abs().sum(dim=0, keepdim=True)
s_dis_i_l = dis_i_l.diff(dim=0).abs().sum(dim=0, keepdim=True)
s_sin_i_l = sin_i_l.diff(dim=0).abs().sum(dim=0, keepdim=True)
s_both_i_l = both_i_l.diff(dim=0).abs().sum(dim=0, keepdim=True)
s_base_i_l = base_i_l.diff(dim=0).abs().sum(dim=0, keepdim=True)

s_tar_dis_i_l = (tar_i_l - dis_i_l).abs().sum(dim=0, keepdim=True)
# s_i_l = (s_tar_i_l > e) & (s_dis_i_l > e) & (s_sin_i_l > e) & (s_tar_dis_i_l > e)
s_i_l = (s_tar_i_l > e) & (s_dis_i_l > e) & (s_base_i_l > e) & (s_tar_dis_i_l > e)

p_s_i_l = s_i_l.permute(1, 2, 3, 0)

p_tar_i_l = tar_i_l.permute(1, 2, 3, 0)[p_s_i_l[..., 0]].detach().cpu() # .numpy()
p_dis_i_l = dis_i_l.permute(1, 2, 3, 0)[p_s_i_l[..., 0]].detach().cpu() # .numpy()
p_sin_i_l = sin_i_l.permute(1, 2, 3, 0)[p_s_i_l[..., 0]].detach().cpu() # .numpy()
p_both_i_l = both_i_l.permute(1, 2, 3, 0)[p_s_i_l[..., 0]].detach().cpu() # .numpy()
p_base_i_l = base_i_l.permute(1, 2, 3, 0)[p_s_i_l[..., 0]].detach().cpu() # .numpy()

p_tar_i_l = torch.tensor(gaussian_filter1d(p_tar_i_l, 2, axis=-1))
p_dis_i_l = torch.tensor(gaussian_filter1d(p_dis_i_l, 2, axis=-1))
p_sin_i_l = torch.tensor(gaussian_filter1d(p_sin_i_l, 2, axis=-1))
p_both_i_l = torch.tensor(gaussian_filter1d(p_both_i_l, 2, axis=-1))
p_base_i_l = torch.tensor(gaussian_filter1d(p_base_i_l, 2, axis=-1))

good_tar_ = []
good_dis_ = []
good_ones = []
for i in range(p_tar_i_l.size(0)):
    (ta, tb, tc, td), te = fit_gaussian(p_tar_i_l[i])
    (da, db, dc, dd), de = fit_gaussian(p_dis_i_l[i])
    if te < 0.01 and de < 0.01:
        good_ones.append(i)
        good_tar_.append([ta, tb, tc, td])
        good_dis_.append([da, db, dc, dd])
        if i%50 == 0:
            polar_plot(p_tar_i_l[i], p_dis_i_l[i], 180, results_folder, f"_c_good_L{layer_}_i{iter_}_N{i}")
    else:
        if i%50 == 0:
            polar_plot(p_tar_i_l[i], p_dis_i_l[i], 180, results_folder, f"_c_bad_L{layer_}_i{iter_}_N{i}")
