# Notebook for HGRU implementation

## Horizontal gated recurrent units (hGRUs)

**Inrtoduction**
* Recurrent neural model ofcontextual interactions developed by Mély 
* Model units are indexed by their 2D positions(x,y) and feature channel k. 
* Neural activity is governed by following differential equation

<img src="image\1.png">

$X \in R^{WxHxK}$  feedforward drive (neural response to stimulus)
$H^{1} \in R^{WxHxK}$ recurrent circuit input
$H^{2} \in R^{WxHxK}$ recurrent circuit 

*Advantages of Inhibition*
1. Separately input and output allows shunting (inhibition)
2. Excitation --> Acts linearly on the input
3. Inhibition --> Brings non linearity on Output

$W^{I}, W^{E} \in R^{WxHxK}$ inhibitory and exhibitory hypercolumn connecivity 

$\mu ~ \& ~ \alpha$ control linear and quadratic inhibition by $C^{1} \in R^{WxHxK}$

$\gamma$ scales excitation by $C^{2} \in R^{WxHxK}$

$\xi$ scales the feedforward drive

Each stage is linearly rectified (RELU) $[.]_{+} = max(.,0)$

$\eta, \epsilon, \tau, \sigma$ are time constant

* set $\eta = \tau, \sigma = \epsilon\$

$\triangle t = \frac{\eta}{\epsilon^{2}}$

#### Deriving HGRU from contextual neural circuit model

Rearraging the eq (1):
<img src="image\4.png">

Discretizing equation using the Eular's method, setting $\eta = \tau$ and $\sigma = \epsilon$

<img src="image\5.png">

Aplying with time steps *h*

<img src="image\6.png">



Distributing *h*

<img src="image\7.png">



choosing $h = \frac{\eta}{\epsilon^2}$

<img src="image\8.png">

Discrete time approximation of initial dynamical system:
* Because RNN are difficult to train , gates manage the flow of information over time. 

<img src="image\9.png">

$\sigma$ is the squashing pointwise nonlinearity

$U^{(.)}$ -> Convolutional kernal

$b^{(.)}$ -> Bias

When this gate is applied to the above equation we obtain:


<img src="image\10.png">

### Code snippets

```python
# Initializing the parameters:

self.u1_gate = nn.Conv2d(hidden_size, hidden_size, 1)
self.u2_gate = nn.Conv2d(hidden_size, hidden_size, 1)
        
self.w_gate_inh = nn.Parameter(torch.empty(hidden_size , hidden_size , kernel_size, kernel_size))
self.w_gate_exc = nn.Parameter(torch.empty(hidden_size , hidden_size , kernel_size, kernel_size))

self.alpha = nn.Parameter(torch.empty((hidden_size,1,1)))
self.gamma = nn.Parameter(torch.empty((hidden_size,1,1)))
self.kappa = nn.Parameter(torch.empty((hidden_size,1,1)))
self.w = nn.Parameter(torch.empty((hidden_size,1,1)))
self.mu= nn.Parameter(torch.empty((hidden_size,1,1)))

# Defining the parameters:

i = timestep # i = 8
if self.batchnorm:
    g1_t = torch.sigmoid(self.bn[i*4+0](self.u1_gate(prev_state2)))
    c1_t = self.bn[i*4+1](F.conv2d(prev_state2 * g1_t, self.w_gate_inh, padding=self.padding))

    next_state1 = F.relu(input_ - F.relu(c1_t*(self.alpha*prev_state2 + self.mu)))

    g2_t = torch.sigmoid(self.bn[i*4+2](self.u2_gate(next_state1)))
    c2_t = self.bn[i*4+3](F.conv2d(next_state1, self.w_gate_exc, padding=self.padding))
    h2_t = F.relu(self.kappa*next_state1 + self.gamma*c2_t + self.w*next_state1*c2_t)

    prev_state2 = (1 - g2_t)*prev_state2 + g2_t*h2_t

else:
    g1_t = F.sigmoid(self.u1_gate(prev_state2))
    c1_t = F.conv2d(prev_state2 * g1_t, self.w_gate_inh, padding=self.padding)
    
    next_state1 = F.tanh(input_ - c1_t*(self.alpha*prev_state2 + self.mu))
    
    g2_t = F.sigmoid(self.bn[i*4+2](self.u2_gate(next_state1)))
    c2_t = F.conv2d(next_state1, self.w_gate_exc, padding=self.padding)
    h2_t = F.tanh(self.kappa*(next_state1 + self.gamma*c2_t) + (self.w*(next_state1*(self.gamma*c2_t))))
    
    prev_state2 = self.n[timestep]*((1 - g2_t)*prev_state2 + g2_t*h2_t)

```

##### Trainable convolutional recurrent neural network

<img src="image\2.png">

### HGRU formulation

Model build on equation 2 has the capacity to learn complex interaction between units via horizontal connection 

Modification intrduced in equation 2
* learnable gates (GRU like)
* $H^2$ (excitation) is symmetric with $H^1$ (inhibition) -> gives ability to learn in implementing linear and interaction at each processing stages
* To control unstable gradient -> squashing pointwise non linearity and learned parameter to globally scale activity at every processing time step.

<img src="image\3.png">



* HGRU can learn non linear interaction between spatially neighbouring units in feedforward drive **X** enoded in hidden state *$H^2$*

Stage 1:
1. Horizontal inhibition (blue) calculated by applying gain to $H^2[t-1]$ and convolving with W
2. Linear ($+$) and quadratic ($\times$) control the convergence of inhibition onto $X$

Stage 2:
1. Horizontal excitation is computed by convolving $H^1[t]$ with W
2. Another linear and quadratic controller


### Summary

Feedforward drive **X** corresponds to activity from preceesing convolutional layer

**HGRU** encodes the spatial dependency via time varying hidden states ($H^1$ and $H^2$)

These states are updated via reset and update gates ($G^1$ and $G^2$)

Gates are derived from convolutions between Kernals ($U^1 and U^2 \in R^{1x1xKxK}$ ) and hidden states shifted by bias. 

Point wise non linearity ($\sigma$) for normalization

<img src="image\12.png">


Horizontal interaction between units are calculated by kernel **W** $\in R^{SxSxKxK}$, where S is the spatial extent of these connections in single time step. 

**W** is contrained to have symmetric weights between the channel -> reduces the numer of learnable parameters by half. 

Hidden states are recomputed via horizontal interaction at each time step $t \in [0, T]$

$H^2[t-1]$ is modulated by gain $G^1[t]$ which gives $C^1[t]$ after conlcolving with W

$C^1[t]$ is horizontal inhibition of hGRU which is applied to **X** via $\mu and \sigma$

They are *k* dimensional -> scaling linear and quadratic terms of horizontal interaction with **X**

Pointwise $\xi$ is hyperbolic tangent for normalization in range [-1,1]

<img src="image\13.png">

$C^2[t]$ represents excitation of horizontal connection

Linear and quadratic contributions are managed by $\kappa, \omega, \beta$

Learnable T dimensional $\eta$ time gain, helps to control unstable gradient during training. 


## Describing Dataset:

### Path finder challenge

* Synthetic visual task inspired by visual psychology

* Task is to detect if two circles are connected by a path or not (**Binary Classification**)

* Multiple shorter paths present -> **Distractor Path**

* AIM -> detect the long range of spatial dependencies 

**Positive examples** placing two circles at the end of single path and **Negative** on different path

Three types: Peddler length **6,9 and 14**

<img src="image\14.png">


#### Hyperparameter selected:

**Recurrent Model**
* 8 time steps
* 15x15 horizontal connection Kernels (**W**)

*Observation*

* path length increased -> performance decreased
* more time steps are required wrt increasing length of the path. -> Human wrt distance.

**Feedforward Model**
* Three types of kernels used: 10x10, 15x15, 20x20 (small, medium and large)
* performance is compared keeping number of parameters same as that of hGRU (36, 16 and 9 kernels)
* Other combination-> 2 pixel dialated convolution before convolving input. 

### Including all the header files

In [1]:
import os
import time
import torch
from torchvision.transforms import Compose as transcompose
import torch.nn.parallel
import torch.optim
import numpy as np

from dataset import DataSetPol
from hgru import hConvGRU, FFConvNet
from transforms import GroupScale, Augmentation, Stack, ToTorchFormatTensor
from misc_functions import AverageMeter, accuracy, plot_grad_flow
from statistics import mean

#from opts import parser

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

In [2]:
# For models
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
from torch.nn import init

### Setting up the parameters:

In [15]:
train_list = 'train_images_9.txt'
val_list = 'test_images_9.txt'
name = 'hgru'

# ========================= Learning Configs ==========================
epochs = 10
batch_size = 64 #batch-size
lr = .001 #learning rate
lr_steps = [20, 40] #epochs to decay learning rate by 10

# ========================= Monitor Configs ==========================
print_freq = 200 #print-frequency
ef = 1 #eval frequency
parallel = True
start_epoch=0

In [16]:
plt.ion()
plt.show()

global best_prec1
best_prec1 = 0

transform_list = transcompose([GroupScale((150,150)), Augmentation(), Stack(), 
                               ToTorchFormatTensor(div=True)])

print("Loading training dataset")
train_loader = torch.utils.data.DataLoader(DataSetPol("/media/data_cifs/curvy_2snakes_300/", 
                                                      train_list, 
                                                      transform = transform_list ),
                                           batch_size=batch_size,   
                                           shuffle=True, num_workers=4, pin_memory=True)

print("Loading validation dataset")
val_loader = torch.utils.data.DataLoader(DataSetPol("/media/data_cifs/curvy_2snakes_300/", 
                                                    val_list, transform = transform_list),
                                         batch_size=batch_size, shuffle=False, 
                                         num_workers=4, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device used is: ',device)


Loading training dataset
Loading validation dataset
device used is:  cuda


In [17]:
def validate(val_loader, model, iter, criterion, logger=None):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (imgs, target) in enumerate(val_loader):
            target = target.cuda()
            imgs = imgs.cuda()
            output = model.forward(imgs)
            
            loss = criterion(output, target)
            losses.update(loss.data, imgs.size(0))
            
            [prec1] = accuracy(output.data, target, topk=(1,))
            top1.update(prec1, imgs.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                print('Test: [{0}/{1}]\t Time: {batch_time.avg:.3f}\t Loss: {loss.val:.4f} ({loss.avg: .4f})\t'
                       'Prec: {top1.val:.3f} ({top1.avg:.3f})\t'.format(i, len(val_loader), 
                                                                        batch_time=batch_time, 
                                                                        loss=losses, top1=top1))
            
    print('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
          .format(top1=top1, top5=top5, loss=losses))

    return top1.avg

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    filename = '_'.join((name, 'accuracy', str(state['best_prec1'].item()), 
                         'epoch', str(state['epoch']), filename))
    torch.save(state, filename)






In [None]:
if __name__ == '__main__':
    
    #global best_prec1

    print("Init model")
    if parallel == True:
        model = hConvGRU(timesteps=8, filt_size = 15)
        model = torch.nn.DataParallel(model).to(device)
        print("Loading parallel finished on GPU count:", torch.cuda.device_count())
    else:
        model = hConvGRU(timesteps=8, filt_size = 15).to(device)
        print("Loading finished")

    criterion = torch.nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    lr_init = lr
    print("Starting training: ")
    f_val= []
    f_training = []
    train_loss_history = []
    for epoch in range(start_epoch, epochs):
        
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
    
        model.train()
        end = time.perf_counter() #like a stop watch for the code
        for i, (imgs, target) in enumerate(train_loader):
            data_time.update(time.perf_counter() - end)
            
            imgs = imgs.to(device)
            target = target.to(device)

            output  = model.forward(imgs) 
            '''
            This passes the images through the model (forward pass) and applies the operations previously 
            discussed in layer. You get the resultant output.
            '''
            
            loss = criterion(output, target)
            [prec1] = accuracy(output.data, target, topk=(1,))
            
            losses.update(loss.data.item(), imgs.size(0))
            top1.update(prec1.data.item(), imgs.size(0))
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            batch_time.update(time.perf_counter() - end)
            
            end = time.perf_counter()
            if i % (print_freq) == 0:
                print('Epoch: [{0}][{1}/{2}]\t lr: {lr:g}\t Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t Data: {data_time.val:.3f} ({data_time.avg:.3f})\t'
                       'Prec: {top1.val:.3f} ({precprint:.3f}) ({top1.avg:.3f})\t Loss: {loss.val:.6f} ({lossprint:.6f}) ({loss.avg:.6f})'.format(epoch, i, len(train_loader), batch_time=batch_time,
                        data_time=data_time, loss=losses, lossprint= mean(losses.history[-print_freq:]), lr=lr, top1=top1, precprint= mean(top1.history[-print_freq:])))
            
        f_training.append(top1.avg)
        train_loss_history += losses.history
        if (epoch + 1) % 1 == 0 or epoch == epochs - 1:
            prec = validate(val_loader, model, (epoch + 1) * len(train_loader), criterion)
            f_val.append(prec)
            is_best = prec > best_prec1
            if is_best:
                best_prec1 = max(prec, best_prec1)
                save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                }, is_best)

    np.array(f_training).dump(open("{}.npy".format(name),'w'))
    np.array(f_val).dump(open("{}.npy".format(name),'w'))




Init model
Training with filter size: 15 x 15
Loading parallel finished on GPU count: 8
Starting training: 
Epoch: [0][0/7422]	 lr: 0.001	 Time: 3.838 (3.838)	 Data: 2.766 (2.766)	Prec: 56.250 (56.250) (56.250)	 Loss: 0.711636 (0.711636) (0.711636)
Epoch: [0][200/7422]	 lr: 0.001	 Time: 1.482 (1.356)	 Data: 0.000 (0.014)	Prec: 57.812 (51.516) (51.539)	 Loss: 0.683204 (0.701643) (0.701692)
Epoch: [0][400/7422]	 lr: 0.001	 Time: 1.503 (1.409)	 Data: 0.000 (0.007)	Prec: 50.000 (51.188) (51.364)	 Loss: 0.688915 (0.694152) (0.697932)
Epoch: [0][600/7422]	 lr: 0.001	 Time: 1.494 (1.429)	 Data: 0.000 (0.005)	Prec: 57.812 (52.953) (51.893)	 Loss: 0.685336 (0.691379) (0.695751)
Epoch: [0][800/7422]	 lr: 0.001	 Time: 1.510 (1.438)	 Data: 0.000 (0.004)	Prec: 37.500 (53.195) (52.218)	 Loss: 0.709131 (0.691124) (0.694596)
Epoch: [0][1000/7422]	 lr: 0.001	 Time: 1.400 (1.443)	 Data: 0.000 (0.003)	Prec: 51.562 (53.992) (52.572)	 Loss: 0.689172 (0.690171) (0.693712)
Epoch: [0][1200/7422]	 lr: 0.001	 T

In [None]:
x = np.load('gabor_serre.npy')

In [4]:
x.shape

(25, 1, 7, 7)

In [9]:
class hConvGRU(nn.Module):

    def __init__(self, timesteps=8, filt_size = 9):
        super().__init__()
        self.timesteps = timesteps
        
        self.conv0 = nn.Conv2d(1, 25, kernel_size=7, padding=3)
        part1 = np.load("gabor_serre.npy")
        self.conv0.weight.data = torch.FloatTensor(part1)
        
        # calling hconvgrucell
        
        self.unit1 = hConvGRUCell(25, 25, filt_size) 
        print("Training with filter size:",filt_size,"x",filt_size)
        self.unit1.train()
        
        self.bn = nn.BatchNorm2d(25, eps=1e-03)
        
        self.conv6 = nn.Conv2d(25, 2, kernel_size=1)
        init.xavier_normal_(self.conv6.weight)
        init.constant_(self.conv6.bias, 0)
        
        self.maxpool = nn.MaxPool2d(150, stride=1)
        
        self.bn2 = nn.BatchNorm2d(2, eps=1e-03)
        
        self.fc = nn.Linear(2, 2)
        init.xavier_normal_(self.fc.weight)
        init.constant_(self.fc.bias, 0)

    def forward(self, x):
        internal_state = None

        x = self.conv0(x)
        x = torch.pow(x, 2)
        
        for i in range(self.timesteps):
            internal_state  = self.unit1(x, internal_state, timestep=i)
        output = self.bn(internal_state)
        output = F.leaky_relu(self.conv6(output))
        output = self.maxpool(output)
        output = self.bn2(output)
        output = output.view(output.size(0), -1)
        output = self.fc(output)
        return output

In [11]:
class hConvGRUCell(nn.Module):
    """
    Generate a convolutional GRU cell
    """

    def __init__(self, input_size, hidden_size, kernel_size, batchnorm=True, timesteps=8):
        super().__init__()
        self.padding = kernel_size // 2
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.timesteps = timesteps
        self.batchnorm = batchnorm
        
        self.u1_gate = nn.Conv2d(hidden_size, hidden_size, 1)
        self.u2_gate = nn.Conv2d(hidden_size, hidden_size, 1)
        
        self.w_gate_inh = nn.Parameter(torch.empty(hidden_size , hidden_size , kernel_size, kernel_size))
        self.w_gate_exc = nn.Parameter(torch.empty(hidden_size , hidden_size , kernel_size, kernel_size))
        
        self.alpha = nn.Parameter(torch.empty((hidden_size,1,1)))
        self.gamma = nn.Parameter(torch.empty((hidden_size,1,1)))
        self.kappa = nn.Parameter(torch.empty((hidden_size,1,1)))
        self.w = nn.Parameter(torch.empty((hidden_size,1,1)))
        self.mu= nn.Parameter(torch.empty((hidden_size,1,1)))

        if self.batchnorm:
            self.bn = nn.ModuleList([nn.BatchNorm2d(25, eps=1e-03) for i in range(32)])
        else:
            self.n = nn.Parameter(torch.randn(self.timesteps,1,1))

        init.orthogonal_(self.w_gate_inh)
        init.orthogonal_(self.w_gate_exc)
        
        self.w_gate_inh.register_hook(lambda grad: (grad + torch.transpose(grad,1,0))*0.5)
        self.w_gate_exc.register_hook(lambda grad: (grad + torch.transpose(grad,1,0))*0.5)

        
        init.orthogonal_(self.u1_gate.weight)
        init.orthogonal_(self.u2_gate.weight)
        
        for bn in self.bn:
            init.constant_(bn.weight, 0.1)
        
        init.constant_(self.alpha, 0.1)
        init.constant_(self.gamma, 1.0)
        init.constant_(self.kappa, 0.5)
        init.constant_(self.w, 0.5)
        init.constant_(self.mu, 1)
        
        init.uniform_(self.u1_gate.bias.data, 1, 8.0 - 1)
        self.u1_gate.bias.data.log()
        self.u2_gate.bias.data =  -self.u1_gate.bias.data


    def forward(self, input_, prev_state2, timestep=0):

        if timestep == 0:
            prev_state2 = torch.empty_like(input_)
            init.xavier_normal_(prev_state2)

        #import pdb; pdb.set_trace()
        i = timestep
        if self.batchnorm:
            g1_t = torch.sigmoid(self.bn[i*4+0](self.u1_gate(prev_state2)))
            c1_t = self.bn[i*4+1](F.conv2d(prev_state2 * g1_t, self.w_gate_inh, padding=self.padding))
            
            next_state1 = F.relu(input_ - F.relu(c1_t*(self.alpha*prev_state2 + self.mu)))
            
            g2_t = torch.sigmoid(self.bn[i*4+2](self.u2_gate(next_state1)))
            c2_t = self.bn[i*4+3](F.conv2d(next_state1, self.w_gate_exc, padding=self.padding))
            
            h2_t = F.relu(self.kappa*next_state1 + self.gamma*c2_t + self.w*next_state1*c2_t)
            
            prev_state2 = (1 - g2_t)*prev_state2 + g2_t*h2_t

        else:
            g1_t = F.sigmoid(self.u1_gate(prev_state2))
            c1_t = F.conv2d(prev_state2 * g1_t, self.w_gate_inh, padding=self.padding)
            next_state1 = F.tanh(input_ - c1_t*(self.alpha*prev_state2 + self.mu))
            g2_t = F.sigmoid(self.bn[i*4+2](self.u2_gate(next_state1)))
            c2_t = F.conv2d(next_state1, self.w_gate_exc, padding=self.padding)
            h2_t = F.tanh(self.kappa*(next_state1 + self.gamma*c2_t) + (self.w*(next_state1*(self.gamma*c2_t))))
            prev_state2 = self.n[timestep]*((1 - g2_t)*prev_state2 + g2_t*h2_t)

        return prev_state2

In [12]:
model = hConvGRU(timesteps=8, filt_size = 15)
print(model)

Training with filter size: 15 x 15
hConvGRU(
  (conv0): Conv2d(1, 25, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
  (unit1): hConvGRUCell(
    (u1_gate): Conv2d(25, 25, kernel_size=(1, 1), stride=(1, 1))
    (u2_gate): Conv2d(25, 25, kernel_size=(1, 1), stride=(1, 1))
    (bn): ModuleList(
      (0): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (1): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (3): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (4): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (5): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (6): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (7): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, 

In [6]:
model

hConvGRU(
  (conv0): Conv2d(1, 25, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
  (unit1): hConvGRUCell(
    (u1_gate): Conv2d(25, 25, kernel_size=(1, 1), stride=(1, 1))
    (u2_gate): Conv2d(25, 25, kernel_size=(1, 1), stride=(1, 1))
    (bn): ModuleList(
      (0): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (1): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (3): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (4): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (5): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (6): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (7): BatchNorm2d(25, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (8)