In [1]:
import numpy as np
%matplotlib inline
import pylab
import seaborn
import nengo
import nengo.spa as spa




In [77]:
class RIFModel(object):
    def __init__(self, mapping, D_category=16, D_items=64, threshold=0.4, learning_rate=1e-4):
        model = spa.SPA()
        self.model = model
        self.mapping = mapping
        self.vocab_category = spa.Vocabulary(D_category)
        self.vocab_items = spa.Vocabulary(D_items)
        for k in mapping.keys():
            self.vocab_category.parse(k)
            for v in mapping[k]:
                self.vocab_items.parse(v)
        
        with model:
            model.category = spa.State(D_category, vocab=self.vocab_category)
            
            model.items = spa.State(D_items)

            def learned(x):
                cats = np.dot(self.vocab_category.vectors, x)
                best_index = np.argmax(cats)
                if cats[best_index] < threshold:
                    return self.vocab_items.parse('0').v
                else:
                    k = self.vocab_category.keys[best_index]
                    total = '+'.join(self.mapping[k])
                    return self.vocab_items.parse(total).v

            c = nengo.Connection(model.category.all_ensembles[0], model.items.input, 
                             function=learned, learning_rule_type=nengo.PES(learning_rate=learning_rate))

            
            model.error = spa.State(D_items)
            nengo.Connection(model.items.output, model.error.input)
            nengo.Connection(model.error.output, c.learning_rule)
            
            
            self.stim_category_value = np.zeros(D_category)
            self.stim_category = nengo.Node(self.stim_category)
            nengo.Connection(self.stim_category, model.category.input, synapse=None)

            self.stim_correct_value = np.zeros(D_items)
            self.stim_correct = nengo.Node(self.stim_correct)
            nengo.Connection(self.stim_correct, model.error.input, synapse=None, transform=-1)

            self.stim_stoplearn_value = np.zeros(1)
            self.stim_stoplearn = nengo.Node(self.stim_stoplearn)
            for ens in model.error.all_ensembles:
                nengo.Connection(self.stim_stoplearn, ens.neurons, synapse=None, transform=-10*np.ones((ens.n_neurons, 1)))
            
            
            
            self.probe_items = nengo.Probe(model.items.output, synapse=0.01)
            
        self.sim = nengo.Simulator(self.model)
        
    def stim_category(self, t):
        return self.stim_category_value

    def stim_correct(self, t):
        return self.stim_correct_value
    
    def stim_stoplearn(self, t):
        return self.stim_stoplearn_value
    
    
    def test(self, category, T=0.5):
        self.stim_stoplearn_value = 1
        self.stim_category_value = self.vocab_category.parse(category).v
        self.stim_correct_value = self.vocab_items.parse('0').v
        self.sim.run(T)
        d = self.sim.data[self.probe_items]
        #self.sim.data[self.probe_items]
        
        return np.dot(self.vocab_items.vectors, d[-1])
    
    def practice(self, category, item, T=0.5):
        self.stim_stoplearn_value = 0
        self.stim_category_value = self.vocab_category.parse(category).v
        self.stim_correct_value = self.vocab_items.parse(item).v
        self.sim.run(T)
        
        
        
        
        
        
        
        

In [78]:
mapping = {
    'ANIMAL': ['DOG', 'CAT', 'RAT'],
    'COLOR': ['RED', 'BLUE', 'GREEN'],
}

m = RIFModel(mapping, learning_rate=1e-5)

In [79]:
print m.vocab_items.keys

['RED', 'BLUE', 'GREEN', 'DOG', 'CAT', 'RAT']


In [80]:
m.test('COLOR')

Simulation finished in 0:00:02.                                                 


array([ 0.99992329,  1.00532139,  0.97266045, -0.03696156, -0.57618808,
       -0.07140686])

In [81]:
m.test('ANIMAL')

Simulation finished in 0:00:02.                                                 


array([-0.2985458 , -0.10978881, -0.22965498,  0.69084005,  0.78006378,
        0.73885816])

In [82]:
m.practice('COLOR', 'RED')

Simulation finished in 0:00:02.                                                 


In [84]:
m.practice('COLOR', 'BLUE')

Simulation finished in 0:00:02.                                                 


In [85]:
m.test('COLOR')


Simulation finished in 0:00:03.                                                 


array([ 0.94128963,  0.97134279,  0.88818391, -0.0059253 , -0.54358464,
       -0.06306801])

In [74]:
m.test('ANIMAL')

Simulation finished in 0:00:02.                                                 


array([ 0.08751561, -0.34436001, -0.26907267,  0.6062511 ,  0.5593552 ,
        0.57091676])