## Figure: bibiplots

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import seaborn as sns

import sys
sys.path.append('../code')
import sparseRRR

In [2]:
def preprocess(data):
    X = data['Firing rate']
    X = X - np.mean(X, axis=0)
    X = X / np.std(X, axis=0)

    Y = data['Locomotion']
    Y = Y - np.mean(Y, axis=0)
    Y = Y / np.std(Y, axis=0)
    
    return X,Y

In [3]:
def adjustlabels(fig, labels, max_iter=1000, eps=0.01, delta=0.1):
    N = len(labels)
    widths = np.zeros(N)
    heights = np.zeros(N)
    centers = np.zeros((N, 2))
    for i,l in enumerate(labels):
        bb = l.get_window_extent(renderer=fig.canvas.get_renderer())
        bb = bb.transformed(plt.gca().transData.inverted())
        widths[i] = bb.width
        heights[i] = bb.height
        centers[i] = (bb.min + bb.max)/2

    for i in range(max_iter):
        stop = True
        for a in range(N):
            for b in range(N):
                if ((a!=b) and
                    (np.abs(centers[a,0]-centers[b,0]) < (widths[a]+widths[b])/2 + delta) and
                    (np.abs(centers[a,1]-centers[b,1]) < (heights[a]+heights[b])/2 +  delta)):
                    
                    d = centers[a] - centers[b]
                    centers[a] += d * eps
                    centers[b] -= d * eps
                    labels[a].set_position(centers[a])
                    labels[b].set_position(centers[b])
                    stop = False
        if stop:
            break

In [4]:
slice = np.linspace(0, 10000, 10000, dtype=int)

In [5]:
data = pickle.load(open('../data/purkinje.pickle', 'rb'))
X,Y = preprocess(data)
print('Shape of X:', X.shape, '\nShape of Y:', Y.shape)
w,v = sparseRRR.relaxed_elastic_rrr(X[slice,:], Y[slice,:], rank=2, alpha=.26, l1_ratio=1)
print('{} neurons selected:'.format(np.sum(w[:,0]!=0)))

Shape of X: (2691968, 105) 
Shape of Y: (2691968, 4)
13 neurons selected:


### Movie

In [None]:
def adjustlabels(fig, ax, labels, max_iter=1000, eps=0.01, delta=0.1):
    N = len(labels)
    widths = np.zeros(N)
    heights = np.zeros(N)
    centers = np.zeros((N, 2))
    for i,l in enumerate(labels):
        bb = l.get_window_extent(renderer=fig.canvas.get_renderer())
        bb = bb.transformed(ax.transData.inverted())
        widths[i] = bb.width
        heights[i] = bb.height
        centers[i] = (bb.min + bb.max)/2

    for i in range(max_iter):
        stop = True
        for a in range(N):
            for b in range(N):
                if ((a!=b) and
                    (np.abs(centers[a,0]-centers[b,0]) < (widths[a]+widths[b])/2 + delta) and
                    (np.abs(centers[a,1]-centers[b,1]) < (heights[a]+heights[b])/2 +  delta)):
                    
                    d = centers[a] - centers[b]
                    centers[a] += d * eps
                    centers[b] -= d * eps
                    labels[a].set_position(centers[a])
                    labels[b].set_position(centers[b])
                    stop = False
        if stop:
            break

In [52]:
xylim=3.9
scaleFactor=3.5
s=2
time_samples=slice[::10][:5]
L = np.corrcoef(np.concatenate((Zy[:,:2], Y[slice,:]), axis=1), rowvar=False)[2:,:2]

In [68]:
for i, tp in enumerate(slice[::100][1:]):
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    ax1 = axes[0]
    ax2 = axes[1]
    
    # Plotting on ax1
    ax1.scatter(Zx[0:tp, 0], Zx[0:tp, 1], c=range(tp), cmap='viridis', s=1)
    labels = []
    L = np.corrcoef(np.concatenate((Zx[:, :2], X[slice, :]), axis=1), rowvar=False)[2:, :2]
    for j in np.where(w[:, 0] != 0)[0]:
        ax1.plot([0, scaleFactor * L[j, 0]], [0, scaleFactor * L[j, 1]], linewidth=.75, color=[.4, .4, .4], zorder=1)
        t = ax1.text(scaleFactor * L[j, 0], scaleFactor * L[j, 1], data['cell_names'][j], 
                     ha='center', va='center', color='k', fontsize=6,
                     bbox=dict(facecolor='w', edgecolor='#777777', boxstyle='round', linewidth=.5, pad=.2))
        labels.append(t)
    adjustlabels(fig, ax1, labels)
    circ = plt.Circle((0, 0), radius=scaleFactor, color=[.4, .4, .4], fill=False, linewidth=.5)
    ax1.add_patch(circ)
    
    # Plotting on ax2
    ax2.scatter(Zy[0:tp, 0], Zy[0:tp, 1], c=range(tp), cmap='viridis', s=1)
    labels = []
    L = np.corrcoef(np.concatenate((Zy[:, :2], Y[slice, :]), axis=1), rowvar=False)[2:, :2]
    for j in range(Y.shape[1]):
        ax2.plot([0, scaleFactor * L[j, 0]], [0, scaleFactor * L[j, 1]], linewidth=.75, color=[.4, .4, .4], zorder=1)
        t = ax2.text(scaleFactor * L[j, 0], scaleFactor * L[j, 1], data['locomotion_names'][j], 
                     ha='center', va='center', color='k', fontsize=6,
                     bbox=dict(facecolor='w', edgecolor='#777777', boxstyle='round', linewidth=.5, pad=.2))
        labels.append(t)
    adjustlabels(fig, ax2, labels)
    circ = plt.Circle((0, 0), radius=scaleFactor, color=[.4, .4, .4], fill=False, linewidth=.5)
    ax2.add_patch(circ)

    # Setting limits and aspect ratio
    for ax in axes:
        ax.set_xlim([-xylim, xylim])
        ax.set_ylim([-xylim, xylim])
        ax.set_aspect('equal', adjustable='box')
        ax.set_xticks([])
        ax.set_yticks([])
    
    sns.despine(left=True, bottom=True)
    
    # Save and close the figure
    plt.savefig(f'./movie/bibiplot{i}.png')
    plt.close(fig)

In [69]:
# Directory where images are saved
image_dir = './movie/'

# List of image filenames
images = []
for i in range(len(slice[::100][1:])):
    filename = f'bibiplot{i}.png'
    images.append(imageio.imread(os.path.join(image_dir, filename)))

# Save as GIF
output_file = 'bibiplot_animation.gif'
imageio.mimsave(output_file, images, duration=0.5)  # Adjust duration as needed

  images.append(imageio.imread(os.path.join(image_dir, filename)))
