# EBH model of associative memory

# Imports

In [1]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import time
import pickle
import math
import sparse_weights as sw
import corr2 as c2
import etile as et

# Initialize

In [2]:
# Shift and Next modifiers
# consonant vs vowel discrimination

# Start things up
np.random.seed(1)
nodetype = 'int16'

# read in stack of letters
font = 'HN'
n_pix = 100
n_let = 26
xjitter = 30
yjitter = 15
density = 15 # in thousandths
#with open('letters/letterstack_j{:d}_f{:d}.pkl'.format(jitter,fontsize), 'rb') as f:
with open('letters/letterstack_f{:s}_jx{:d}_jy{:d}_d{:d}.pkl'.format(font,xjitter,yjitter,density), 'rb') as f:
    newlist = pickle.load(f)
    letterstack = newlist[0].astype(nodetype)
    # symbols: forward, shift, lowercase, lower, querycase, upper
    symbols = newlist[1].astype(nodetype)
lettermat = np.reshape(letterstack,(2*n_let,n_pix*n_pix))
symbolmat = np.reshape(symbols,(6,n_pix*n_pix))
forward = 0
shift = 1
lowercase = 2
lower = 3
querycase = 4
upper = 5

# Initialize network
n_nodes = n_pix * n_pix
weight_density = 0.1        # 0.2 for 1% steady
weight_std = 0.0001           # 0.05 for 1% steady
learning_rate = 0.02
activation_threshold = 0.5  # 0.5 for 1% steady

weights = sw.sparse_weights(n_nodes,density=weight_density,weight_std=weight_std)
    
def stepfunc(a,thresh):
    return (a>thresh).astype(nodetype)


# Go

In [3]:
# Learn: sequence for all 52
#        upshift for lower 26, downshift for upper 26
#        querry case for all 52

# Go
reps = 1

seqin = np.zeros((2*n_let,n_pix*n_pix))
seqout = np.zeros((2*n_let,n_pix*n_pix))
shiftin = np.zeros((2*n_let,n_pix*n_pix))
shiftout = np.zeros((2*n_let,n_pix*n_pix))
casein = np.zeros((2*n_let,n_pix*n_pix))
caseout = np.zeros((2*n_let,n_pix*n_pix))

clip_weights = False
normalize_weights = False

# Train
for irep in range(reps):
    # Sequence
    for iletter in range(n_let-1): # -1 because z has no next letter
        # lower case
        inputnodes = lettermat[iletter,:] + symbolmat[forward,:]
        outputnodes = lettermat[iletter+1,:]
        weights.data += learning_rate * outputnodes[weights.row] * inputnodes[weights.col]
        # upper case
        inputnodes = lettermat[iletter+n_let,:] + symbolmat[forward,:]
        outputnodes = lettermat[iletter+n_let+1,:]
        weights.data += learning_rate * outputnodes[weights.row] * inputnodes[weights.col]
    print('done training sequence')
    # Upshift and Downshift
    for iletter in range(n_let):
        # lower case
        inputnodes = lettermat[iletter,:] + symbolmat[shift,:]
        outputnodes = lettermat[iletter+n_let,:]
        weights.data += learning_rate * outputnodes[weights.row] * inputnodes[weights.col]
        # upper case
        inputnodes = lettermat[iletter+n_let,:] + symbolmat[lowercase,:]
        outputnodes = lettermat[iletter,:]
        weights.data += learning_rate * outputnodes[weights.row] * inputnodes[weights.col]
    print('done training shift')
    # Case query
    for iletter in range(n_let):
        # lower case
        inputnodes = lettermat[iletter,:] + symbolmat[querycase,:]
        outputnodes = symbolmat[lower,:]
        weights.data += learning_rate * outputnodes[weights.row] * inputnodes[weights.col]
        # upper case
        inputnodes = lettermat[iletter+n_let,:] + symbolmat[querycase,:]
        outputnodes = symbolmat[upper,:]
        weights.data += learning_rate * outputnodes[weights.row] * inputnodes[weights.col]
    print('done training query')
    
# Test
# Sequence
for iletter in range(n_let):
    # lower case
    inputnodes = lettermat[iletter,:] + symbolmat[forward,:]
    seqin[iletter,:] = inputnodes.copy()
    seqout[iletter,:] = weights.dot(inputnodes)
    # upper case
    inputnodes = lettermat[iletter+n_let,:] + symbolmat[forward,:]
    seqin[iletter+n_let,:] = inputnodes.copy()
    seqout[iletter+n_let,:] = weights.dot(inputnodes)
# Upshift and Downshift
for iletter in range(n_let):
    # lower case
    inputnodes = lettermat[iletter,:] + symbolmat[shift,:]
    shiftin[iletter,:] = inputnodes.copy()
    shiftout[iletter,:] = weights.dot(inputnodes)
    # upper case
    inputnodes = lettermat[iletter+n_let,:] + symbolmat[lowercase,:]
    shiftin[iletter+n_let,:] = inputnodes.copy()
    shiftout[iletter+n_let,:] = weights.dot(inputnodes)
# Case query
for iletter in range(n_let):
    # lower case
    inputnodes = lettermat[iletter,:] + symbolmat[querycase,:]
    casein[iletter,:] = inputnodes.copy()
    caseout[iletter,:] = weights.dot(inputnodes)
    # upper case
    inputnodes = lettermat[iletter+n_let,:] + symbolmat[querycase,:]
    casein[iletter+n_let,:] = inputnodes.copy()
    caseout[iletter+n_let,:] = weights.dot(inputnodes)
print('Done testing')
            

done training sequence
done training shift
done training query
Done testing


In [4]:
# Correlate output against all letters
corrseq   = c2.corr2(seqout,lettermat)
corrshift = c2.corr2(shiftout,lettermat)
corrcase  = c2.corr2(caseout,symbolmat[(lower,upper),:])
# Winner Take All output
outseq    = np.argmax(corrseq,axis=1)
outshift  = np.argmax(corrshift,axis=1)
outcase   = np.argmax(corrcase,axis=1)
# Number correct
correct_seq   = np.sum(outseq[0:25]==np.arange(1,26,dtype=int)) + np.sum(outseq[26:51]==np.arange(27,52,dtype=int))
correct_shift = np.sum(outshift[0:26]==np.arange(26,52,dtype=int)) + np.sum(outshift[26:52]==np.arange(0,26,dtype=int))
correct_case  = np.sum(outcase[0:26]==0) + np.sum(outcase[26:52]==1)
# fraction correct
frac_seq   = correct_seq/50.
frac_shift = correct_shift/52.
frac_case  = correct_case/52.
print(frac_seq,frac_shift,frac_case)
fdata = (frac_seq,frac_shift,frac_case)

0.56 0.8461538461538461 0.8846153846153846


In [5]:
fig,ax = plt.subplots(4, 2, figsize=(8,16))
im=ax[0][0].imshow(np.reshape(seqin[0,:],(n_pix,n_pix)))
im=ax[0][1].imshow(np.reshape(seqout[0,:],(n_pix,n_pix)),vmax=1.2)
im=ax[1][0].imshow(np.reshape(shiftin[0,:],(n_pix,n_pix)))
im=ax[1][1].imshow(np.reshape(shiftout[0,:],(n_pix,n_pix)),vmax=1.2)
im=ax[2][0].imshow(np.reshape(casein[0,:],(n_pix,n_pix)))
im=ax[2][1].imshow(np.reshape(caseout[0,:],(n_pix,n_pix)),vmax=4)
im=ax[3][0].imshow(np.reshape(letterstack[0,:]+np.sum(symbols,axis=0),(n_pix,n_pix)))

plt.show()
plt.savefig('associativeAB.png')

FigureCanvasNbAgg()

In [6]:
fig,ax = plt.subplots(figsize=(3,3))
b = plt.bar(np.arange(3), fdata, 0.6)
#plt.ylabel('Fraction Correct')
plt.xticks(np.arange(3), ('Seq', 'Shift', 'Case'))
plt.show()
plt.savefig('associativeC.pdf')

FigureCanvasNbAgg()

In [7]:
fig,ax = plt.subplots(2, 3, figsize=(10,6))
im=ax[0][0].imshow(corrseq)
im=ax[0][1].imshow(corrshift)
im=ax[0][2].imshow(corrcase)
ax[1][0].plot(outseq)
ax[1][1].plot(outshift)
ax[1][2].plot(outcase)
#fig.colorbar(im,ax=ax[0])
plt.show()

FigureCanvasNbAgg()

In [8]:
fig,ax = plt.subplots(6, 1, figsize=(8,16))
im=ax[0].imshow(et.etile(np.reshape(seqin,(-1,n_pix,n_pix))))
im=ax[1].imshow(et.etile(np.reshape(seqout,(-1,n_pix,n_pix))))
im=ax[2].imshow(et.etile(np.reshape(shiftin,(-1,n_pix,n_pix))))
im=ax[3].imshow(et.etile(np.reshape(shiftout,(-1,n_pix,n_pix))))
im=ax[4].imshow(et.etile(np.reshape(casein,(-1,n_pix,n_pix))))
im=ax[5].imshow(et.etile(np.reshape(caseout,(-1,n_pix,n_pix))))
#fig.colorbar(im,ax=ax[0])
plt.show()

FigureCanvasNbAgg()