In [1]:
###
# The goal of this notebook is to:
# - Take two neural populations
# - Compute the JS divergence between stimuli pairs for each population (the same stimuli pairs)
# - Compute the mutual information between the distributions of JS divergences

import glob
import os
from importlib import reload
import pickle
import datetime

import numpy as np
import scipy as sp
import pandas as pd
import h5py as h5
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt
%matplotlib inline

import neuraltda.topology2 as tp2
import neuraltda.spectralAnalysis as sa
import neuraltda.simpComp as sc
import pycuslsa as pyslsa

daystr = datetime.datetime.now().strftime('%Y%m%d')
figsavepth = '/home/brad/DailyLog/'+daystr+'/'
print(figsavepth)

  from ._conv import register_converters as _register_converters


/home/brad/DailyLog/20190314/


In [8]:
# Set up birds and block_paths
birds = ['B1083', 'B1056', 'B1235', 'B1075']
bps = {'B1083': '/home/brad/krista/B1083/P03S03/', 'B1075': '/home/brad/krista/B1075/P01S03/',
       'B1235': '/home/brad/krista/B1235/P02S01/', 'B1056': '/home/brad/krista/B1056/klusta/phy020516/Pen01_Lft_AP100_ML1300__Site03_Z2500__B1056_cat_P01_S03_1/',
       'B1056': '/home/brad/krista/B1056/klusta/phy020516/Pen01_Lft_AP100_ML1300__Site03_Z2500__B1056_cat_P01_S03_1/',
       'B1083-5': '/home/brad/krista/B1083/P03S05/'}


learned_stimuli = {'B1083': ['M_scaled_burung', 'N_scaled_burung', 'O_scaled_burung', 'P_scaled_burung'], 'B1056': ['A_scaled_burung', 'B_scaled_burung', 'C_scaled_burung', 'D_scaled_burung'], 'B1235': [], 'B1075': []}
peck_stimuli = {'B1083': {'L': ['N_40k','P_40k'], 'R': ['M_40k', 'O_40k']}, 'B1056': {'L': ['B_scaled_burung', 'D_scaled_burung'], 'R': ['A_scaled_burung', 'C_scaled_burung']}, 
                'B1235': {'L': ['F_scaled_burung', 'H_scaled_burung'], 'R': ['E_scaled_burung', 'G_scaled_burung'],}, 'B1075': {'L': ['F_40k', 'H_40k'], 'R': ['E_40k', 'G_40k']},
               'B1083-5': {'L': ['N_40k','P_40k'], 'R': ['M_40k', 'O_40k']}}

unfamiliar_stimuli = {'B1083': ['I_40k', 'J_40k', 'K_40k', 'L_40k'], 
                      'B1083-5': ['I_40k', 'J_40k', 'K_40k', 'L_40k'],
                      'B1235': ['A_scaled_burung', 'B_scaled_burung', 'C_scaled_burung', 'D_scaled_burung'], 
                      'B1075': ['A_40k', 'B_40k', 'C_40k', 'D_40k'], 
                      'B1056': ['E_scaled_burung', 'F_scaled_burung', 'G_scaled_burung', 'H_scaled_burung']
                     }

#bps =  {'B1056': '/home/AD/btheilma/krista/B1056/klusta/phy020516/Pen01_Lft_AP100_ML1300__Site03_Z2500__B1056_cat_P01_S03_1/',
#        'B1235': '/home/AD/btheilma/krista/B1235/P02S01/'}
#test_birds = ['B1056', 'B1235']
#test_birds = ['B1075', 'B1235']
#test_birds = ['B1056', 'B1235']
#test_birds =['B1056', 'B1083']
#test_birds = ['B1083']
#test_birds = ['B1083', 'B1083-5']
test_birds = ['B1056', 'B1235', 'B1083', 'B1083-5']
test_birds = ['B1056']
# Binning Parameters
windt = 10.0                      # milliseconds
dtovr = 0.5*windt                 # milliseconds
segment_info = [0, 0]             # use full Trial
cluster_group = ['Good']          # use just good clusters
comment = 'JS_MI_TEST'            # BootStrap Populations
bdfs = {}                         # Dictionary to store bdf

In [9]:
# Loop through each bird in our list and bin the data
for bird in test_birds:
    block_path = bps[bird]
    bfdict = tp2.dag_bin(block_path, windt, segment_info, cluster_group=cluster_group, dt_overlap=dtovr, comment=comment)
    bdf = glob.glob(os.path.join(bfdict['raw'], '*.binned'))[0]
    print(bdf)
    bdfs[bird] = bdf

/home/brad/krista/B1056/klusta/phy020516/Pen01_Lft_AP100_ML1300__Site03_Z2500__B1056_cat_P01_S03_1/binned_data/win-10.0_dtovr-5.0_seg-0-0-JS_MI_TEST/20180525T180834Z-10.0-5.0.binned


In [10]:
# extract left vs right stims
# extract population tensors for the populations of interest
# Do not sort the stims
population_tensors_familiar = {}
stimuli = []

for bird in test_birds:
    stimuli = peck_stimuli[bird]['L'] + peck_stimuli[bird]['R']
    print(stimuli)
    bdf = bdfs[bird]
    population_tensors_familiar[bird] = []
    # open the binned data file
    with h5.File(bdf, 'r') as f:
        #stimuli = f.keys()
        print(list(f.keys()))
        for stim in stimuli:
            poptens = np.array(f[stim]['pop_tens'])
            population_tensors_familiar[bird].append([poptens, stim])

['B_scaled_burung', 'D_scaled_burung', 'A_scaled_burung', 'C_scaled_burung']
['A_scaled_burung', 'B_scaled_burung', 'C_scaled_burung', 'D_scaled_burung', 'E_scaled_burung', 'F_scaled_burung', 'G_scaled_burung', 'H_scaled_burung', 'I_scaled_burung', 'J_scaled_burung', 'K_scaled_burung', 'L_scaled_burung', 'M_scaled_burung', 'N_scaled_burung', 'O_scaled_burung', 'P_scaled_burung']


In [11]:
# extract Unfamiliar stims
# extract population tensors for the populations of interest
# Do not sort the stims
population_tensors_unfamiliar = {}
stimuli = []

for bird in test_birds:
    stimuli = unfamiliar_stimuli[bird]
    print(stimuli)
    bdf = bdfs[bird]
    population_tensors_unfamiliar[bird] = []
    # open the binned data file
    with h5.File(bdf, 'r') as f:
        #stimuli = f.keys()
        print(list(f.keys()))
        for stim in stimuli:
            poptens = np.array(f[stim]['pop_tens'])
            population_tensors_unfamiliar[bird].append([poptens, stim])

['E_scaled_burung', 'F_scaled_burung', 'G_scaled_burung', 'H_scaled_burung']
['A_scaled_burung', 'B_scaled_burung', 'C_scaled_burung', 'D_scaled_burung', 'E_scaled_burung', 'F_scaled_burung', 'G_scaled_burung', 'H_scaled_burung', 'I_scaled_burung', 'J_scaled_burung', 'K_scaled_burung', 'L_scaled_burung', 'M_scaled_burung', 'N_scaled_burung', 'O_scaled_burung', 'P_scaled_burung']


In [26]:
# flatten the list of population tensors for each population
threshold = 6

def threshold_poptens(tens, thresh):
    ncell, nwins, ntrials = tens.shape
    frs = np.mean(tens, axis=1)
    tfr = thresh*frs
    tfrtens = np.tile(tfr[:, np.newaxis, :], (1, nwins, 1))
    bintens = 1*np.greater(tens, tfrtens)
    return bintens

def shuffle_binmat(binmat):
    ncells, nwin = binmat.shape
    for i in range(ncells):
        binmat[i, :] = np.random.permutation(binmat[i, :])
    return binmat

def get_JS(i, j, Li, Lj, speci, specj, beta):
    js = (i, j, sc.sparse_JS_divergence2_fast(Li, Lj, speci, specj, beta))
    print((i, j))
    return js

def get_Lap(trial_matrix, sh):
    if sh == 'shuffled':
        mat = shuffle_binmat(trial_matrix)
    else:
        mat = trial_matrix
    ms = sc.binarytomaxsimplex(trial_matrix, rDup=True)
    scg1 = sc.simplicialChainGroups(ms)
    L = sc.sparse_laplacian(scg1, dim)
    return L



In [13]:

poptens = {'familiar': population_tensors_familiar, 'unfamiliar': population_tensors_unfamiliar}

In [None]:
# Compute JS popA:
#Left vs right

dim = 1
beta = 1.0

for bird in test_birds:
    for sh in ['original', 'shuffled']:
        for fam in ['familiar', 'unfamiliar']:
            ntrials = 10 # Only do half the trials for each stim
            bird_tensors = poptens[fam][bird]
            SCG = []
            spectra = []
            for bird_tensor, stim in bird_tensors:
                binmatlist = []
                print(bird, stim)
                ncells, nwin, _ = bird_tensor.shape
                bin_tensor = threshold_poptens(bird_tensor, threshold)
                for trial in tqdm(range(ntrials)):
                    if sh == 'shuffled':
                        binmatlist.append(shuffle_binmat(bin_tensor[:, :, trial]))
                    else:
                        binmatlist.append(bin_tensor[:, :, trial])
                    ms = sc.binarytomaxsimplex(bin_tensor[:, :, trial], rDup=True)
                    scg1 = sc.simplicialChainGroups(ms)
                    SCG.append(scg1)
                    L = sc.sparse_laplacian(scg1, dim)
                    rho = sc.sparse_density_matrix(L, beta)
                    r = sc.sparse_density_spectrum(rho)
                    spectra.append(r)
            N = len(SCG)
            # compute density matrices
            
            jsmat = np.zeros((N, N))
            for i in tqdm(range(N)):
                for j in range(i, N):
                    jsmat[i, j] = sc.sparse_reconciled_spectrum_JS(spectra[i], spectra[j])
            with open(os.path.join(figsavepth, 'JSpop_{}-{}-{}-{}_LvsR-{}-{}'.format(bird, dim, beta, ntrials, fam, sh)), 'wb') as f:
                pickle.dump(jsmat, f)
            


In [None]:
# mirroring cuda code
#Left vs right
reload(sc)
dim = 1
beta = 3

betas = [1]

for bird in test_birds:
    for sh in ['original', 'shuffled']:
        for fam in ['familiar', 'unfamiliar']:
            ntrials = 20 # Only do half the trials for each stim
            bird_tensors = poptens[fam][bird]
            SCG = []
            spectra = []
            laplacians = []
            print('Computing Laplacians for {} {} {}...'.format(bird, sh, fam))
            for bird_tensor, stim in bird_tensors:
                binmatlist = []
                print(bird, stim)
                ncells, nwin, _ = bird_tensor.shape
                bin_tensor = threshold_poptens(bird_tensor, threshold)
                laplacians.append(Parallel(n_jobs=23)(delayed(get_Lap)(bin_tensor[:, :, trial], sh) for trial in range(ntrials)))
            laplacians = sum(laplacians, [])
            N = len(laplacians)
            # compute spectra
            print('Computing Spectra... {} total laplacians'.format(N))
            spectra = Parallel(n_jobs=23)(delayed(sc.sparse_spectrum)(L) for L in laplacians)

            # compute density matrices
            pairs = [(i, j) for i in range(N) for j in range(i, N)]
            for beta in betas:
                print('Computing JS Divergences with beta {}...'.format(beta))
                jsmat = np.zeros((N, N))
                
                jsdat = Parallel(n_jobs=23)(delayed(get_JS)(i, j, laplacians[i], laplacians[j], spectra[i], spectra[j], beta) for (i, j) in pairs)
                for d in jsdat:
                    jsmat[d[0], d[1]] = d[2]
            
                with open(os.path.join(figsavepth, 'JSpop_fast_{}-{}-{}-{}_LvsR-{}-{}.pkl'.format(bird, dim, beta, ntrials, fam, sh)), 'wb') as f:
                    pickle.dump(jsmat, f)


Computing Laplacians for B1056 original familiar...
B1056 B_scaled_burung
B1056 D_scaled_burung
B1056 A_scaled_burung
B1056 C_scaled_burung
Computing Spectra... 80 total laplacians
Computing JS Divergences with beta 1...
(0, 1)
(0, 3)
(0, 0)
(0, 15)
(0, 13)
(0, 8)
(0, 11)
(0, 14)
(0, 17)
(0, 4)
(0, 18)
(0, 19)
(0, 6)
(0, 16)
(0, 9)
(0, 2)
(0, 10)
(0, 7)
(0, 5)
(0, 12)
(0, 22)
(0, 21)
(0, 36)
(0, 35)
(0, 23)
(0, 40)
(0, 26)
(0, 37)
(0, 25)
(0, 39)
(0, 30)
(0, 38)
(0, 41)
(0, 24)
(0, 33)
(0, 31)
(0, 47)
(0, 27)
(0, 52)
(0, 28)
(0, 20)
(0, 42)
(0, 32)
(0, 45)
(0, 44)
(0, 51)
(0, 34)
(0, 58)
(0, 54)
(0, 57)
(0, 53)
(0, 59)
(0, 29)
(0, 49)
(0, 50)
(0, 48)
(0, 55)
(0, 43)
(0, 56)
(0, 46)
(0, 61)
(0, 71)
(1, 1)
(0, 77)
(0, 73)
(0, 75)
(0, 69)
(0, 79)
(1, 2)
(0, 74)
(0, 60)
(0, 65)
(0, 64)
(1, 3)
(0, 62)
(1, 4)
(0, 72)
(0, 68)
(0, 63)
(0, 70)
(0, 78)
(1, 9)
(0, 67)
(0, 76)
(1, 8)
(1, 6)
(1, 10)
(1, 14)
(1, 13)
(1, 11)
(1, 7)
(1, 19)
(1, 18)
(0, 66)
(1, 16)
(1, 17)
(1, 15)
(1, 5)
(1, 12)
(1, 35

(13, 40)
(13, 26)
(13, 38)
(13, 31)
(13, 41)
(13, 39)
(13, 51)
(13, 37)
(13, 27)
(13, 24)
(13, 33)
(13, 52)
(13, 32)
(13, 42)
(13, 58)
(13, 34)
(13, 45)
(13, 44)
(13, 47)
(13, 28)
(13, 59)
(13, 53)
(13, 57)
(13, 54)
(13, 20)
(13, 50)
(13, 29)
(13, 43)
(13, 55)
(13, 49)
(13, 48)
(13, 56)
(13, 46)
(13, 61)
(13, 71)
(14, 14)
(13, 65)
(13, 79)
(13, 60)
(13, 73)
(13, 74)
(13, 77)
(13, 72)
(13, 69)
(13, 70)
(13, 63)
(13, 75)
(14, 17)
(13, 62)
(14, 19)
(13, 64)
(14, 18)
(14, 16)
(13, 68)
(13, 78)
(13, 67)
(14, 15)
(14, 21)
(13, 76)
(13, 66)
(14, 22)
(14, 36)
(14, 23)
(14, 35)
(14, 39)
(14, 40)
(14, 25)
(14, 38)
(14, 33)
(14, 37)
(14, 41)
(14, 27)
(14, 26)
(14, 30)
(14, 28)
(14, 31)
(14, 24)
(14, 34)
(14, 32)
(14, 52)
(14, 42)
(14, 20)
(14, 47)
(14, 57)
(14, 54)
(14, 51)
(14, 59)
(14, 58)
(14, 53)
(14, 45)
(14, 44)
(14, 50)
(14, 29)
(14, 48)
(14, 43)
(14, 49)
(14, 55)
(14, 46)
(14, 56)
(14, 69)
(15, 16)
(14, 71)
(14, 67)
(14, 61)
(14, 79)
(14, 60)
(14, 62)
(15, 15)
(14, 77)
(14, 65)
(14, 75)
(

(28, 55)
(28, 53)
(28, 54)
(28, 56)
(28, 58)
(28, 57)
(28, 59)
(28, 65)
(28, 60)
(28, 62)
(28, 61)
(28, 63)
(28, 69)
(28, 64)
(28, 66)
(28, 68)
(28, 67)
(28, 70)
(29, 30)
(28, 71)
(28, 73)
(28, 74)
(28, 72)
(28, 75)
(28, 78)
(28, 76)
(28, 77)
(28, 79)
(29, 29)
(29, 33)
(29, 37)
(29, 32)
(29, 34)
(29, 35)
(29, 36)
(29, 31)
(29, 39)
(29, 41)
(29, 38)
(29, 42)
(29, 40)
(29, 45)
(29, 43)
(29, 46)
(29, 44)
(29, 48)
(29, 47)
(29, 49)
(29, 51)
(29, 50)
(29, 53)
(29, 55)
(29, 52)
(29, 54)
(29, 56)
(29, 58)
(29, 57)
(29, 61)
(29, 60)
(29, 64)
(30, 31)
(29, 59)
(29, 67)
(29, 62)
(29, 63)
(29, 65)
(29, 66)
(29, 70)
(29, 71)
(29, 68)
(30, 30)
(29, 69)
(29, 73)
(30, 35)
(30, 33)
(29, 74)
(29, 72)
(30, 37)
(30, 36)
(29, 76)
(29, 75)
(30, 38)
(29, 79)
(30, 32)
(30, 39)
(29, 77)
(29, 78)
(30, 40)
(30, 34)
(30, 41)
(30, 44)
(30, 42)
(30, 45)
(30, 47)
(30, 48)
(30, 51)
(30, 52)
(30, 50)
(30, 49)
(30, 53)
(30, 43)
(30, 55)
(30, 59)
(30, 54)
(30, 57)
(30, 56)
(30, 60)
(30, 46)
(30, 58)
(30, 62)
(30, 61)
(

(50, 70)
(50, 79)
(50, 77)
(51, 51)
(50, 68)
(51, 52)
(51, 53)
(51, 54)
(50, 76)
(50, 67)
(50, 66)
(51, 59)
(51, 58)
(51, 55)
(51, 57)
(51, 56)
(51, 71)
(51, 73)
(51, 61)
(51, 62)
(51, 60)
(51, 75)
(51, 63)
(52, 52)
(51, 77)
(51, 69)
(51, 70)
(51, 65)
(51, 74)
(51, 72)
(51, 64)
(51, 68)
(51, 78)
(51, 79)
(51, 67)
(52, 58)
(52, 53)
(51, 66)
(52, 54)
(52, 57)
(52, 59)
(51, 76)
(52, 71)
(52, 55)
(52, 75)
(52, 61)
(52, 65)
(53, 54)
(52, 73)
(52, 56)
(52, 69)
(52, 66)
(52, 70)
(52, 78)
(52, 60)
(52, 79)
(52, 77)
(52, 74)
(52, 72)
(52, 63)
(53, 53)
(52, 62)
(52, 68)
(52, 64)
(52, 76)
(52, 67)
(53, 55)
(53, 58)
(53, 71)
(53, 57)
(53, 59)
(53, 56)
(54, 54)
(53, 69)
(53, 75)
(53, 73)
(53, 61)
(53, 62)
(53, 60)
(53, 72)
(53, 65)
(53, 70)
(53, 77)
(53, 63)
(53, 79)
(53, 74)
(53, 78)
(53, 64)
(54, 57)
(53, 68)
(54, 55)
(53, 67)
(54, 59)
(54, 58)
(53, 76)
(53, 66)
(54, 56)
(54, 77)
(54, 71)
(54, 70)
(54, 73)
(54, 79)
(55, 55)
(54, 75)
(54, 69)
(54, 74)
(54, 65)
(54, 61)
(54, 63)
(54, 67)
(54, 62)
(

In [15]:
from joblib import Parallel, delayed


In [None]:
test

In [29]:
test = [[1,2], [3,4]]
sum(test, [])

[1, 2, 3, 4]