In [None]:
import numpy as np
import pickle
import csv
import torch
from numpy.array_api import int32
from torch.linalg import inv, eig, pinv
from matplotlib import pyplot as plt
from tools import whiten, adap_whitening, adap_whitening_2
from sklearn import svm, metrics
from sklearn.decomposition import PCA
from datetime import datetime
import sklearn
import pickle
from scipy.ndimage import gaussian_filter1d
from tools import load, split, estimate_derivative, plot_two_intervals

In [None]:
def find_blocks(labels, max_len=20, ignore=0):
    """
    Split a 1D array of integer labels into contiguous blocks.
    - Blocks with label == `ignore` (default 0) are skipped.
    - Long runs are split into chunks of at most `max_len`.
    Returns: list of (start_idx, end_idx, label) with end_idx exclusive.
    """
    labels = np.asarray(labels)
    n = len(labels)
    if n == 0:
        return []

    blocks = []
    start = 0
    prev = labels[0]

    # walk + flush on change (and once at the end)
    for i in range(1, n + 1):
        cur = labels[i] if i < n else None
        if cur != prev:
            if prev != ignore:
                run_start, run_end, lab = start, i, int(prev)
                # chunk the run to respect max_len
                s = run_start
                while s < run_end:
                    e = min(s + max_len, run_end)
                    blocks.append((s, e, lab))
                    s = e
            start = i
            prev = cur

    return blocks

In [None]:
from matplotlib import font_manager as fm, rcParams

path = r"/home/p308270/.local/share/fonts/Helvetica.ttf"  # or .otf
fm.fontManager.addfont(path)
rcParams["font.family"] = fm.FontProperties(fname=path).get_name()

rng = np.random.default_rng(42)  # for reproducibility

In [None]:
n_hd = 10000
n_out = 3
k = 25
n_train = 225
w_teacher = 1.
filename = '1_600_20'

grid_k = np.arange(10,50,5)
grid_p = np.arange(1,10,1)

grid_n_fold = 2

sensor_data, sequence, times_sec, sequence_sec = load(filename, reduced=True)
d_sensor_data = np.apply_along_axis(estimate_derivative, axis=0, arr=sensor_data)
sensor_data = np.hstack((sensor_data, d_sensor_data))

# baseline = np.mean(sensor_data[:300], axis=0)  # Add baseline substraction
# sensor_data = (sensor_data - baseline)

In [None]:
params = {'k': [], 'p': [], 'n_fold': []}
results = {'train_acc': [], 'test_acc': [], 'y_pred': [], 'y_true': []}

for k in grid_k:
    for p in grid_p:
        for n_fold in range(grid_n_fold):
            for key in params.keys():
                params[key].append(locals()[key])

            x_dense = sensor_data
            n_dense = x_dense.shape[1]

            labels = np.zeros_like(times_sec)
            for i, t in enumerate(sequence_sec[:n_train]):
                try:
                    flag = (times_sec > sequence_sec[i]) & (times_sec < sequence_sec[i+1])
                except IndexError:
                    flag = (times_sec > sequence_sec[i])
                labels[flag] = int(sequence[i][1])

            idx_last_flag = np.where(labels != 0)[0][-1]

            W_hd = np.random.binomial(n=1, p=0.05, size=(n_hd, n_dense))  #Test random sparse weights
            x_hd = x_dense @ W_hd.T
            z_hd = np.where(np.argsort(x_hd)<k, 1., 0)
            W_out = np.zeros((n_out, n_hd))
            W = np.zeros((n_out, n_hd))

            z_out_train = np.zeros((z_hd.shape[0],  n_out))
            for i, row in enumerate(z_hd[:idx_last_flag]):
                if labels[i] != 0:
                    active_idx = np.flatnonzero(row)
                    to_flip = active_idx[rng.random(active_idx.size) < p]     # Bernoulli(p) per active index# indices where z_hd==1
                    W_out[int(labels[i])-1, to_flip] = 1./k


                out = row @ W_out.T
                z_out_train[i] = out

            z_out_acc = np.zeros((z_hd.shape[0],  n_out))
            for i, row in enumerate(z_hd):
                out = row @ W_out.T
                z_out_acc[i] = out

            z_wta = np.where(np.argsort(z_out_acc, axis=1)<1, 1., 0)

            z_pred = np.zeros_like(sequence_sec)
            z_true = np.zeros_like(sequence_sec)
            for i, t in enumerate(sequence_sec):
                try:
                    flag = (times_sec > sequence_sec[i]) & (times_sec < sequence_sec[i+1])
                except IndexError:
                    flag = (times_sec > sequence_sec[i])
                z_pred[i] = np.argsort(np.sum(z_out_acc[flag], axis=0))[-1] + 1
                z_true[i] = sequence[i][1]

            train_acc = sklearn.metrics.accuracy_score(z_true[:n_train], z_pred[:n_train])
            test_acc = sklearn.metrics.accuracy_score(z_true[n_train:], z_pred[n_train:])
            results['train_acc'].append(train_acc)
            results['test_acc'].append(test_acc)
            results['y_pred'].append(z_pred.copy())
            results['y_true'].append(z_true.copy())

            print(f'k: {k}, p: {p}, n_fold: {n_fold}')
            print(f'Train accuracy: {train_acc:.4f}, Test accuracy: {test_acc:.4f}')

for key in params.keys():
    params[key] = np.array(params[key])
for k in ['train_acc', 'test_acc']:
    results[k] = np.array(results[k])
# keep y_pred / y_true as lists-of-arrays (ragged OK) or cast to object arrays:
results['y_pred'] = np.array(results['y_pred'], dtype=object)
results['y_true'] = np.array(results['y_true'], dtype=object)

data = {'params': params, 'results': results}

with open('data/gridsearch_stochastic.pkl', 'wb') as f:
    pickle.dump(data, f)

In [None]:
n_hd = 10000
n_out = 3
k = 25
p = 1./k
n_train = 225
w_teacher = 1.

x_dense = sensor_data
n_dense = x_dense.shape[1]

labels = np.zeros_like(times_sec)
for i, t in enumerate(sequence_sec[:n_train]):
    try:
        flag = (times_sec > sequence_sec[i] + t_training_delay) & (times_sec < sequence_sec[i+1])
    except IndexError:
        flag = (times_sec > sequence_sec[i] + t_training_delay)
    labels[flag] = int(sequence[i][1])

idx_last_flag = np.where(labels != 0)[0][-1]

W_hd = np.random.binomial(n=1, p=0.05, size=(n_hd, n_dense))  #Test random sparse weights
x_hd = x_dense @ W_hd.T
z_hd = np.where(np.argsort(x_hd)<k, 1., 0)
W_out = np.zeros((n_out, n_hd))
W = np.zeros((n_out, n_hd))

z_out_train = np.zeros((z_hd.shape[0],  n_out))
for i, row in enumerate(z_hd[:idx_last_flag]):
    if labels[i] != 0:
        active_idx = np.flatnonzero(row)
        to_flip = active_idx[rng.random(active_idx.size) < p]     # Bernoulli(p) per active index# indices where z_hd==1
        W_out[int(labels[i])-1, to_flip] = 1./k


    out = row @ W_out.T
    z_out_train[i] = out

z_out_acc = np.zeros((z_hd.shape[0],  n_out))
for i, row in enumerate(z_hd):
    out = row @ W_out.T
    z_out_acc[i] = out

z_wta = np.where(np.argsort(z_out_acc, axis=1)<1, 1., 0)

z_pred = np.zeros_like(sequence_sec)
z_true = np.zeros_like(sequence_sec)
for i, t in enumerate(sequence_sec):
    try:
        flag = (times_sec > sequence_sec[i] + t_training_delay) & (times_sec < sequence_sec[i+1])
    except IndexError:
        flag = (times_sec > sequence_sec[i] + t_training_delay)
    z_pred[i] = np.argsort(np.sum(z_out_acc[flag], axis=0))[-1] + 1
    z_true[i] = sequence[i][1]

train_acc = sklearn.metrics.accuracy_score(z_true[:n_train], z_pred[:n_train])
test_acc = sklearn.metrics.accuracy_score(z_true[n_train:], z_pred[n_train:])

z_out = np.empty((z_hd.shape[0],  n_out))
z_out[:idx_last_flag] = z_out_train[:idx_last_flag]
z_out[idx_last_flag:] = z_out_acc[idx_last_flag:]

print(f'k: {k}, n_pot: {n_pot}, t_training_delay: {t_training_delay}')
print(f'Train accuracy: {train_acc:.4f}, Test accuracy: {test_acc:.4f}')

In [None]:
# ----- config -----
top_intervals_idx = [(0, 10), (215, 225)]   # left & middle (sequence-index windows)
j0, j1 = 235, 245                            # right "Test" (sequence-index)
max_len = 20
sigma = 2.
savepath = 'figs/hd_out'

# connector + slash styling
connector_color   = 'black'
connector_lw      = 1.5
connector_ls      = ':'      # dotted
slash_color       = 'black'
slash_lw          = 1.5
slash_len_fig     = 0.015
slash_angle_deg   = 80
slash_inset       = 0.003
# right slashes are NOT mirrored

# ----- style -----
plt.rcParams.update({
    "font.size": 12,
    "axes.labelsize": 12,
    "axes.titlesize": 12,
    "legend.fontsize": 10
})
cm = plt.get_cmap('tab10')

# ----- per-sample block labels over full recording -----
colour = np.zeros_like(times_sec, dtype=int)
for i in range(len(sequence_sec)):
    t_start = sequence_sec[i]
    t_end = sequence_sec[i + 1] if i + 1 < len(sequence_sec) else np.inf
    mask = (times_sec >= t_start) & (times_sec < t_end)
    colour[mask] = int(sequence[i][1])

# ----- helpers -----
def seq_idx_window_to_sample_idx(a_idx, b_idx):
    a_idx = int(np.clip(a_idx, 0, len(sequence_sec) - 1))
    b_idx = int(np.clip(b_idx, 0, len(sequence_sec) - 1))
    t0_sec = sequence_sec[a_idx]
    t1_sec = sequence_sec[b_idx]
    t0 = int(np.abs(times_sec - t0_sec).argmin())
    t1 = int(np.abs(times_sec - t1_sec).argmin())
    if t1 <= t0:
        t1 = min(t0 + 1, len(times_sec))
    return t0, t1

def find_blocks(labels, max_len=20, ignore=0):
    labels = np.asarray(labels)
    n = len(labels)
    if n == 0: return []
    blocks, start, prev = [], 0, labels[0]
    for i in range(1, n + 1):
        cur = labels[i] if i < n else None
        if cur != prev:
            if prev != ignore:
                run_start, run_end, lab = start, i, int(prev)
                s = run_start
                while s < run_end:
                    e = min(s + max_len, run_end)
                    blocks.append((s, e, lab))  # [s, e) end-exclusive
                    s = e
            start, prev = i, cur
    return blocks

# Use midpoints between samples as span edges -> no visual gaps
def block_edges_from_indices(x, s, e):
    if s <= 0:
        L = x[0]
    else:
        L = 0.5 * (x[s-1] + x[s])
    if e >= len(x):
        R = x[-1]
    else:
        R = 0.5 * (x[e-1] + x[e])
    return L, R

# ----- figure (1x3) -----
fig, ax = plt.subplots(1, 3, figsize=(12, 1.5),
                       sharey=False,
                       gridspec_kw={'wspace': 0.25})
ax_l, ax_m, ax_r = ax[0], ax[1], ax[2]

# windows (sample indices)
t0_a, t1_a       = seq_idx_window_to_sample_idx(*top_intervals_idx[0])  # left
t0_b, t1_b       = seq_idx_window_to_sample_idx(*top_intervals_idx[1])  # middle
t0_test, t1_test = seq_idx_window_to_sample_idx(j0, j1)                 # right

# GLOBAL origin: first time from LEFT window
x0_ref = times_sec[t0_a]

def plot_interval_time_x(ah, t0, t1, title=None, hide_y=False):
    # time axis shifted by the SAME origin (x0_ref)
    x_raw = times_sec[t0:t1]
    x = x_raw - x0_ref

    labels_local = colour[t0:t1]
    blocks = find_blocks(labels_local, max_len=max_len)

    # traces vs time
    for i in range(z_out.shape[1]):
        smoothed = gaussian_filter1d(z_out[t0:t1, i], sigma=sigma)
        ah.plot(x, smoothed, label=f'Neuron {i+1}', color=cm(i % cm.N), lw=1.5, alpha=0.85)

    # shaded blocks (midpoint edges => no gaps)
    for s, e, lab in blocks:
        if lab == 0:
            continue
        L, R = block_edges_from_indices(x, s, e)
        if R > L:
            ah.axvspan(L, R, facecolor=cm((lab-1) % cm.N), alpha=0.15, linewidth=0)

    # axes styling
    for side in ['top','right','left','bottom']:
        ah.spines[side].set_visible(side in ['left','bottom'])
        ah.spines[side].set_linewidth(1.5)
    ah.tick_params(axis='both', which='both', labelsize=12, width=1.5)
    ah.set_xlim(x[0], x[-1])
    ah.set_ylim(0., 1.)

    if not hide_y:
        ah.set_yticks([0, 1])
    else:
        ah.set_yticks([])
        ah.set_ylabel("")
        ah.spines['left'].set_visible(False)

    if title:
        ah.text(0.99, 1.06, title, transform=ah.transAxes,
                va="center", ha="right", fontsize=14)

# ----- plot panels (single row) -----
plot_interval_time_x(ax_l, t0_a,   t1_a,   title=None,   hide_y=False)  # left
plot_interval_time_x(ax_m, t0_b,   t1_b,   title="Train", hide_y=True)  # middle (hide y)
plot_interval_time_x(ax_r, t0_test,t1_test,title="Test",  hide_y=True) # right

# labels
ax_l.set_ylabel("Activity (a.u.)", fontsize=12)
ax_m.set_xlabel("Time (s)", fontsize=12)

# legend (from left), outside on the right
handles, labels = ax_l.get_legend_handles_labels()
fig.legend(handles, labels, frameon=False, loc="center left",
           bbox_to_anchor=(0.92, 0.80), ncol=1)

# ----- black dotted connectors + slashes (centered on axis; equal above/below) -----
from matplotlib.lines import Line2D

# helper: bottom-spine anchor at x∈[0,1] in *axes* coords → figure coords
def bottom_anchor_in_fig(ax, x_in_axes):
    pt_disp = ax.transAxes.transform((x_in_axes, 0.0))            # axes → display
    return fig.transFigure.inverted().transform(pt_disp)          # display → figure

# exact anchors at spine ends
pL_right = bottom_anchor_in_fig(ax_l, 1.0)
pM_left  = bottom_anchor_in_fig(ax_m, 0.0)
pM_right = bottom_anchor_in_fig(ax_m, 1.0)
pR_left  = bottom_anchor_in_fig(ax_r, 0.0)
pR_right = bottom_anchor_in_fig(ax_r, 1.0)

# geometry (identical everywhere)
theta  = np.deg2rad(slash_angle_deg)
dx_fig = slash_len_fig * np.cos(theta)  # total horizontal span of the slash
dy_fig = slash_len_fig * np.sin(theta)  # total vertical span of the slash
stub   = 0.04                           # dotted continuation length (figure coords)

def draw_centered_slash(anchor_xy):
    """Draw a slash centered on the anchor, spanning equally above/below the axis."""
    xA, yA = anchor_xy
    fig.add_artist(Line2D([xA - dx_fig/2, xA + dx_fig/2],
                          [yA - dy_fig/2, yA + dy_fig/2],
                          transform=fig.transFigure, lw=slash_lw,
                          color=slash_color, clip_on=False))

def draw_gap(p_left, p_right):
    """Two centered slashes at both anchors + dotted connector between their endpoints."""
    xL, yL = p_left
    xR, yR = p_right
    # slashes (centered, same orientation)
    draw_centered_slash((xL, yL))
    draw_centered_slash((xR, yR))
    # dotted connector from right end of left slash to left end of right slash (baseline = left y)
    fig.add_artist(Line2D([xL + dx_fig/2, xR - dx_fig/2], [yL, yL],
                          transform=fig.transFigure, linestyle=connector_ls,
                          lw=connector_lw, color=connector_color, clip_on=False))

# Left ↔ Middle
draw_gap(pL_right, pM_left)

# Middle ↔ Right
draw_gap(pM_right, pR_left)

# Right-edge: centered slash at the exact right end + dotted continuation
xE, yE = pR_right
draw_centered_slash((xE, yE))
fig.add_artist(Line2D([xE + dx_fig/2, xE + dx_fig/2 + stub], [yE, yE],
                      transform=fig.transFigure, linestyle=connector_ls,
                      lw=connector_lw, color=connector_color, clip_on=False))


plt.savefig(f'{savepath}.pdf', bbox_inches='tight')
plt.savefig(f'{savepath}.png', dpi=200, bbox_inches='tight')

plt.show()
