# Compute song information retained in small population of neurons

In [1]:
%matplotlib inline
from itertools import product as cproduct
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats
import sys

from disp import set_plot


cc = np.concatenate

FPS = 30.03  # sampling rate of behavioral data
DT = 1/FPS

STRAINS = ['NM91', 'ZH23']
STRAIN_KEY = '_'.join(STRAINS).lower()

FSTRAIN = 'data/simple/strains.csv'
MSTRAINS = [(pd.read_csv(FSTRAIN)['STRAIN'] == strain) for strain in STRAINS]
MSTRAIN = np.any(MSTRAINS, axis=0)
ISTRAIN = MSTRAIN.nonzero()[0]

PFX_NRL = f'data/simple/mlv/ma_built/nrl/ma_built'

In [2]:
# neural params
TAU_R = np.array([.1, .5, 1, 2, 5, 10, 30, 60, 120, 180, 240, 480, 600])
TAU_A = np.array([.1, .5, 1, 2, 5, 10, 30, 60, np.inf])
X_S = np.array([0, .5, 1])

tau_r_tau_a_x_s = np.array(list(cproduct(TAU_R, TAU_A, X_S)))

tau_rs = tau_r_tau_a_x_s[:, 0]
tau_as = tau_r_tau_a_x_s[:, 1]
x_ss = tau_r_tau_a_x_s[:, 2]
x_ps = 1 - x_ss  # pulse selectivity

nr = len(tau_rs)
R_COLS = [f'R_{ir}' for ir in range(nr)]

In [3]:
# load all neural recordings
dfs_tr = [np.load(f'{PFX_NRL}_tr_{itr}.npy', allow_pickle=True)[0]['df'] for itr in ISTRAIN]

# only keep neural activity after first song onset
it_song_starts = [np.nonzero(np.array(df_tr['Q']) == 0)[0][0] for df_tr in dfs_tr]
frs = cc([np.array(df_tr[R_COLS])[it_song_start:, :] for df_tr, it_song_start in zip(dfs_tr, it_song_starts)])

# normalize so that each firing rate lives between 0 and 1
for cnrn in range(frs.shape[1]):
    frs[:, cnrn] = frs[:, cnrn]/frs[:, cnrn].max()

dfs_tr = None

In [None]:
max_nrns = 12
hs = np.nan*np.zeros((max_nrns, nr))  # each row corresponds to addition of new neuron

inrns_chosen = []
hs_chosen = []

nbin = 16
bins = np.linspace(0, 1, nbin+1)

frs_keep = np.nan*np.zeros((max_nrns, len(frs)))

for cnrn in range(max_nrns):
    sys.stdout.write('\n>')
    
    # compute new entropy after adding each possible next candidate neuron
    for inrn in range(frs.shape[1]):
        sys.stdout.write('.')
        
        fr_new = frs[:, inrn].copy()
    
        frs_cand = cc([frs_keep[:cnrn, :], fr_new[None, :]], axis=0)
        
        cts_fr = np.histogramdd(frs_cand.T, bins=(cnrn+1)*[bins])[0]
        p_fr = cts_fr/cts_fr.sum()
        
        h = stats.entropy(p_fr.flatten(), base=2)
        hs[cnrn, inrn] = h
    
    # find best neuron to add
    inrn_best = np.argmax(hs[cnrn, :])
    inrns_chosen.append(inrn_best)
    hs_chosen.append(hs[cnrn, inrn_best])
    
    frs_keep[cnrn, :] = frs[:, inrn_best]
    
    print('\nBest neurons:')
    for inrn in inrns_chosen:
        print(f'{inrn}: TAU_R={tau_rs[inrn]}, TAU_A={tau_as[inrn]}, X_S={x_ss[inrn]}, X_P={x_ps[inrn]}')
    print('Joint entropies:', hs_chosen)
    
    np.save('data/simple/compression/hs_greedy.npy', np.array([{
        'HS_ALL': np.array(hs),
        'INRNS_CHOSEN': np.array(inrns_chosen),
        'HS_CHOSEN': np.array(hs_chosen),
        'TAU_R': tau_rs,
        'TAU_A': tau_as,
        'X_S': x_ss,
        'X_P': x_ps,
        'NBIN': nbin,
        'FRS': frs_keep.T,
    }]))


>...............................................................................................................................................................................................................................................................................................................................................................
Best neurons:
206: TAU_R=60.0, TAU_A=10.0, X_S=1.0, X_P=0.0
Joint entropies: [3.9855149848222235]

>...............................................................................................................................................................................................................................................................................................................................................................
Best neurons:
206: TAU_R=60.0, TAU_A=10.0, X_S=1.0, X_P=0.0
91: TAU_R=2.0, TAU_A=2.0, X_S=0.5, X_P=0.5
Joint entropies: [3.9855149848222235, 7.430837702047649]

>................................