In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from scipy import misc
import sys
import os
import bz2
import glob
from mpl_toolkits.axes_grid1 import make_axes_locatable
from brian2.units import *

In [2]:
def thresholded_difference(img, ref, thresh):
    diff = img - ref
    diff[np.abs(diff) < thresh] = 0
    return diff

def diff_to_count(diff, thresh, ms_per_frame):
    return np.clip(diff//thresh, -ms_per_frame, ms_per_frame)


# X | Y | P
# col | row | pol
def encode_index(row, col, pol, width_bits, height_bits, pol_bits):
    w_mask = (1 << width_bits) - 1
    h_mask = (1 << height_bits) - 1
    pol_mask = (1 << pol_bits) - 1
    return ((np.uint32(col) & w_mask) << (height_bits + pol_bits)) + \
            ((np.uint32(row) & h_mask) << pol_bits) + ((np.uint32(pol) & pol_mask))
    
def neg(diff):
    return np.where(diff < 0.)

def pos(diff):
    return np.where(diff > 0)

def count_to_spikes(count, start_t, ms_per_frame, width_bits, height_bits, pol_bits):
    total = (1 << (width_bits + height_bits + pol_bits))
    neg_ids = neg(count)
    pos_ids = pos(count)
    pos_enc = encode_index(pos_ids[0], pos_ids[1], np.ones(pos_ids[0].shape), \
                           width_bits, height_bits, pol_bits)
    
    neg_enc = encode_index(neg_ids[0], neg_ids[1], np.zeros(neg_ids[0].shape), \
                           width_bits, height_bits, pol_bits)
    
    max_t = start_t + ms_per_frame
    spikes = {}
    
    for idx, enc in enumerate(pos_enc):
        end_t = start_t + count[pos_ids[0][idx], pos_ids[1][idx]]
        end_t = int(min(max_t, end_t))
        for t in np.arange(start_t, end_t):
            l = spikes.get(t, [])
            l.append(enc)
            spikes[t] = l

    for idx, enc in enumerate(neg_enc):
        end_t = start_t - count[neg_ids[0][idx], neg_ids[1][idx]]
        end_t = int(min(max_t, end_t))
        for t in np.arange(start_t, end_t):
            l = spikes.get(t, [])
            l.append(enc)
            spikes[t] = l

    return spikes

def del_file(fname):
    os.remove(fname)

def write_to_file(spikes, output):
    for t in spikes:
        for nid in spikes[t]:
            output.write("%d %f\n"%(nid, t))

            

def get_img_filenames(in_path, pattern="*.png"):
    return sorted(glob.glob(os.path.join(in_path, pattern)))

def same_size_cbar(im, ax):
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    plt.colorbar(im, cax=cax)
    
def plot_cycle(tmp, ref, diff, count):
    cm = "Greys_r"
    plt.figure()
    ax = plt.subplot(1, 4, 1)
    ax.set_title("img")
    im = plt.imshow(tmp, cmap=cm); ax.set_xticks([]); ax.set_yticks([])
    same_size_cbar(im, ax)

    ax = plt.subplot(1, 4, 2)
    ax.set_title("ref")
    im = plt.imshow(ref, cmap=cm); ax.set_xticks([]); ax.set_yticks([])
    same_size_cbar(im, ax)
    
    ax = plt.subplot(1, 4, 3)
    ax.set_title("diff")
    im = plt.imshow(diff, cmap=cm); ax.set_xticks([]); ax.set_yticks([])
    same_size_cbar(im, ax)
    
    ax = plt.subplot(1, 4, 4)
    ax.set_title("count")
    im = plt.imshow(count, cmap=cm, vmin=-5, vmax=5); ax.set_xticks([]); ax.set_yticks([])
    same_size_cbar(im, ax)
    plt.tight_layout()
    plt.show()

def encode_images(filenames, thresh, fps, output_filename):
    img = misc.imread(filenames[0], flatten=True)
    img_h, img_w = img.shape
    ms_per_frame = np.round(1000./fps)
    ref = np.zeros(img.shape)
    diff = np.zeros(img.shape)
    count = np.zeros(img.shape)
    thr = np.ones(img.shape)*thresh
    wbits = int(np.ceil(np.log2(img_w)))
    hbits = int(np.ceil(np.log2(img_h)))
    pbits = 1
    spikes = {}
    start_t = 0
    total_imgs = len(filenames)
    out_spikes = [[] for _ in range(1<<(wbits+hbits+pbits))]
    for img_idx, fname in enumerate(filenames):
        sys.stdout.write("\r%06d / %06d"%(img_idx+1, total_imgs))
        sys.stdout.flush()

        img[:] = misc.imread(fname, flatten=True)
        diff[:] = thresholded_difference(img, ref, thresh)
        count[:] = diff_to_count(diff, thresh, ms_per_frame)
        ref[:] = ref + count*thresh
        
#         plot_cycle(img, ref, diff, count)
        spikes.clear()
        spikes = count_to_spikes(count, start_t, ms_per_frame, 
                                 wbits, hbits, pbits)
        for t in spikes:

            for nid in spikes[t]:
                
                out_spikes[nid].append(t)

        start_t += ms_per_frame
        
    np_spikes = np.asarray([np.asarray(times, dtype=np.double) for times in out_spikes])
    
    
    
    np.savez_compressed(output_filename, 
        spikes_on=np_spikes[1::2], 
        spikes_off=np_spikes[0::2],
        time_ms_per_frame=ms_per_frame,
        width=img_w, height=img_h,
        width_bits=wbits, height_bits=hbits, polarity_bits=pbits,
        threshold=thresh)



In [3]:
in_folder = "raw_moving_bar_pngs"
folders = [f for f in os.listdir(in_folder) 
                   if os.path.isdir(os.path.join(in_folder, f))]
# for f in folders:
#     print f
# print folders

In [4]:
out_folder = "spiking_moving_bar_motif_bank"
if not os.path.exists(out_folder):
    os.mkdir(out_folder)
    
thresh = 128
fps = 200/ second
for folder in folders:
    in_path = os.path.join(in_folder, folder)
    out_path = os.path.join(out_folder, folder + ".npz")
    fnames = get_img_filenames(in_path)
    encode_images(fnames, thresh, int(fps * second), out_path)

000040 / 000040

# Check angle 0 npz

In [5]:
data = np.load(out_folder + "/moving_bar_res_32x32__w_3__angle_000__fps_200__cycles_0.20s.npz")

In [6]:
data.files

['spikes_off',
 'polarity_bits',
 'height',
 'width',
 'threshold',
 'time_ms_per_frame',
 'spikes_on',
 'height_bits',
 'width_bits']

In [7]:
spikes_on = data['spikes_on']
spikes_off = data['spikes_off']

In [8]:
spikes_on

array([[   0.],
       [   0.],
       [   0.],
       ..., 
       [ 150.],
       [ 150.],
       [ 150.]])

In [14]:
spikes_on.size / (np.max(spikes_on)*ms * 1024)

6.66666667 * hertz

In [9]:
data.close()

In [10]:
import collections

In [12]:
for angle in np.arange(0, 360, 5):
    namen = ("/moving_bar_res_32x32__w_3__angle_%03d__fps_200__cycles_0.20s.npz"%angle)
    data = np.load(out_folder + namen)
    spikes_on = data['spikes_on']
    spikes_off = data['spikes_off']
    
    assert spikes_on.shape[0] == 1024
    assert spikes_off.shape[0] == 1024
    
    if isinstance(spikes_on[0], collections.Iterable):
        assert np.logical_or(len(spikes_on[0]) == 1, len(spikes_on[0])==0), spikes_on[0]
        if len(spikes_on.shape)==1:
            print namen
    else:
        print "Get to fuck"
        assert spikes_on.shape[1]==1
    
    data.close()

/moving_bar_res_32x32__w_3__angle_010__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_3__angle_015__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_3__angle_020__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_3__angle_025__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_3__angle_030__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_3__angle_035__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_3__angle_040__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_3__angle_045__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_3__angle_050__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_3__angle_055__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_3__angle_060__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_3__angle_065__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_3__angle_070__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_3__angle_075__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_3__angle_080__fps_200__cycles_0.20s.npz
/moving_bar_res_32x32__w_

# Generate spike times matrix for spike source array

In [59]:
simtime = 300000 * ms
chunk = 200*ms
N_layer = 32**2

In [60]:
def generate_bar_input(simtime, chunk, N_layer, angles = np.arange(0, 360, 5), in_folder="spiking_moving_bar_motif_bank"):
    actual_angles = np.random.choice(angles, int(simtime/chunk))
    spike_times = []
    for _ in range(N_layer):
        spike_times.append([])
    on_files = []
    off_files = []
    for angle in angles:
        fname = ("/moving_bar_res_32x32__w_3__angle_%03d__fps_200__cycles_0.20s.npz"%angle)
        data = np.load(in_folder + fname)
        on_files.append(data['spikes_on'])
        off_files.append(data['spikes_off'])
        data.close()
        
    on_spikes = []
    off_spikes = []
    for _ in range(N_layer):
        on_spikes.append([])
        off_spikes.append([])
    
    for chunk_no in range(int(simtime/chunk)):
        current_angle = actual_angles[chunk_no]
        angle_index = int(np.argwhere(angles==current_angle))
        
        on_entries = on_files[angle_index] + (chunk_no * (chunk/ms))
        off_entries = off_files[angle_index] + (chunk_no * (chunk/ms))
        
        for index, value in np.ndenumerate(on_entries):
            if not isinstance(value, collections.Iterable):
                on_spikes[index[0]].append(value)
            elif len(value)==1:
                on_spikes[index[0]].append(value[0])
            
        
        for index, value in np.ndenumerate(off_entries):
            if not isinstance(value, collections.Iterable):
                off_spikes[index[0]].append(value)
            elif len(value)==1:
                off_spikes[index[0]].append(value[0])
        
        
#     on_spikes = np.asarray(on_spikes)
#     off_spikes = np.asarray(off_spikes)
    
    np.savez("spiking_moving_bar_motif_bank_simtime_%ds"%int(simtime/second), 
             actual_angles = actual_angles,
             on_spikes = on_spikes,
             off_spikes = off_spikes,
             chunk=chunk/ms, simtime=simtime/ms
            )
    return actual_angles, on_spikes, off_spikes

In [61]:
actual_angles, on_spikes, off_spikes = generate_bar_input(simtime, chunk, N_layer)

In [54]:
for i in range(N_layer):
    print len(on_spikes[i])

4639
5161
5314
5559
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
5559
5308
5152
5158
5413
5650
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
5625
5387
5341
5646
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
5620
5581
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000
6000


In [55]:
off_spikes.shape

AttributeError: 'list' object has no attribute 'shape'