# Learning rule from Loss Function

## General prescription

We consider the full neural+synaptic system in the QSS approximation, that is the synaptic subsystem and the quasi-steady-state values of $E$ and $I$. In this approximation, the general prescription to compute the learning rule from a loss function $L$ is:

1. Consider a generic loss function depending on $E$ and $I$  
$L = L(E,I)$


2. The learning rules are  
$\displaystyle \Delta W_{EE} = - \alpha \frac{\partial L}{\partial W_{EE}}$  
$\displaystyle \Delta W_{EI} = - \alpha \frac{\partial L}{\partial W_{EI}}$  
$\dots \mbox{etc.}$  


3. The partial derivatives of the loss function are  
$\displaystyle \frac{\partial L}{\partial W_{EE}} = \frac{\partial L}{\partial E} \frac{\partial E}{\partial W_{EE}} + \frac{\partial L}{\partial I} \frac{\partial I}{\partial W_{EE}}$  
$\displaystyle \frac{\partial L}{\partial W_{EI}} = \ldots$  


4. The partial derivatives $\displaystyle \frac{\partial E}{\partial W_{EE}}, \ldots$ etc. are to be taken from the quasi-steady-state values of $E$ and $I$

## Partial derivatives of $E$ and $I$

### Neural system

In [1]:
%%capture
load('up states - Neural subsystem stability.sage.py')

In [2]:
var('dEdWEE,dEdWEI,dEdWIE,dEdWII')
var('dIdWEE,dIdWEI,dIdWIE,dIdWII')
var('alpha')
show(f_E)
show(f_I)

### Quasi-steady-state approximation
(see e.g. ```up states - Homeostatic stability.ipynb```)

$\displaystyle \tau_E\frac{dE}{dt} \ll E \implies \tau_E\frac{dE}{dt} \sim 0$  
$\displaystyle \tau_I\frac{dI}{dt} \ll I \implies \tau_I\frac{dI}{dt} \sim 0$

then

$E = E_{ss}(W_{EE},\ldots)$  
$I = I_{ss}(W_{EE},\ldots)$

Compute the steady-state values $E=E_{ss}$ and $I=I_{ss}$ as implicit functions (the derivatives will have simpler expressions this way):  
$E_{ss} = f(E_{ss},I_{ss})$  
$E_{ss} = g(E_{ss},I_{ss})$

In [3]:
E_implicit = f_E.subs(dEdt==0)*tau_E + E
I_implicit = f_I.subs(dIdt==0)*tau_I + I
show(E_implicit)
show(I_implicit)

Also solve explicitly for $E$ and $I$:

In [4]:
neuralFixedPoint = solve([E_implicit,I_implicit],E,I)
E_ss = neuralFixedPoint[0][0]
I_ss = neuralFixedPoint[0][1]
show(E_ss)
show(I_ss)

Also solve explicitly for the weights:

In [5]:
var('E_set,I_set')
var('W_EEup,W_EIup,W_IEup,W_IIup')
Wup = [W_EE==W_EEup,W_EI==W_EIup,W_IE==W_IEup,W_II==W_IIup]
synapticFixedPoint = solve([E.subs(E_ss)==E_set,I.subs(I_ss)==I_set],W_EI,W_II)[0]
WEIWII_up = [synapticFixedPoint[k].subs(Wup) for k in [0,1]]
show(synapticFixedPoint)
show(WEIWII_up)

### Partial derivatives of $E$ and $I$ from their implicit expressions

Take the derivatives of the implicit expressions above assuming $E$ and $I$ depend on all weights and using the product rule when needed, e.g.:  
$\displaystyle \frac{\partial E}{\partial W_{EE}} = g_E E + g_E W_{EE}\frac{\partial E}{\partial W_{EE}} - g_E W_{EI}\frac{\partial I}{\partial W_{EE}}$

In [6]:
dEimpdWEE = dEdWEE == g_E*(E + W_EE*dEdWEE) - g_E*W_EI*dIdWEE
dEimpdWEI = dEdWEI == g_E*W_EE*dEdWEI - g_E*(I + W_EI*dIdWEI)
dEimpdWIE = dEdWIE == g_E*W_EE*dEdWIE - g_E*W_EI*dIdWIE
dEimpdWII = dEdWII == g_E*W_EE*dEdWII - g_E*W_EI*dIdWII
dIimpdWEE = dIdWEE == g_I*W_IE*dEdWEE - g_I*W_II*dIdWEE
dIimpdWEI = dIdWEI == g_I*W_IE*dEdWEI - g_I*W_II*dIdWEI
dIimpdWIE = dIdWIE == g_I*(E + W_IE*dEdWIE) - g_I*W_II*dIdWIE
dIimpdWII = dIdWII == g_I*W_IE*dEdWII - g_I*(I + W_II*dIdWII)
dEimp = [dEimpdWEE,dEimpdWEI,dEimpdWIE,dEimpdWII]
dIimp = [dIimpdWEE,dIimpdWEI,dIimpdWIE,dIimpdWII]
show(dEimpdWEE)
show(dEimpdWEI)
show(dEimpdWIE)
show(dEimpdWII)
show(dIimpdWEE)
show(dIimpdWEI)
show(dIimpdWIE)
show(dIimpdWII)

In [7]:
aux = solve([dEimp,dIimp],dEdWEE,dEdWEI,dEdWIE,dEdWII,dIdWEE,dIdWEI,dIdWIE,dIdWII)
gradE = aux[0][0:4]
gradI = aux[0][4:8]
show(gradE)
show(gradI)

## Loss function

$L(E,I) = \frac{1}{2}(E - E_{set})^2 + \frac{1}{2}(I - I_{set})^2$

### Exact result

Substitute the obtained expressions.  
Take for instance the change in $W_{EE}$:  
$\displaystyle \Delta W_{EE} = \eta \frac{\partial L}{\partial W_{EE}}$  
and substitute the obtained expressions into the derivative  
$\displaystyle \frac{\partial L}{\partial W_{EE}} = \frac{\partial L}{\partial E} \frac{\partial E}{\partial W_{EE}} + \frac{\partial L}{\partial I} \frac{\partial I}{\partial W_{EE}}$

In [8]:
L = (E - E_set)^2/2 + (I - I_set)^2/2
dLdE = diff(L,E)
dLdI = diff(L,I)
show(dLdE)
show(dLdI)

In [9]:
dLdWEE = dLdE*dEdWEE.subs(gradE) + dLdI*dIdWEE.subs(gradI)
dLdWEI = dLdE*dEdWEI.subs(gradE) + dLdI*dIdWEI.subs(gradI)
dLdWIE = dLdE*dEdWIE.subs(gradE) + dLdI*dIdWIE.subs(gradI)
dLdWII = dLdE*dEdWII.subs(gradE) + dLdI*dIdWII.subs(gradI)
deltaWEE = -alpha*dLdWEE.full_simplify()
deltaWEI = -alpha*dLdWEI.full_simplify()
deltaWIE = -alpha*dLdWIE.full_simplify()
deltaWII = -alpha*dLdWII.full_simplify()
show(deltaWEE)
show(deltaWEI)
show(deltaWIE)
show(deltaWII)

The prescriptions for $\Delta W_{EE}, \ldots$ etc. are complicated expressions that depend on all weights and on $E$ and $I$ (which also depend on all weights).

### Approximation: small weights

The partial derivatives $\displaystyle \frac{\partial E}{\partial WEE},\ldots$ are complicated functions of the weights but are very simple (linear) functions of $E$ and $I$, e.g.:

In [10]:
show(gradE[0].full_simplify().factor())

So we attempt a _sui generis_ multivariate Taylor expansion of the derivatives only, around zero as a function of the weights without further expanding $E$ and $I$.

First order approximation:

In [11]:
var('W_EE,W_EI,W_IE,W_II')
deltaWEE_sw = taylor(deltaWEE,(W_EE,0),(W_EI,0),(W_IE,0),(W_II,0),1)
deltaWEI_sw = taylor(deltaWEI,(W_EE,0),(W_EI,0),(W_IE,0),(W_II,0),1)
deltaWIE_sw = taylor(deltaWIE,(W_EE,0),(W_EI,0),(W_IE,0),(W_II,0),1)
deltaWII_sw = taylor(deltaWII,(W_EE,0),(W_EI,0),(W_IE,0),(W_II,0),1)
show(deltaWEE_sw)
show(deltaWEI_sw)
show(deltaWIE_sw)
show(deltaWII_sw)

### Approximation: weights are not small but near the Up state

Same idea as before: we attempt a _sui generis_ multivariate Taylor expansion of the derivatives $dE/dWEE, \ldots$ around the Up state, without further expanding $E$ and $I$.

In [12]:
var('W_EE,W_EI,W_IE,W_II')
degree = 0
aux = taylor(dEdWEE.subs(gradE),(W_EE,W_EEup),(W_EI,W_EIup),(W_IE,W_IEup),(W_II,W_IIup),degree)
dEdWEE_up = aux.subs(WEIWII_up).factor()
aux = taylor(dEdWEI.subs(gradE),(W_EE,W_EEup),(W_EI,W_EIup),(W_IE,W_IEup),(W_II,W_IIup),degree)
dEdWEI_up = aux.subs(WEIWII_up).factor()
aux = taylor(dEdWIE.subs(gradE),(W_EE,W_EEup),(W_EI,W_EIup),(W_IE,W_IEup),(W_II,W_IIup),degree)
dEdWIE_up = aux.subs(WEIWII_up).factor()
aux = taylor(dEdWII.subs(gradE),(W_EE,W_EEup),(W_EI,W_EIup),(W_IE,W_IEup),(W_II,W_IIup),degree)
dEdWII_up = aux.subs(WEIWII_up).factor()
show(dEdWEE_up)
show(dEdWEI_up)
show(dEdWIE_up)
show(dEdWII_up)

In [13]:
var('W_EE,W_EI,W_IE,W_II')
aux = taylor(dIdWEE.subs(gradI),(W_EE,W_EEup),(W_EI,W_EIup),(W_IE,W_IEup),(W_II,W_IIup),degree)
dIdWEE_up = aux.subs(WEIWII_up).factor()
aux = taylor(dIdWEI.subs(gradI),(W_EE,W_EEup),(W_EI,W_EIup),(W_IE,W_IEup),(W_II,W_IIup),degree)
dIdWEI_up = aux.subs(WEIWII_up).factor()
aux = taylor(dIdWIE.subs(gradI),(W_EE,W_EEup),(W_EI,W_EIup),(W_IE,W_IEup),(W_II,W_IIup),degree)
dIdWIE_up = aux.subs(WEIWII_up).factor()
aux = taylor(dIdWII.subs(gradI),(W_EE,W_EEup),(W_EI,W_EIup),(W_IE,W_IEup),(W_II,W_IIup),degree)
dIdWII_up = aux.subs(WEIWII_up).factor()
show(dIdWEE_up)
show(dIdWEI_up)
show(dIdWIE_up)
show(dIdWII_up)

In [14]:
var('W_EE,W_EI,W_IE,W_II')
dLdWEE_up = (dLdE*dEdWEE_up + dLdI*dIdWEE_up)
dLdWEI_up = (dLdE*dEdWEI_up + dLdI*dIdWEI_up)
dLdWIE_up = (dLdE*dEdWIE_up + dLdI*dIdWIE_up)
dLdWII_up = (dLdE*dEdWII_up + dLdI*dIdWII_up)
deltaWEE_up = -alpha*dLdWEE_up
deltaWEI_up = -alpha*dLdWEI_up
deltaWIE_up = -alpha*dLdWIE_up
deltaWII_up = -alpha*dLdWII_up
show(deltaWEE_up)
show(deltaWEI_up)
show(deltaWIE_up)
show(deltaWII_up)

Define new parameters and rewrite learning rules:

In [15]:
AA = positive_WEI_cond.subs(Wup).lhs()
BB = (sum(positive_WII_cond.subs(Wup).lhs().operands()[0:2])/g_I).factor()
CC = paradox_cond.lhs().subs(Wup)
DD = (I_up.rhs().numerator()/g_I).factor().subs(Wup)
show(AA)
show(BB)
show(CC)
show(DD)

The following three parameters are positive definite:  
1. $AA>0$ because it is equal to the "positive $W_{EI}$" condition
2. $BB>0$ because it is equal to the first two terms in the "positive $W_{II}$" condition
3. $DD>0$ because it is equal to the numerator of $I_{up}$ which must be positive (and its denominator is)

Parameter $CC$ is equal to the paradoxical condition, so it is positive if the system is in the paradoxical regime.

In [16]:
deltaWEE_up_v2 = alpha*W_IEup*g_E*I_set*E*(I_set-I)/DD + alpha*BB*g_E*E*(E_set-E)/DD
deltaWEI_up_v2 = -alpha*W_IEup*g_E*I_set*I*(I_set-I)/DD - alpha*BB*g_E*I*(E_set-E)/DD
deltaWIE_up_v2 = - alpha*AA*E*(E_set-E)/DD - alpha*CC*I_set*E*(I_set-I)/DD
deltaWII_up_v2 = alpha*CC*I_set*I*(I_set-I)/DD + alpha*AA*I*(E_set-E)/DD
show(deltaWEE_up_v2)
show(deltaWEI_up_v2)
show(deltaWIE_up_v2)
show(deltaWII_up_v2)
#confirm
show((deltaWEE_up - deltaWEE_up_v2).expand())
show((deltaWEI_up - deltaWEI_up_v2).expand())
show((deltaWIE_up - deltaWIE_up_v2).expand())
show((deltaWII_up - deltaWII_up_v2).expand())

### Full approximation

Substitute $E=E_{ss}(W_{EE},W_{EI},W_{IE},W_{II})$ and $I=I_{ss}(W_{EE},W_{EI},W_{IE},W_{II})$ and compute the full Taylor expansion as a function of the weights.
Now $\Delta W_{XY}$ is a polynomial function of all weights with no apparent homeostatic form:

In [17]:
var('W_EE,W_EI,W_IE,W_II')
deltaWEE_full = taylor(deltaWEE.subs(E_ss,I_ss),(W_EE,0),(W_EI,0),(W_IE,0),(W_II,0),1)
deltaWEI_full = taylor(deltaWEI.subs(E_ss,I_ss),(W_EE,0),(W_EI,0),(W_IE,0),(W_II,0),1)
deltaWIE_full = taylor(deltaWIE.subs(E_ss,I_ss),(W_EE,0),(W_EI,0),(W_IE,0),(W_II,0),1)
deltaWII_full = taylor(deltaWII.subs(E_ss,I_ss),(W_EE,0),(W_EI,0),(W_IE,0),(W_II,0),1)
R.<W_EE,W_EI,W_IE,W_II> = SR[]
deltaWEE_full_poly = deltaWEE_full.polynomial(SR)
deltaWEI_full_poly = deltaWEI_full.polynomial(SR)
deltaWIE_full_poly = deltaWIE_full.polynomial(SR)
deltaWII_full_poly = deltaWII_full.polynomial(SR)
show(deltaWEE_full_poly)
show(deltaWEI_full_poly)
show(deltaWIE_full_poly)
show(deltaWII_full_poly)