https://towardsdatascience.com/how-to-use-pytorch-as-a-general-optimizer-a91cbf72a7fb

In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch import nn
from torch.functional import F
from copy import copy
import seaborn as sns

from jithub.models import model_classes
model = model_classes.ADEXPModel()
from neuronunit.optimization.model_parameters import MODEL_PARAMS
from neuronunit.tests.target_spike_current import SpikeCountSearch

import quantities as qt

sns.set_style("whitegrid")
n = 1000
noise = torch.Tensor(np.random.normal(0, 0.02, size=n))
x = torch.arange(n)

##
# simulate a ground truth neuron here.
##

a, k, b = 0.7, .01, 0.2
ground_truth = a * np.exp(-k * x) + b# + noise
print(np.shape(ground_truth))



torch.Size([1000])


In [2]:

class Model(nn.Module):
    """Custom Pytorch model for gradient optimization.
    """
    def __init__(self):
        
        super().__init__()
        weights = torch.distributions.Uniform(0.1, 1.0).sample((10,))
        self.weights = nn.Parameter(weights)        
        self.nspikes = 0
        
    def forward(self, X):
        """Implement function to be optimised. In this case, an exponential decay
        function (a + exp(-k * X) + b),
        """
        
        
        params = {}
        i = 0
        for k,j in MODEL_PARAMS['ADEXP'].items():
            wv = self.weights.detach().numpy()[0]
            params[k] = np.mean([wv*j[0],wv*j[1]])
            i += 1
        model = model_classes.ADEXPModel()
        model.attrs = params
        
        observation = {}
        observation["value"] = self.nspikes
        scs = SpikeCountSearch(observation)
        target_current = scs.generate_prediction(model)
        inject_param = {
            "padding":0 * qt.ms,
            "delay": 0 * qt.ms,
            "amplitude": 900 * qt.pA,
            "duration": 1000 * qt.ms,
            "dt":0.25
        }
        try:
            inject_param["amplitude"] = target_current["value"]* qt.pA
            model.inject_square_current(**inject_param)
            vm = model.get_membrane_potential()
            vm_=[v[0] for v in vm]
            #plt.plot(vm.times,vm)
            #plt.show()

            #concise_spk_time = model.get_spike_train()
        except:
            # some model parameters will cause try block to fail.
            #concise_spk_time = []
            vm = 1000.0
        #return concise_spk_time
        #a, k, b = self.weights
        print(vm)
        n = 1000
#noise = torch.Tensor(np.random.normal(0, 0.02, size=n))
        x = torch.arange(n)
        #import pdb
        #pdb.set_trace()
        return torch.as_tensor(vm_)#concise_spk_time#a * torch.exp(-k * X) + b
    
def training_loop(model, optimizer, n=1000):
    "Training loop for torch model."
    losses = []
    for i in range(n):
        preds = model(x)
        print(np.shape(preds),np.shape(ground_truth))
        model.nspikes=10
        print(preds.clone().detach().numpy())
        #print(ground_truth)
        loss = F.mse_loss(preds, ground_truth).sqrt()
        print(loss)
        #print(optimizer.requires_grad)
        print(loss.requires_grad)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        losses.append(loss)  
    return losses
"""
If you are familiar with Pytorch there is nothing too fancy going on here. The key thing that we are doing here is defining our own weights and manually registering these as Pytorch parameters — that is what these lines do:
weights = torch.distributions.Uniform(0, 0.1).sample((3,))
# make weights torch parameters
self.weights = nn.Parameter(weights)
The lines below detemine the function to be optimised. You can replace these with the definition of the function you want to minimise.
a, k, b = self.weights
return a * torch.exp(-k * X) + b
By calling nn.Parameter the weight we define will behave and function in the same way as standard Pytorch parameters — i.e they can calculate gradients and be updated in response to a loss function. The training loop is simply iterating over n epochs, each time estimating the mean squared error and updating the gradients.
Time to run the model, we’ll use Adam for the optimization.
# instantiate model
"""
m = Model()
m.n_spikes=10
# Instantiate optimizer


In [3]:
opt = torch.optim.Adam(m.parameters(), lr=0.001)
losses = training_loop(m, opt)
plt.figure(figsize=(14, 7))
plt.plot(losses)
print(m.weights)

#Losses over 1000 epochs — Image by Author..
#The plot above shows the loss function over 1000 epochs — you can see that after ~600 it is showing no signs of further improvement. The estimated weights for a, k, b are 0.697, 0.0099, 0.1996, so extremely close to the parameters that define the function and we can use the trained model to estimate the function:
preds = m(x)
plt.figure(figsize=(14, 7))
plt.scatter(x, preds.detach().numpy())
plt.scatter(x, y, alpha=.3)

[[-15.23517741]
 [-15.23594829]
 [-15.23670814]
 ...
 [-15.23069351]
 [-15.23069246]
 [-15.23069141]] mV


TypeError: len() of unsized object

In [None]:
np.shape(vm_)