In [None]:
import sys, random
from itertools import product
from datetime import datetime
import torch, torch.nn as nn, torch.optim as optim
import numpy as np
import cv2
import scipy.stats as stats
import matplotlib.pyplot as plt
import pingouin as pg
plt.rcParams["font.family"] = "Arial"

# Import custom modules
sys.path.append("../")
from models.network_hierarchical_recurrent import NetworkHierarchicalRecurrent
from plotting_functions import *

MODEL_PATH = ''

## Diagram

In [None]:
raw_images = np.load('./mnist/raw_images.npy')

for i in raw_images[:5]:
    i = (i-np.mean(i))/np.std(i)
    vmax = np.max(np.abs(i))
    plt.imshow(i, vmax=vmax, vmin=-vmax, cmap='gray')
    plt.axis('off')
    plt.show()
    
noise = np.random.normal(size=(5, 32, 32))
for i in noise:
    vmax = np.max(np.abs(i))
    plt.imshow(i, vmax=vmax, vmin=-vmax, cmap='gray')
    plt.axis('off')
    plt.show()    

## Load models

In [None]:
# Load model
model, hyperparameters, loss_history = NetworkHierarchicalRecurrent.load(
    model_path=MODEL_PATH,
    device='cpu',
    plot_loss_history=False
)

model_no_fb, hyperparameters, loss_history = NetworkHierarchicalRecurrent.load(
    model_path=MODEL_PATH,
    device='cpu',
    plot_loss_history=False
)
weights = model.rnn.weight_hh_l0.detach().cpu().numpy().copy()
weights[0:800, 800:1600] = 0
weights[800:1600, 1600:2400] = 0
model_no_fb.rnn.weight_hh_l0 = torch.nn.Parameter(torch.Tensor(weights).to('cpu'))

# Functions

In [None]:
# Function to store running (i.e., per minibatch) loss, to be average per epoch
def append_running_loss (running_loss_history, loss, loss_components):
    running_loss_history["i"] += 1
    running_loss_history["loss"] += loss.detach().cpu().numpy()

    for k, v in loss_components.items():
        try:
            v = v.detach().cpu().numpy()
        except:
            pass

        if k in running_loss_history:
            running_loss_history[k] += v
        else:
            running_loss_history[k] = v

# Function to append averaged loss values
def append_epoch_loss (epoch_loss, running_loss_history, epoch):
    i = running_loss_history["i"]
    
    epoch_loss['epochs'].append(epoch)

    for key in running_loss_history.keys():
        if key != 'i':
            if not key in epoch_loss:
                epoch_loss[key] = []
            epoch_loss[key].append(running_loss_history[key] / i)

class MNISTClassifier (nn.Module):
    def __init__ (self, model, n_back):
        super(MNISTClassifier, self).__init__()

        self.device      = 'cpu'
        self.n_back      = n_back
        self.x_          = np.load('./mnist/res_images.npy')
        self.y_          = np.load('./mnist/labels.npy')

        self.model       = model
        
        self.fc_in_size  = 2400
        self.fc_out_size = 10

        self.fc0         = nn.Linear(
            in_features=self.fc_in_size,
            out_features=self.fc_out_size
        )
        
    def preprocess_data (self):
        res_images_, labels_, n_back = self.x_, self.y_, self.n_back
        
        res_images_processed = np.zeros(
            (res_images_.shape[0], 4+1+n_back, 20*40),
            dtype=res_images_.dtype
        )
        res_images_processed[:, :4] = np.random.normal(size=(res_images_processed[:, :4].shape))
        res_images_processed[:, 4] = res_images_
        res_images_processed[:, 4+1:] = np.random.normal(size=(res_images_processed[:, 4+1:].shape))

        res_images_batched = res_images_processed.reshape(-1, 100, 4+1+n_back, 20*40)
        labels_batched     = labels_.reshape(-1, 100, 1)

        res_images_batched = torch.from_numpy(res_images_batched).type(torch.FloatTensor)
        labels_batched     = torch.from_numpy(labels_batched).type(torch.LongTensor)

        train_idxs = int(res_images_batched.shape[0]*0.8)
        
        return (
            res_images_batched[:train_idxs], res_images_batched[train_idxs:],
            labels_batched[:train_idxs]    , labels_batched[train_idxs:] 
        )
        
    def get_model_outputs (self):
        x_raw_train, x_raw_test, y_train, y_test = self.preprocess_data()
        
        self.y_train = y_train
        self.y_test  = y_test
        
        x_processed_train = []
        x_processed_test  = []
        
        with torch.no_grad():
            for batch in x_raw_train:
                _, h = self.model(batch)
                x_processed_train.append(h[:, -1])

            for batch in x_raw_test:
                _, h = self.model(batch)
                x_processed_test.append(h[:, -1])
                
        self.x_train = x_processed_train
        self.x_test  = x_processed_test
                                
    def forward (self, x):
        return self.fc0(x)
        
    def get_loss (self, y, y_hat):
        cross_entropy = nn.functional.cross_entropy(y_hat, y[:, 0].long())
                
        classes = torch.argmax(y_hat, dim=1)
        accuracy = torch.sum(classes == y[:, 0]).item()/len(y)
                              
        return cross_entropy, { 'accuracy': accuracy }
    
    def train (self):
        optimizer = optim.Adam(self.parameters(), lr=10**-4)

        train_history = {'epochs': []}
        valid_history = {'epochs': []}

        for epoch in range(1, 100+1):
            for mode in ['validation', 'train']:
                if mode == 'validation':
                    x_data = self.x_test
                    y_data = self.y_test
                else:
                    x_data = self.x_train
                    y_data = self.y_train

                running_loss_history = { "i": 0, "loss": 0 }

                for x, y in zip(x_data, y_data):
                    if mode == 'train':
                        model.train()
                        optimizer.zero_grad()
                        y_hat = self(x)
                        loss, loss_components = self.get_loss(y, y_hat)
                        loss.backward()
                        optimizer.step()
                    else:
                        model.eval()
                        with torch.no_grad():
                            y_hat = self(x)
                            loss, loss_components = self.get_loss(y, y_hat)

                    append_running_loss(running_loss_history, loss, loss_components)

                if mode == 'train':
                    append_epoch_loss(train_history, running_loss_history, epoch)
                else:
                    append_epoch_loss(valid_history, running_loss_history, epoch)

            if epoch%10==0:
                print('Epoch: {}/{}.............'.format(epoch, 100), end=' ')
                print("Loss: {:.4f}.............".format(train_history['loss'][-1]), end=' ')
                print("Val accuracy: {:.4f}.............".format(valid_history['accuracy'][-1]), end=' ')
                print(datetime.now().strftime("%H:%M:%S"))

        return train_history, valid_history
    
    def get_accuracy (self):
        optimizer = optim.Adam(self.parameters(), lr=10**-4)

        x_data = torch.unsqueeze(torch.cat(self.x_test, dim=0), 1)
        y_data = self.y_test.reshape(-1, 1, 1)

        score = []
        
        for x, y in zip(x_data, y_data):
            model.eval()
            with torch.no_grad():
                y_hat = self(x)
                loss, loss_components = self.get_loss(y, y_hat)

            score.append(loss_components['accuracy'])
        
        return score

# Train

In [None]:
accuracy_data = {
    'full': [],
    'no_fb': []
}

models = {
    'full' : model,
    'no_fb' : model_no_fb,
}

n_back_arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

for model_key, model_i in models.items():
    print(model_key)
    
    for n_back in n_back_arr:
        print(n_back)
        
        MNIST_classifier = MNISTClassifier(model_i, n_back=n_back)
        MNIST_classifier.get_model_outputs()
        
        _, valid_history = MNIST_classifier.train()
        accuracy = valid_history['accuracy'][-1]
        accuracy_data[model_key].append(accuracy*100)
        
        print('\n')


fig = plt.figure()
plt.plot(n_back_arr, accuracy_data['full'] , label='Full model')
plt.plot(n_back_arr, accuracy_data['no_fb'], label='No feedback')
plt.plot([n_back_arr[0], n_back_arr[-1]], [10, 10], '--', c='black', label='Chance')
plt.xlabel('n-back')
plt.ylabel('MNIST accuracy (%)')
plt.ylim(0, 100)
format_plot(fontsize=20)
fig.set_size_inches(4, 4)
plt.gca().get_legend().set_bbox_to_anchor((1, 1))
plt.show()

In [None]:
accuracy_data = np.load('./nback_data.npy', allow_pickle=True).item()

n_back_arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

fig = plt.figure()
plt.plot(n_back_arr, accuracy_data['full'] , label='Full model')
plt.plot(n_back_arr, accuracy_data['no_fb'], label='No feedback')
plt.plot([n_back_arr[0], n_back_arr[-1]], [10, 10], '--', c='black', label='Chance')
plt.xlabel('n-back')
plt.ylabel('MNIST accuracy (%)')
plt.ylim(0, 100)
format_plot(fontsize=20)
fig.set_size_inches(4, 4)
plt.gca().get_legend().set_bbox_to_anchor((1, 1))
plt.show()

for p_full, p_nofb in zip(accuracy_data['full'], accuracy_data['no_fb']):
    p_full = p_full/100
    p_nofb = p_nofb/100
    
    x_full = p_full*140*100
    x_nofb = p_nofb*140*100
        
    p = (x_full+x_nofb)/(2*140*100)
        
    z = (p_full-p_nofb)/np.sqrt((2*p**(1-p))/(140*100))
    
    pval = stats.norm.sf(abs(z))*2
    
    print(z, pval)

# Shuffle weights

In [None]:
def shuffle_weights (W_, pct):
    if pct == 0:
        return W_
    
    W = W_.copy()

    n_weights = int(W.shape[0]*W.shape[1]*pct/100)
        
    pool = [*product(range(W.shape[0]), range(W.shape[1]))]
    idxs = random.sample(pool, n_weights)
    i_idxs, j_idxs = list(zip(*idxs))
    
    values = W[i_idxs, j_idxs]
    np.random.shuffle(values)
    W[i_idxs, j_idxs] = values
    
    return W

trained_accuracy_data = np.load('./nback_data.npy', allow_pickle=True).item()
random_accuracy_data  = np.load('./nback_data_randomised.npy', allow_pickle=True).item()

n_back_arr      = [5]
shuffle_pct_arr = [0, 25, 50, 75, 100]

shuffled_accuracy_data   = []

for shuffle_pct in shuffle_pct_arr:
    print(shuffle_pct, '% shuffled feedback')

    shuffled_model, hyperparameters, loss_history = NetworkHierarchicalRecurrent.load(
        model_path=MODEL_PATH,
        device='cpu',
        plot_loss_history=False
    )
    weights = model.rnn.weight_hh_l0.detach().cpu().numpy().copy()
    weights[0:800, 800:1600] = shuffle_weights(weights[0:800, 800:1600], shuffle_pct)
    weights[800:1600, 1600:2400] = shuffle_weights(weights[800:1600, 1600:2400], shuffle_pct)
    shuffled_model.rnn.weight_hh_l0 = torch.nn.Parameter(torch.Tensor(weights).to('cpu'))

    model_accuracy = []
    
    for n_back in n_back_arr:
        print('\t', n_back)

        MNIST_classifier = MNISTClassifier(shuffled_model, n_back=n_back)
        MNIST_classifier.get_model_outputs()

        _, valid_history = MNIST_classifier.train()
        accuracy = valid_history['accuracy'][-1]
        model_accuracy.append(accuracy*100)
        
    shuffled_accuracy_data += model_accuracy

fig = plt.figure()

plt.plot(shuffle_pct_arr, shuffled_accuracy_data, label='Shuffled feedback', c='tab:green')

plt.plot([100], [trained_accuracy_data['no_fb'][5]], '*', markersize=15, label='No feedback', c='tab:orange')

plt.xlabel('% shuffled feedback')
plt.ylabel('5-back MNIST accuracy (%)')
format_plot(fontsize=20)
fig.set_size_inches(4, 4)
plt.gca().get_legend().set_bbox_to_anchor((1, 1))
plt.gca().get_legend().get_title().set(fontsize=20)
plt.gca().get_legend().remove()
plt.show()

In [None]:
p_full = np.mean(shuffled_accuracy_data[-1])/100
p_nofb = trained_accuracy_data['no_fb'][5]/100

x_full = p_full*140*100
x_nofb = p_nofb*140*100

p = (x_full+x_nofb)/(2*140*100)

z = (p_full-p_nofb)/np.sqrt((2*p**(1-p))/(140*100))

pval = stats.norm.sf(abs(z))*2

print(z, pval)