In [None]:
import torch
import numpy as np

from tqdm import tqdm

from torch.utils.data import DataLoader

from sklearn.metrics import f1_score
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import StratifiedKFold

from yasf.layers.layers import InputLayer, LIFLayer
from yasf.learning.learning import STDP
from yasf.encoding.encoders import PoissonEncoder
from yasf.utils.data import DataWrapper

def run(params):
    
    gpu = params["gpu"]
    
    device = torch.device("cuda:0" if torch.cuda.is_available() and gpu else "cpu")
    
    print(f"Running on device: {device}.")

    X, Y = load_iris(as_frame=False, return_X_y=True)

    X = torch.as_tensor(X, dtype=torch.float32)
    y = torch.as_tensor(Y, dtype=torch.int32)

    skf = StratifiedKFold(n_splits=5)

    #encoding arguments
    time = int(params["time"])
    dt = 1
    intensity = params["intensity"]

    #neuron arguments
    lsz = int(params["lsz"])
    rest = 0
    reset = 0
    thresh = params["thresh"]#torch.normal(mean=params["thresh"] , std=params["tstd"] * torch.ones(lsz))
    refrac = params["refrac"]
    decay = params["decay"]

    #STDP arguments
    tau_plus = params["tau"]
    tau_minus = params["tau"]
    A_plus = params["lr"]
    A_minus = params["lr"]

    #net args
    init_w = 0.1
    wmin = 0.
    wmax = 1.

    f1s = []

    for train_idx, test_idx in skf.split(X,Y):

        train_ds = DataWrapper(x=X[train_idx],
                               y=Y[train_idx],
                               sample_transform=PoissonEncoder(time=time, dt=dt, intensity=intensity),
                               target_transform=None,
                               global_transform=None,
                              )

        train_loader = DataLoader(train_ds, batch_size = 1, shuffle = True)

        test_ds = DataWrapper(x=X[test_idx],
                               y=Y[test_idx],
                               sample_transform=PoissonEncoder(time=time, dt=dt, intensity=intensity),
                               target_transform=None,
                               global_transform=None,
                              )

        test_loader = DataLoader(test_ds, batch_size = 1, shuffle = False)

        net_weights = init_w * torch.rand(X.shape[-1],lsz).to(device)
        rec_weights = -params["inh"] * (torch.ones(lsz, lsz) - torch.diag(torch.ones(lsz))).to(device)

        input_layer = InputLayer(n = X.shape[-1], dt = dt, time = time, device = device)

        class_layer = LIFLayer(n = lsz,
            dt = dt,
            time = time,
            rest = rest,
            reset = reset,
            thresh = thresh,
            refrac = refrac,
            decay = decay,
            device = device, )

        stdp = STDP(tau_plus = tau_plus,
                    tau_minus = tau_minus,
                    A_plus = A_plus,
                    A_minus = A_minus,)

        train_samples = []
        train_labels = []
        test_samples = []
        test_labels = []

        for batch in tqdm(train_loader):


            spike_recorder = torch.zeros(lsz).to(device)

            #run simulation
            for ts in range(int(time/dt)):
                
                sample_spikes = batch[0].squeeze()[ts,:].to(device)
                
                input_spikes, input_times = input_layer(sample_spikes)
                output_spikes, output_times = class_layer(input_spikes, 
                                                          w = net_weights,
                                                          w_rec = rec_weights,
                                                         )
                dw = stdp(input_times, output_times)
                net_weights+=dw

                #clamp weights
                net_weights[net_weights < wmin] = wmin
                net_weights[net_weights > wmax] = wmax

                #print(torch.min(net_weights), torch.max(net_weights), torch.mean(net_weights))
                spike_recorder+=output_spikes

            spike_recorder /= int(time/dt)
            
            if device == torch.device("cuda:0"):
                spike_recorder = spike_recorder.cpu()

            train_samples.append(spike_recorder.numpy())
            train_labels.append(batch[1].squeeze().numpy())

            input_layer.reset_state()
            class_layer.reset_state()

        for batch in tqdm(test_loader):

            spike_recorder = torch.zeros(lsz).to(device)

            #run simulation
            for ts in range(int(time/dt)):
                input_spikes, input_times = input_layer(batch[0].squeeze()[ts,:].to(device))
                output_spikes, output_times = class_layer(input_spikes, 
                                                          w = net_weights,
                                                          w_rec = rec_weights,
                                                         )
                spike_recorder+=output_spikes

            spike_recorder /= int(time/dt)
            
            if device == torch.device("cuda:0"):
                spike_recorder = spike_recorder.cpu()

            test_samples.append(spike_recorder.numpy())
            test_labels.append(batch[1].squeeze().numpy())

            input_layer.reset_state()
            class_layer.reset_state()

        gbc = GradientBoostingClassifier()

        gbc.fit(train_samples, train_labels)
        
        f1 = f1_score(test_labels, gbc.predict(test_samples), average='macro')
        
        f1s.append(f1)

        print(f1)
        
        if f1 < 0.7:
            break
    
    print(params)
    return -np.mean(f1s)

In [None]:
%%time

params_rec = {'decay': 0.7040816326530612, 
              'inh': 34.0, 
              'intensity': 83.46938775510205, 
              'lr': 0.009448979591836736, 
              'lsz': 702.5252525252524, 
              'refrac': 1.5714285714285714, 
              'tau': 10.612244897959183, 
              'thresh': 4.7142857142857135, 
              'time': 800.0,
              'gpu':True
             }

run(params_rec)