# This is a notebook for learning the codebase

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import sys, shelve

import torch
import torch.nn as nn

## Dataset
- Rihkye tasks

In [4]:
class RihkyeTask():
    def __init__(self, num_cueingcontext, num_cue, num_rule, rule, blocklen, block_cueingcontext, tsteps, cuesteps, batch_size):
        '''Generate Rihkye task dataset
        Parameters:
            num_cueingcontext: int, number of cueing contexts
            num_cue: int, number of cues in each cueing contexts
            num_rule: int, number of rules (e.g. attend to audition is a rule)
            rule: list of int, rule corresponding to one cue in one cueing context
            blocklen: list of int, trainlen of each block
            block_cueingcontext: list of int, cueing context trained in each block
            tsteps: int, length of a trial, equals to cuesteps + delaystep
            cuesteps: int, length of showing cues
            batch_size: int
        '''
        self.num_cueingcontext = num_cueingcontext
        self.num_cue = num_cue
        self.num_rule = num_rule
        self.rule = rule

        self.blockrange = np.zeros_like(blocklen) # index range of each block
        for i in range(len(blocklen)):
            self.blockrange[i] = sum(blocklen[:i+1])

        self.block_cueingcontext = block_cueingcontext
        self.tsteps = tsteps
        self.cuesteps = cuesteps
        self.batch_size = batch_size

        # Initialize counter
        self.traini = 0

    def __call__(self, *args, **kwargs):
        """Return the input stimulus and target output for one cycle.
        Parameters:
            No parameter

        Returns:
            input: (n_time, n_input)
            target: (n_time, n_output)
        """

        inputs = np.zeros((self.tsteps * self.num_cue, self.batch_size, self.num_cue*self.num_cueingcontext))
        targets = np.zeros((self.tsteps * self.num_cue, self.batch_size, self.num_rule))

        for i in range(self.batch_size):
            blocki = np.argwhere(self.traini / self.blockrange < 1.0)
            if len(blocki) == 0:
                raise ValueError("the end")
            blocki = int(blocki[0])
            cueingcontext = self.block_cueingcontext[blocki]

            cueList = self.get_cue_list(cueingcontext)

            cues_order = np.random.permutation(cueList)


            t_start = 0
            for cueingcontext, cuei in cues_order:
                cue, target = self.get_cue_target(cueingcontext, cuei)
                inputs[t_start:t_start+self.cuesteps, i, :] = cue
                targets[t_start:t_start+self.tsteps, i, :] = target
                t_start += self.tsteps

            self.traini += 1

        return inputs, targets

    # Helper functions
    def get_cue_list(self, cueingcontext):
        '''
        Return:
        (cueingcontext, cuei) combinations for one training step
        '''
        cueList = np.dstack(( np.repeat(cueingcontext,self.num_cue), 
                                np.arange(self.num_cue) ))

        return cueList[0]

    def get_cue_target(self, cueingcontext, cuei):
        cue = np.zeros(self.num_cue*self.num_cueingcontext)
        cue[cueingcontext*self.num_cue + cuei] = 1.

        target = np.zeros(self.num_rule)
        target[self.rule[cueingcontext*self.num_cue + cuei]] = 1

        return cue, target

In [5]:
RNGSEED = 5
np.random.seed([RNGSEED])

num_cueingcontext = 2
num_cue = 2
num_rule = 2
rule = [0, 1, 0, 1]
blocklen = [500, 500, 200]
block_cueingcontext = [0, 1, 0]
tsteps = 200
cuesteps = 100
batch_size = 10

dataset = RihkyeTask(num_cueingcontext=num_cueingcontext, num_cue=num_cue, num_rule=num_rule, rule=rule, blocklen=blocklen, block_cueingcontext=block_cueingcontext, tsteps=tsteps, cuesteps=cuesteps, batch_size=batch_size)

input, target = dataset()
print(input.shape, target.shape, '\n')
#print(input)
#print(target)

(400, 10, 4) (400, 10, 2) 



## Model

In [None]:
class SensoryInputLayer():
    def __init__(self, n_sub, n_cues, n_output):
        # TODO: Hard-coded for now
        self.Ncues = n_cues
        self.Nsub = n_sub
        self.Nneur = n_output
        self.positiveRates = True

        self.wIn = np.zeros((self.Nneur, self.Ncues))
        self.cueFactor = 1.5
        if self.positiveRates:
            lowcue, highcue = 0.5, 1.
        else:
            lowcue, highcue = -1., 1
        for cuei in np.arange(self.Ncues):
            self.wIn[self.Nsub * cuei:self.Nsub * (cuei + 1), cuei] = \
                np.random.uniform(lowcue, highcue, size=self.Nsub) * self.cueFactor

        self._use_torch = False

    def __call__(self, input):
        if self._use_torch:
            input = input.numpy()

        output = np.dot(self.wIn, input)

        if self._use_torch:
            output = torch.from_numpy(output).astype(torch.float)

        return output

    def torch(self, use_torch=True):
        self._use_torch = use_torch


class PFC():
    def __init__(self, n_neuron, n_neuron_per_cue, positiveRates=True, MDeffect=True):
        self.Nneur = n_neuron
        self.Nsub = n_neuron_per_cue
        self.useMult = True
        self.noisePresent = False
        self.noiseSD = 1e-3#1e-3
        self.tau = 0.02
        self.dt = 0.001
                    
        self.positiveRates = positiveRates
        if self.positiveRates:
            # only +ve rates
            self.activation = lambda inp: np.clip(np.tanh(inp), 0, None)
        else:
            # both +ve/-ve rates as in Miconi
            self.activation = lambda inp: np.tanh(inp)

        self.G = 0.75  # determines also the cross-task recurrence
        # With MDeffect = True and MDstrength = 0, i.e. MD inactivated
        #  PFC recurrence is (1+PFC_G_off)*Gbase = (1+1.5)*0.75 = 1.875
        # So with MDeffect = False, ensure the same PFC recurrence for the pure reservoir
        if not MDeffect: self.G = 1.875

        self.init_activity()
        self.init_weights()

    def init_activity(self):
        self.xinp = np.random.uniform(0, 0.1, size=(self.Nneur))
        self.activity = self.activation(self.xinp)

    def init_weights(self):
        self.Jrec = np.random.normal(size=(self.Nneur, self.Nneur)) * self.G / np.sqrt(self.Nsub * 2)
        # make mean input to each row zero,
        #  helps to avoid saturation (both sides) for positive-only rates.
        #  see Nicola & Clopath 2016
        self.Jrec -= np.mean(self.Jrec, axis=1)[:, np.newaxis]
        # mean of rows i.e. across columns (axis 1),
        #  then expand with np.newaxis
        #   so that numpy's broadcast works on rows not columns

    def __call__(self, input, input_x=None, *args, **kwargs):
        """Run the network one step

        For now, consider this network receiving input from PFC,
        input stands for activity of PFC neurons
        output stands for output current to PFC neurons

        Args:
            input: array (n_neuron,)
            input_x: array (n_neuron,), modulatory input that multiplicatively
                interact with the neurons

        Returns:
            output: array (n_output,)
        """

        if input_x is None:
            input_x = np.zeros_like(input)
            
        xadd = np.dot(self.Jrec, self.activity)
        xadd += input_x + input # MD inputs
        
        self.xinp += self.dt / self.tau * (-self.xinp + xadd)

        if self.noisePresent:
            self.xinp += np.random.normal(size=(self.Nneur)) * self.noiseSD * np.sqrt(self.dt) / self.tau

        rout = self.activation(self.xinp)
        self.activity = rout
        return rout

    def update_weights(self, input, activity, output):
        self.trace = self.trace + activity
        w_input = self.w_input + input * self.trace
        w_output = self.w_output + input * self.trace


class MD():
    def __init__(self, Nneur, Num_MD, num_active=1, positiveRates=True, dt=0.001):
        self.Nneur = Nneur
        self.Num_MD = Num_MD
        self.positiveRates = positiveRates
        self.num_active = num_active # num_active: num MD active per context

        self.tau = 0.02
        self.tau_times = 4
        self.dt = dt
        self.tsteps = 200
        self.Hebb_learning_rate = 1e-4
        Gbase = 0.75  # determines also the cross-task recurrence

        self.wPFC2MD = np.random.normal(0, 1 / np.sqrt(self.Num_MD * self.Nneur)    , size=(self.Num_MD, self.Nneur))
        self.wMD2PFC = np.random.normal(0, 1 / np.sqrt(self.Num_MD * self.Nneur)    , size=(self.Nneur, self.Num_MD))
        self.wMD2PFCMult = np.random.normal(0, 1 / np.sqrt(self.Num_MD * self.Nneur), size=(self.Nneur, self.Num_MD))
        self.MDpreTrace = np.zeros(shape=(self.Nneur))
        self.MDpostTrace = np.zeros(shape=(self.Num_MD))
        self.MDpreTrace_threshold = 0

        # Choose G based on the type of activation function
        # unclipped activation requires lower G than clipped activation,
        #  which in turn requires lower G than shifted tanh activation.
        if self.positiveRates:
            self.G = Gbase
            self.tauMD = self.tau * self.tau_times  ##self.tau
        else:
            self.G = Gbase
            self.MDthreshold = 0.4
            self.tauMD = self.tau * 10 * self.tau_times
        self.init_activity()
        
    def init_activity(self):
        self.MDinp = np.zeros(shape=self.Num_MD)
        
    def __call__(self, input, *args, **kwargs):
        """Run the network one step

        For now, consider this network receiving input from PFC,
        input stands for activity of PFC neurons
        output stands for output current to MD neurons

        Args:
            input: array (n_input,)
            

        Returns:
            output: array (n_output,)
        """
        # MD decays 10x slower than PFC neurons,
        #  so as to somewhat integrate PFC input
        if self.positiveRates:
            self.MDinp += (self.dt / self.tauMD) * (-self.MDinp + np.dot(self.wPFC2MD, input))
        else:  # shift PFC rates, so that mean is non-zero to turn MD on
            self.MDinp += (self.dt / self.tauMD) * (-self.MDinp + np.dot(self.wPFC2MD, (input + 1. / 2)))
                     
        MDout = self.winner_take_all(self.MDinp)

        self.update_weights(input, MDout)

        return MDout

    def update_trace(self, rout, MDout):
        # MD presynaptic traces filtered over 10 trials
        # Ideally one should weight them with MD syn weights,
        #  but syn plasticity just uses pre!
        self.MDpreTrace += 1. / self.tsteps / 5. * (-self.MDpreTrace + rout)
        self.MDpostTrace += 1. / self.tsteps / 5. * (-self.MDpostTrace + MDout)
        # MDoutTrace =  self.MDpostTrace

        MDoutTrace = self.winner_take_all(self.MDpostTrace)

        return MDoutTrace

    def update_weights(self, rout, MDout):
        """Update weights with plasticity rules.

        Args:
            rout: input to MD
            MDout: activity of MD
        """
        MDoutTrace = self.update_trace(rout, MDout)
        #                    if self.MDpostTrace[0] > self.MDpostTrace[1]: MDoutTrace = np.array([1,0])
        #                    else: MDoutTrace = np.array([0,1])
        self.MDpreTrace_threshold = np.mean(self.MDpreTrace)
        #self.MDpreTrace_threshold = np.mean(self.MDpreTrace[:self.Nsub * self.Ncues])  # first 800 cells are cue selective
        # MDoutTrace_threshold = np.mean(MDoutTrace) #median
        MDoutTrace_threshold = 0.5  
        wPFC2MDdelta = 0.5 * self.Hebb_learning_rate * np.outer(MDoutTrace - MDoutTrace_threshold,self.MDpreTrace - self.MDpreTrace_threshold)

        # Update and clip the weights
        self.wPFC2MD = np.clip(self.wPFC2MD + wPFC2MDdelta, 0., 1.)
        self.wMD2PFC = np.clip(self.wMD2PFC + 0.1 * (wPFC2MDdelta.T), -10., 0.)
        self.wMD2PFCMult = np.clip(self.wMD2PFCMult + 0.1 * (wPFC2MDdelta.T), 0.,7. / self.G)

    def winner_take_all(self, MDinp):
        '''Winner take all on the MD
        '''

        # Thresholding
        MDout = np.zeros(self.Num_MD)
        MDinp_sorted = np.sort(MDinp)
        # num_active = np.round(self.Num_MD / self.Ntasks)

        MDthreshold = np.mean(MDinp_sorted[-int(self.num_active) * 2:])
        # MDthreshold  = np.mean(MDinp)
        index_pos = np.where(MDinp >= MDthreshold)
        index_neg = np.where(MDinp < MDthreshold)
        MDout[index_pos] = 1
        MDout[index_neg] = 0

        return MDout


class OutputLayer():
    def __init__(self, n_input, n_out, dt):
        self.dt = dt
        self.tau = 0.02
        self.tauError = 0.001
        self.Nout = n_out
        self.Nneur = n_input
        self.learning_rate = 5e-6
        self.wOut = np.random.uniform(-1, 1,
                                      size=(
                                      self.Nout, self.Nneur)) / self.Nneur
        self.state = np.zeros(shape=self.Nout)
        self.error_smooth = np.zeros(shape=self.Nout)
        self.activation = lambda inp: np.clip(np.tanh(inp), 0, None)

    def __call__(self, input, target, *args, **kwargs):
        outAdd = np.dot(self.wOut, input)
        self.state += self.dt / self.tau * (-self.state + outAdd)
        output = self.activation(self.state)
        self.update_weights(input, output, target)
        return output

    def update_weights(self, input, output, target):
        """error-driven i.e. error*pre (perceptron like) learning"""
        error = output - target
        self.error_smooth += self.dt / self.tauError * (-self.error_smooth + error)
        self.wOut += -self.learning_rate * np.outer(self.error_smooth, input)

In [None]:
class FullNetwork():
    def __init__(self, Num_PFC, n_neuron_per_cue, Num_MD, num_active, MDeffect=True):
        dt = 0.001
        self.sensory2pfc = SensoryInputLayer(n_sub=n_neuron_per_cue, n_cues=4, n_output=Num_PFC)
        self.pfc = PFC(Num_PFC, n_neuron_per_cue, MDeffect=MDeffect)
        self.pfc2out = OutputLayer(n_input=Num_PFC, n_out=2, dt=dt)
        self.pfc_output_t = np.array([])
        
        self.MDeffect = MDeffect
        if self.MDeffect:
            self.md = MD(Nneur=Num_PFC, Num_MD=Num_MD, num_active=num_active, dt=dt)

            self.md_output = np.zeros(Num_MD)
            index = np.random.permutation(Num_MD)
            self.md_output[index[:num_active]] = 1 # randomly set num_active indices of md_output to 1
            self.md_output_t = np.array([])

    def __call__(self, input, target, *args, **kwargs):
        """
        Args:
             input: (n_time, n_input)
             target: (n_time, n_output)
             
        """
        self._check_shape(input, target)
        n_time = input.shape[0]
        tsteps = 200

        self.pfc.init_activity()  # Reinit PFC activity
        pfc_output = self.pfc.activity
        if self.MDeffect:
            self.md.init_activity()  # Reinit MD activity

        output = np.zeros((n_time, target.shape[-1]))
        self.pfc_output_t *= 0
        if self.MDeffect:
            self.md_output_t *= 0

        for i in range(n_time):
            input_t = input[i]
            target_t = target[i]
            
            if i % tsteps == 0: # Reinit activity for each trial
                self.pfc.init_activity()  # Reinit PFC activity
                pfc_output = self.pfc.activity
                if self.MDeffect:
                    self.md.init_activity()  # Reinit MD activity

            input2pfc = self.sensory2pfc(input_t)
            if self.MDeffect:
                self.md_output = self.md(pfc_output)

                self.md.MD2PFCMult = np.dot(self.md.wMD2PFCMult, self.md_output)
                rec_inp = np.dot(self.pfc.Jrec, self.pfc.activity)
                md2pfc_weights = (self.md.MD2PFCMult / np.round(self.md.Num_MD / 2))
                md2pfc = md2pfc_weights * rec_inp  
                md2pfc += np.dot(self.md.wMD2PFC / np.round(self.md.Num_MD /2), self.md_output) 
                pfc_output = self.pfc(input2pfc, md2pfc)

                if i==0:
                    self.pfc_output_t = pfc_output.reshape((1,pfc_output.shape[0]))
                    self.md_output_t = self.md_output.reshape((1,self.md_output.shape[0]))
                else:
                    self.pfc_output_t = np.concatenate((self.pfc_output_t, pfc_output.reshape((1,pfc_output.shape[0]))),axis=0)
                    self.md_output_t = np.concatenate((self.md_output_t, self.md_output.reshape((1,self.md_output.shape[0]))),axis=0)
            
            else:
                pfc_output = self.pfc(input2pfc)
                if i==0:
                    self.pfc_output_t = pfc_output.reshape((1,pfc_output.shape[0]))
                else:
                    self.pfc_output_t = np.concatenate((self.pfc_output_t, pfc_output.reshape((1,pfc_output.shape[0]))),axis=0)
            output[i] = self.pfc2out(pfc_output, target_t)

        return output

    def _check_shape(self, input, target):
        assert len(input.shape) == 2
        assert len(target.shape) == 2
        assert input.shape[0] == target.shape[0]

In [None]:
class PytorchPFC(nn.Module):
    def __init__(self, n_neuron, n_neuron_per_cue, positiveRates=True):
        super().__init__()
        self.Nneur = n_neuron
        self.Nsub = n_neuron_per_cue
        self.useMult = True
        self.noisePresent = False
        self.noiseSD = 1e-1  # 1e-3
        self.tau = 0.02
        self.dt = 0.001

        self.positiveRates = positiveRates
        if self.positiveRates:
            # only +ve rates
            self.activation = lambda inp: torch.clip(torch.tanh(inp), 0, None)
        else:
            # both +ve/-ve rates as in Miconi
            self.activation = lambda inp: torch.tanh(inp)

        self.G = 0.75  # determines also the cross-task recurrence

        self.init_activity()
        self.init_weights()

    def init_activity(self):
        self.xinp = torch.rand(self.Nneur) * 0.1
        self.activity = self.activation(self.xinp)

    def init_weights(self):
        self.Jrec = torch.normal(mean=0, std=self.G / np.sqrt(self.Nsub * 2)*2, size=(self.Nneur, self.Nneur))
        # make mean input to each row zero,
        #  helps to avoid saturation (both sides) for positive-only rates.
        #  see Nicola & Clopath 2016
        # mean of rows i.e. across columns (axis 1),
        #  then expand with np.newaxis
        #   so that numpy's broadcast works on rows not columns
        self.Jrec -= torch.mean(self.Jrec, dim=1).unsqueeze_(dim=1)

    def forward(self, input, input_x=None):
        """Run the network one step

        For now, consider this network receiving input from PFC,
        input stand for activity of PFC neurons
        output stand for output current to PFC neurons

        Args:
            input: array (n_neuron,)
            input_x: array (n_neuron,), modulatory input that multiplicatively interact with the neurons

        Returns:
            output: array (n_output,)
        """
        if input_x is None:
            input_x = torch.zeros(input.shape)

        xadd = torch.matmul(self.Jrec, self.activity)
        xadd += input_x + input  # MD inputs
        self.xinp += self.dt / self.tau * (-self.xinp + xadd)
        rout = self.activation(self.xinp)
        self.activity = rout
        return rout

## Training

In [None]:
RNGSEED = 5 # set random seed
np.random.seed([RNGSEED])

Ntrain = 500            # number of training cycles for each context
Nextra = 200            # add cycles to show if block1
Ncontexts = 2           # number of cueing contexts (e.g. auditory cueing context)
inpsPerConext = 2       # in a cueing context, there are <inpsPerConext> kinds of stimuli
                         # (e.g. auditory cueing context contains high-pass noise and low-pass noise)

# generate trainset
dataset = RihkyeTask(Ntrain=Ntrain, Nextra=Nextra, Ncontexts=Ncontexts, inpsPerConext=inpsPerConext, blockTrain=True)

# model parameters
n_neuron = 1000
n_neuron_per_cue = 200
Num_MD = 10
num_active = 5  # num MD active per context
n_output = 2
MDeffect = True

model = FullNetwork(n_neuron, n_neuron_per_cue, Num_MD, num_active, MDeffect=MDeffect)

In [None]:
import pickle
from collections import defaultdict
from pathlib import Path
from tqdm import tqdm

log = defaultdict(list)

num_cycle_train = Ntrain*Ncontexts+Nextra
mses = list()
MDpreTraces = np.zeros(shape=(num_cycle_train,n_neuron))
MDouts_all = np.zeros(shape=(num_cycle_train,Num_MD))
PFCouts_all = np.zeros(shape=(num_cycle_train,n_neuron))

for i in tqdm(range(num_cycle_train)):
    input, target = dataset()
    output = model(input, target)
    mse = np.mean((output - target)**2)*Ncontexts # one cycle has Ncontexts

#    mse = np.mean((output[:200] - target[:200])**2)
#    mse += np.mean((output[200:] - target[200:])**2)
    PFCouts_all[i,:] = model.pfc.activity
    log['mse'].append(mse)
    if  MDeffect == True:
        MDouts_all[i,:] = model.md_output
        MDpreTraces[i,:] = model.md.MDpreTrace

# write MD weights
if  MDeffect == True:  
    log['wPFC2MD'] = model.md.wPFC2MD
    log['wMD2PFC'] = model.md.wMD2PFC
    log['wMD2PFCMult'] = model.md.wMD2PFCMult

# write model
filename = Path('files')
os.makedirs(filename, exist_ok=True)
file_training = 'train_numMD'+str(Num_MD)+'_numContext'+str(Ncontexts)+'_MD'+str(MDeffect)+'_R'+str(RNGSEED)+'.pkl'
with open(filename / file_training, 'wb') as f:
    pickle.dump(log, f)