In [111]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

In [112]:
settype="250full"
repeats=3
batch_size=3
epochs=200

In [113]:
class MultiRegression(nn.Module):
    def __init__(self):
        super(MultiRegression, self).__init__()
        self.l1 = nn.Linear(250, 10)

    def forward(self, x):
        x = self.l1(x)
        x = F.softmax(x,dim=1).float()
        return x
    

In [114]:
def prediction_accuracy(nn,batch_size):
    correct = 0
    total = 0
    for i in range(int((len(testspikes))/batch_size)):    
        spikes = Variable(torch.from_numpy(testspikes[i:i+batch_size]).float())
        truth=(testtruth[i:i+batch_size])
        outputs = nn.forward(spikes)
        prediction = outputs.data.numpy().argmax(axis=1)
        total += len(truth)
        correct += (prediction == truth).sum()
    #print('Accuracy of the model on the '+str(len(testtruth))+' test images: '+str(format(100 * correct / total, '.3f'))+"%")
    return correct / total

In [115]:
def train(batch_size,epochs,learnr=0.3,feedback_every_epoch=10):
    #Learning properties
    neuralNet =[]
    neuralNet = MultiRegression()
    lossFunc =nn.NLLLoss(weight=weights)
    parameters=neuralNet.parameters()
    optimizer = torch.optim.Adadelta(parameters,lr=learnr)
    results=[]    
    for epoch in range(epochs):
        running_loss = 0.0
        running_correct=0.0
        batches=int(len(trainspikes)/batch_size)
        for i in range(batches):
            spikes=trainspikes[i:i+batch_size]
            categories=traintruth[i:i+batch_size]
            spikes=Variable(torch.from_numpy(spikes).float())
            categories=Variable(torch.from_numpy(categories))

            optimizer.zero_grad()

            outputs = neuralNet(spikes)
            error = lossFunc(outputs, categories)
            error.backward()
            optimizer.step()

            running_loss += error.data[0]
            running_correct +=np.sum(categories.data.numpy()==outputs.data.numpy().argmax(axis=1))
        if feedback_every_epoch!=0 and epoch%feedback_every_epoch==feedback_every_epoch-1:
            print("epoch: "+str(epoch+1)+", running loss: "+str(running_loss)+", running correct: "+str(int(running_correct)))
        results.append(prediction_accuracy(neuralNet,1))

    print('fine')
    classpercentages=[]
    #compute TPR for classes
    for i in range(len(output_class_map)):
        iindices=(testtruth==i)
        stimuli=testspikes[iindices]
        classresult=(neuralNet(Variable(torch.from_numpy(stimuli).float()))).data.numpy().argmax(axis=1)
        classpercentages.append(np.sum(classresult==i)/len(classresult)*100)
    return [results,running_correct/len(trainspikes)*100,classpercentages]


In [116]:
#train test split indices generation
def split_data(parts=11):
    testelements=int(len(truth)/parts)
    testlist=[]
    trainlist=[]
    indices=np.arange(len(truth))
    np.random.shuffle(indices)
    for i in range(0,parts):
        testind=indices[testelements*i:testelements*(i+1)]        
        trainind1=indices[:testelements*i]
        trainind2=indices[testelements*(i+1):]
        trainind=np.concatenate((trainind1,trainind2))
        testlist.append(testind)
        trainlist.append(trainind)
    return [trainlist,testlist]    

In [117]:
#location of computed ganglion activations with class information
data_folder=".../ganglion-activation/"
#load output (all) or selected (easy10+difficult100) file from condition folder
data=np.load(data_folder+condition+"/"+settype+".npz")

activations=data['stimuli']
truth=data['truth']
#order of classes like encoded as ints in the output
output_class_map=['chair','dog','bird','bottle','boat','tvmonitor','horse','aeroplane','person','car']

In [119]:
%%time
types=['250full','250balanced']
for datatype in types:
    for condition in os.listdir(data_folder):
        print(condition)
        data=np.load(data_folder+condition+"/"+datatype+".npz")
        activations=data['stimuli']
        truth=data['truth']
        
        #compute weights to account for different frequency of classes in dataset
        weights=np.zeros(10)
        for i in range(10):
            weights[i]=1-np.sum(truth==i)/len(truth)
        weights=torch.from_numpy(weights).float()

        #multiple dataset configuration averaging
        testaccuracies=[]
        trainaccuracies=[]
        classpercentages=[]
        for k in range(repeats):
            print('split ',k)
            indices=split_data()
            trainindices=indices[0]
            testindices=indices[1]
            #k-fold cross validation
            for j in range(len(trainindices)):
                trainspikes=activations[trainindices[j]]
                testspikes=activations[testindices[j]]
                traintruth=truth[trainindices[j]]
                testtruth=truth[testindices[j]]
                results=train(batch_size,epochs,feedback_every_epoch=epochs)
                testaccuracies.append(results[0])
                trainaccuracies.append(results[1])
                classpercentages.append(results[2])
        #save repeats*cross-validation results for each condition
        np.savez(data_folder+condition+"/learning-batch3"+datatype,testaccuracy=testaccuracies,trainaccuracy=trainaccuracies,classpercentage=classpercentages)

original
split  0
epoch: 200, running loss: -705.5091481688485, running correct: 2211
fine
epoch: 200, running loss: -731.4061663948152, running correct: 2298
fine
epoch: 200, running loss: -716.8418069696249, running correct: 2278
fine
epoch: 200, running loss: -721.1557875811352, running correct: 2271
fine
epoch: 200, running loss: -720.1817283512911, running correct: 2260
fine
epoch: 200, running loss: -724.5626672200235, running correct: 2269
fine
epoch: 200, running loss: -721.3493270327231, running correct: 2244
fine
epoch: 200, running loss: -723.184050637547, running correct: 2262
fine
epoch: 200, running loss: -726.9645628825976, running correct: 2287
fine
epoch: 200, running loss: -723.8664134974099, running correct: 2268
fine
epoch: 200, running loss: -723.0689466683336, running correct: 2270
fine
split  1
epoch: 200, running loss: -733.1517148068124, running correct: 2304
fine
epoch: 200, running loss: -718.0467868843681, running correct: 2252
fine
epoch: 200, running loss:

epoch: 200, running loss: -702.3824064786177, running correct: 2211
fine
epoch: 200, running loss: -707.1057964482254, running correct: 2214
fine
epoch: 200, running loss: -700.4839998576667, running correct: 2201
fine
epoch: 200, running loss: -694.6221061254187, running correct: 2184
fine
epoch: 200, running loss: -697.4965957689614, running correct: 2177
fine
epoch: 200, running loss: -699.3504542879955, running correct: 2212
fine
epoch: 200, running loss: -695.6857362615743, running correct: 2170
fine
epoch: 200, running loss: -697.8230667210889, running correct: 2187
fine
epoch: 200, running loss: -699.5529790452558, running correct: 2201
fine
epoch: 200, running loss: -700.2195027043077, running correct: 2202
fine
split  2
epoch: 200, running loss: -687.6812259602939, running correct: 2163
fine
epoch: 200, running loss: -706.3765576557246, running correct: 2232
fine
epoch: 200, running loss: -698.8094458671468, running correct: 2201
fine
epoch: 200, running loss: -711.71999817872

epoch: 200, running loss: -665.4921394200126, running correct: 2108
fine
epoch: 200, running loss: -666.7404304757335, running correct: 2105
fine
epoch: 200, running loss: -661.7586638927274, running correct: 2093
fine
epoch: 200, running loss: -662.6118906388883, running correct: 2087
fine
epoch: 200, running loss: -665.1893632114491, running correct: 2123
fine
epoch: 200, running loss: -667.6925832754326, running correct: 2113
fine
epoch: 200, running loss: -668.529762955504, running correct: 2120
fine
epoch: 200, running loss: -669.6992408324194, running correct: 2120
fine
epoch: 200, running loss: -666.425396015832, running correct: 2120
fine
background-blurring
split  0
epoch: 200, running loss: -730.434698934043, running correct: 2298
fine
epoch: 200, running loss: -722.3163091963206, running correct: 2275
fine
epoch: 200, running loss: -728.550025822791, running correct: 2304
fine
epoch: 200, running loss: -708.9076743100181, running correct: 2229
fine
epoch: 200, running loss: 

epoch: 200, running loss: -292.5509624183178, running correct: 887
fine
epoch: 200, running loss: -297.6686918735504, running correct: 911
fine
epoch: 200, running loss: -296.4741432070732, running correct: 905
fine
epoch: 200, running loss: -299.220978975296, running correct: 914
fine
epoch: 200, running loss: -297.02669855952263, running correct: 899
fine
epoch: 200, running loss: -299.1882465183735, running correct: 914
fine
epoch: 200, running loss: -298.3634642958641, running correct: 914
fine
epoch: 200, running loss: -296.598026484251, running correct: 907
fine
split  1
epoch: 200, running loss: -303.6489636898041, running correct: 927
fine
epoch: 200, running loss: -299.82828322052956, running correct: 915
fine
epoch: 200, running loss: -299.48974096775055, running correct: 912
fine
epoch: 200, running loss: -301.5171903371811, running correct: 921
fine
epoch: 200, running loss: -297.1290137767792, running correct: 912
fine
epoch: 200, running loss: -296.10401329398155, running

epoch: 200, running loss: -301.5683980316062, running correct: 913
fine
epoch: 200, running loss: -303.64791395015163, running correct: 924
fine
epoch: 200, running loss: -304.8377741377044, running correct: 927
fine
epoch: 200, running loss: -302.3727203819799, running correct: 915
fine
epoch: 200, running loss: -303.18533737230064, running correct: 921
fine
epoch: 200, running loss: -303.725178537622, running correct: 924
fine
split  2
epoch: 200, running loss: -302.11619725804485, running correct: 914
fine
epoch: 200, running loss: -305.34184804558754, running correct: 926
fine
epoch: 200, running loss: -293.34529161453247, running correct: 889
fine
epoch: 200, running loss: -301.6449859966524, running correct: 911
fine
epoch: 200, running loss: -302.42510265111923, running correct: 915
fine
epoch: 200, running loss: -305.42426431179047, running correct: 927
fine
epoch: 200, running loss: -302.9546859264374, running correct: 915
fine
epoch: 200, running loss: -302.39454129338264, ru

In [305]:
print('Testset category occurences:')
for i in range(10):    
    print(output_class_map[i]+": "+str(np.sum(testtruth==i))+"; "+str(np.sum(testtruth==i)/len(testtruth)*100)+"%")

Testset category occurences:
chair: 16; 4.9079754601226995%
dog: 66; 20.245398773006134%
bird: 28; 8.588957055214724%
bottle: 12; 3.6809815950920246%
boat: 25; 7.668711656441718%
tvmonitor: 13; 3.9877300613496933%
horse: 15; 4.601226993865031%
aeroplane: 31; 9.509202453987731%
person: 84; 25.766871165644172%
car: 36; 11.042944785276074%


In [304]:
print('Trainset category occurences:')
for i in range(10):    
    print(output_class_map[i]+": "+str(np.sum(testtruth==i)))

Trainset category occurences:
chair: 16
dog: 66
bird: 28
bottle: 12
boat: 25
tvmonitor: 13
horse: 15
aeroplane: 31
person: 84
car: 36


In [None]:
#ANALYSIS

In [185]:
condition='outline-only'
datatype='250balanced'

In [183]:
results=np.load(data_folder+condition+"/learning"+datatype+".npz")
testaccuracies=results['testaccuracy']
trainaccuracies=results['trainaccuracy']
classpercentages=results['classpercentage']

In [184]:
print('mean '+str(np.mean(testaccuracies)))
print('max '+str(np.max(testaccuracies)))
print('train '+str(np.mean(trainaccuracies)))
for i in range(10):
    print(output_class_map[i]+": "+str(np.mean(classpercentages,axis=0)[i]/100))

mean 0.21041969696969698
max 0.38
train 90.38181818181818
chair: 0.13067499132092433
dog: 0.23347574143028688
bird: 0.17598453734817374
bottle: 0.17083532053627745
boat: 0.12847699438608529
tvmonitor: 0.25655922360467814
horse: 0.16334411280935343
aeroplane: 0.4331760856493477
person: 0.12110050148327661
car: 0.1994695787476536
