In [1]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch

from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

In [2]:
sys.path.append("/home/rzhu/Documents/nwn_l2l/")
from nwnTorch.jn_models import *
from nwnTorch.nwn import *
from nwnTorch.generate_adj import *
from nwnTorch.misc import *

In [3]:
# Temporary usage for gpu. need to refine
# torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.set_default_tensor_type('torch.FloatTensor')

In [51]:
batch_size  = 8
data_path   = "/home/rzhu/data_access/data/mnist"
num_classes = 10

transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])
            
mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test  = datasets.MNIST(data_path, train=False, download=True, transform=transform)

In [52]:
con     = pkl_load("/home/rzhu/data_access/l2l_data/volterra_data/con0.pkl")
adj     = torch.tensor(con["adj_matrix"])
net     = NWN(adj, "sydney")

# distMat = graphical_distance(net.adjMat)
# R,C = np.where(distMat == distMat.max())

In [53]:
net                     = NWN(adj, "sydney")
E                       = net.number_of_junctions
net.params["Ron"]       = 1e4
net.params["grow"]      = 5
net.params["decay"]     = 10
net.params["precision"] = True
net.params["collapse"]  = False
net.params["dt"]        = 1e-3
# net.junction_state.L    = torch.rand(E) * 0.3 - 0.15

In [54]:
train_loader = DataLoader(mnist_train, batch_size=batch_size)
train_batch = iter(train_loader)

In [55]:
kernel_size = 3
electrodes  = torch.randperm(1024)[:kernel_size**2 + 1]

In [56]:
dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [57]:
import dask
from dask.distributed import Client, LocalCluster

In [58]:
def run_batch(net, electrodes, 
              scan_batch):
    batch_size, ksqr, num_steps = scan_batch.shape

    readout    = torch.zeros(batch_size, num_steps, ksqr+1)
    
    for i in tqdm(range(batch_size)):
        sample = scan_batch[i]
        sig_in = torch.zeros(kernel_size**2+1)

        for t in range(num_steps):
            sig_in[:-1] = sample[:,t]
            net.sim(sig_in.reshape(1,-1), electrodes)
            readout[i,t,:] = net.I

    net.steps = 12345
    return readout

In [25]:
# cluster = LocalCluster(
#                     n_workers = 2,
#                     threads_per_worker = 1,
#                     scheduler_port = 12121,
#                     dashboard_address = 'localhost:11113',
#                     )

In [34]:
client = Client(cluster)

In [35]:
client

0,1
Client  Scheduler: tcp://127.0.0.1:12121  Dashboard: http://127.0.0.1:11113/status,Cluster  Workers: 2  Cores: 2  Memory: 33.60 GB


In [59]:
scanner    = nn.Unfold(kernel_size = kernel_size)
counter    = 0
job_pool   = []
label_pool = []

for batch_data, batch_label in train_batch:
    scan_batch = scanner(batch_data)
    eval       = dask.delayed(run_batch)(net, electrodes, scan_batch)
    job_pool.append(eval)
    label_pool.append(batch_label)

    counter += 1
    if counter > 500:
        break

In [37]:
from dask.distributed import progress

In [38]:
# from_dask = client.compute(job_pool)
# collected = client.gather(from_dask)

from_dask = client.persist(job_pool)
collected = [temp_future.compute() for temp_future in from_dask]
progress(from_dask)

VBox()

In [66]:
labels = torch.cat(label_pool)

In [41]:
readout = torch.cat(collected, axis = 0)

In [68]:
out_dict = {
    "readout": readout[:4000],
    "labels" : labels[:4000]
}
pkl_save(out_dict, "/home/rzhu/data_access/data/mnist_nwnset0.pkl")

In [57]:
# batch_data, batch_label = next(train_batch)

# scanner    = nn.Unfold(kernel_size = kernel_size)
# scan_batch = scanner(batch_data)
# duration   = scan_batch.shape[-1]
# readout    = torch.zeros(batch_size, duration, kernel_size**2+1)

# for i in tqdm(range(batch_size)):
#     sample = scan_batch[i]
#     sig_in = torch.zeros(kernel_size**2+1)

#     for t in range(duration):
#         sig_in[:-1] = sample[:,t]
#         net.sim(sig_in.reshape(1,-1), electrodes)
#         readout[i,t,:] = net.I

In [70]:
readout = torch.cat(collected, axis = 0)
F1      = readout.reshape(-1,26,26,10).max(axis=1).values
F2      = readout.reshape(-1,26,26,10).max(axis=2).values
F       = torch.cat((F1, F2), axis = 1)

In [79]:
split  = 3000
trainX = F[:split,:,:-1].reshape(split, -1)
testX  = F[split:,:,:-1].reshape(len(F)-split, -1)
trainY = labels[:split]
testY  = labels[split:]

In [81]:
def LDA_test(X, Y, test_X, test_Y, sub_sample = 1, return_map = False):
    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
    model = LDA(solver = 'lsqr', shrinkage = 'auto')
    model.fit(X[:,::sub_sample], Y)
    # result = model.predict(test_X[:,::sub_sample])
    acc = model.score(test_X, test_Y)
    if return_map:
        result = model.predict(test_X)
        sz = np.max(test_Y)+1
        fit_map = np.zeros((sz,sz))
        for i in range(len(result)):
            fit_map[test_Y[i], result[i]] += 1
        return acc, fit_map
    else:
        return acc

In [82]:
LDA_test(trainX, trainY, testX, testY)

0.9117063492063492

In [73]:
# Temporary usage for gpu. need to refine
# torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.set_default_tensor_type('torch.FloatTensor')
