In [1]:
%reset -f
%matplotlib inline
from itertools import permutations
import matplotlib.pyplot as plt
import numpy as np

from aux import Generic
from ntwk import LIFNtwk
from plot import raster, set_font_size

import PARAMS as P_GLOBAL

# PARAMETERS AND CROSS-SIMULATION SUBROUTINES

## Parameters

In [2]:
P = Generic(
    N=500,
)

## Network structure

### Item assignment

In [3]:
def assign_items(n, q, q_star, test=False):
    """
    Probabilistically assign item labels to n neurons.
    Exactly 4 items are considered, with hard-coded 
    combination probabilities.
    
    :param n: number of neurons
    :param q: prob. of neuron being in any given item group
    :param q_star: prob. of neuron being in any specific second
        item group, given that it is in another first item group
        
    :return: n x 4 boolean array of item assignments
    """
    
    # specify all item combinations
    item_combos = [
        [],  # row 0
        [0,], [1,], [2,], [3,],  # row 1
        [0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3],  # row 2
        [0, 1, 2], [0, 1, 3], [0, 2, 3], [1, 2, 3],  # row 3
        [0, 1, 2, 3],  # row 4
    ]
    
    # compute probability of each combination
    
    ## p of row 1
    a = q * ((1 - q_star)**3)
    ## p of row 2
    b = q * q_star * ((1 - q_star)**2)
    ## p of row 3
    c = q * (q_star**2) * (1 - q_star)
    ## p of row 4
    d = q * (q_star**3)
    
    ## validate probabilities
    assert 4*a + 6*b + 4*c + d < 1
    
    ## full vector
    p = [
        1 - (4*a + 6*b + 4*c + d),  # row 0
        a, a, a, a,  # row 1
        b, b, b, b, b, b,  # row 2
        c, c, c, c,  # row 3
        d,  # row 4
    ]
    
    # assign labels
    labels = np.zeros((n, 4), bool)
    
    for i in range(n):
        
        # sample item assignment for i-th neuron
        item_combo = np.random.choice(item_combos, p=p)
        
        # convert to boolean
        labels[i, item_combo] = True
    
    if not test:
        
        return labels
    
    else:
        
        # estimate q
        q_hat = labels.mean()
        
        # estimate q_star
        q_star_hat = np.array([
            labels[labels[:, i], :][:, j].mean()
            for i, j in permutations(range(4), 2)
        ])
        
        return Generic(
            labels=labels, q_hat=q_hat, q_star_hat=q_star_hat)

#### Test

In [4]:
def test_assign_items_no_overlap():
    print('\n>> test_assign_items_no_overlap...\n'.upper())
    
    n = 10000
    q = 0.2
    q_star = 0.0
    
    rslt = assign_items(n=n, q=q, q_star=q_star, test=True)
    
    # ensure correct item probability  
    print('q = {}'.format(q))
    print('q_hat = {}\n'.format(rslt.q_hat))

    # ensure no overlap
    print('q* = {}'.format(q_star))
    print('q*_hat = {}\n'.format(rslt.q_star_hat.mean()))
    
    # ensure only 0 or 1 item per neuron
    print('unique items per neuron ([0, 1]) = {}\n'.format(set(rslt.labels.sum(1))))
    

def test_assign_items_complete_overlap():
    print('\n>> test_assign_items_complete_overlap...\n'.upper())
    
    n = 10000
    q = 0.2
    q_star = 1.0
    
    rslt = assign_items(n=n, q=q, q_star=q_star, test=True)
    
    # ensure correct item probability  
    print('q = {}'.format(q))
    print('q_hat = {}\n'.format(rslt.q_hat))

    # ensure complete overlap
    print('q* = {}'.format(q_star))
    print('q*_hat = {}\n'.format(rslt.q_star_hat.mean()))
    
    # ensure only 0 or 4 items per neuron
    print('unique items per neuron ([0, 4]) = {}\n'.format(set(rslt.labels.sum(1))))
    

def test_assign_items_partial_overlap():
    print('\n>> test_assign_items_partial_overlap...\n'.upper())
    
    n = 10000
    q = 0.2
    q_star = 0.5
    
    rslt = assign_items(n=n, q=q, q_star=q_star, test=True)
    
    # ensure correct item probability  
    print('q = {}'.format(q))
    print('q_hat = {}\n'.format(rslt.q_hat))

    # ensure correct partial overlap
    print('q* = {}'.format(q_star))
    print('q*_hat = {}\n'.format(rslt.q_star_hat.mean()))
    
    # ensure any number of items per neuron
    print('unique items per neuron ([0, 1, 2, 3, 4]) = {}\n'.format(set(rslt.labels.sum(1))))
            

np.random.seed(0)

test_assign_items_no_overlap()
test_assign_items_complete_overlap()
test_assign_items_partial_overlap()


>> TEST_ASSIGN_ITEMS_NO_OVERLAP...

q = 0.2
q_hat = 0.1985

q* = 0.0
q*_hat = 0.0

unique items per neuron ([0, 1]) = {0, 1}


>> TEST_ASSIGN_ITEMS_COMPLETE_OVERLAP...

q = 0.2
q_hat = 0.2014

q* = 1.0
q*_hat = 1.0

unique items per neuron ([0, 4]) = {0, 4}


>> TEST_ASSIGN_ITEMS_PARTIAL_OVERLAP...

q = 0.2
q_hat = 0.200525

q* = 0.5
q*_hat = 0.5046751920174396

unique items per neuron ([0, 1, 2, 3, 4]) = {0, 1, 2, 3, 4}



### Weight matrix

In [None]:
def make_w(labels, p_0, p_1, P):
    return w

#### Test

### Top-level build

In [None]:
def make_ntwk(n, q, q_star, p_0, p_1, P):
    ntwk.order = ...
    ntwk.labels = ...
    return ntwk

## Stimulation paradigm

In [None]:
def make_spks_up(ntwk, items, P):
    return spks_up

## Recall accuracy

In [None]:
def calc_acc(rsp, items, P):
    return acc

# SIMULATIONS

## Examples

### Simulation 

In [None]:
def xmpl(q_star, p_0, p_1, items):
    
    # build ntwk and run smln
    ntwk = make_ntwk(n=N, q=Q, q_star=q_star, p_0=p_0, p_1=p_1, P=P)
    spks_up = make_spks_up(ntwk, items, P)
    rsp = ntwk.run(spks_up, P.DT)
    acc = calc_acc(rsp, items, P)

    # show raster and CTL->OUT weights
    fig, axs = plt.subplots(2, 1, figsize=(15, 8), tight_layout=True)

    ## raster
    raster(axs[0], rsp.ts, rsp.spks, order=ntwk.order)

    ## CTL->OUT weights
    w_out_ctl_items = ...
    w_out_ctl_all = ...

    axs[1].plot(rsp.ts, w_out_ctl_all, color='k', lw=2)
    axs[1].plot(rsp.ts, w_out_ctl_items, color='r', lw=2)
    
    axs[0].set_title('ACC = {0:.4f}'.format(acc))
    for ax in axs:
        set_font_size(ax, 16)
        
    return Generic(
        ntwk=ntwk,
        items=items,
        spks_up=spks_up,
        rsp=rsp,
        acc=acc,
        fig=fig,
        ax=axs)

### Single-item, no overlap

In [None]:
xmpl(q_star=0, p_0=0.1, p_1=0.2, items=[1])

### Single-item, modest overlap

In [None]:
xmpl(q_star=0.2, p_0=0.1, p_1=0.2, items=[1])

### Multi-item, no overlap

In [None]:
xmpl(q_star=0, p_0=0.1, p_1=0.2, items=[1, 3])

### Multi-item, modest overlap

In [None]:
xmpl(q_star=0.2, p_0=0.1, p_1=0.2, items=[1, 3])

## Parameter dependence

### Simulation

In [None]:
def param_variation(q_stars, p_0s, p_1s):
    
    return Generic(
        ...
    )

### Varying interconnectivity (p_0) and overlap (q_star)

### Varying intraconnectivity (p_1) and overlap (q_star)