In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import rc
plt.style.use('ggplot')
import torch
from torch.nn.functional import softmax
from scipy.special import softmax as np_softmax
from scipy.linalg import toeplitz, circulant
import numpy as np
from torch.autograd import Variable
from torch.distributions import uniform, cauchy, normal, relaxed_bernoulli
import matplotlib.animation as animation
from IPython.display import HTML
import time
import ipdb

In [None]:
def sigmoid(x):
    return 1. / (1 + torch.exp(-1*(x)))

def normalize(x):
    return x / torch.sqrt((x**2).sum())

In [None]:
# Fetch distribution objects for intrinsic frequencies

def get_dist(dist_name,):
    if dist_name == 'cauchy':
        loc = 0.0
        scale = 1.0
        dist = cauchy.Cauchy(loc, scale)
        g0 = torch.exp(dist.log_prob(loc))
        return dist, g0
    elif dist_name == 'uniform':
        high = 1.0
        low = -1.0
        g0 = 1. / (high - low)
        dist = uniform.Uniform(-1.0, 1.0)
        return dist, g0

In [None]:
# Networks

# Predict connectivity from omega
class connectivity_net(torch.nn.Module):
    def __init__(self, num_in, num_out, num_hid_units=256, num_hid_layers=1,transform=None):
        super(connectivity_net, self).__init__()
        self.transform=transform
        self.num_out = num_out
        self.layers = torch.nn.ModuleList([torch.nn.Linear(num_in, num_hid_units),
                         torch.nn.ReLU()])
        for _ in range(num_hid_layers):
            self.layers.extend([torch.nn.Linear(num_hid_units, num_hid_units),torch.nn.ReLU()])
        self.layers.append(torch.nn.Linear(num_hid_units,self.num_out))
    def forward(self,x):   
        for layer in self.layers:
            x = layer(x)
        if self.transform == 'softmax':
            x = softmax(x,dim=-1)
        elif self.transform == 'sigmoid':
            x = sigmoid(x)
        return x
    
class connectivity_GRU(torch.nn.Module):
    def __init__(self, num_in, num_out, num_hid_units=256,num_hid_layers=1, T=10, transform=None):
        super(connectivity_GRU, self).__init__()
        self.num_out = num_out
        self.num_hid_units = num_hid_units
        self.num_hid_layers = num_hid_layers
        self.T = T
        self.transform = transform
        
        self.gru = torch.nn.GRU(num_in, num_hid_units, num_hid_layers,)
        self.fc = torch.nn.Linear(self.num_hid_units, self.num_out)
        self.relu = torch.nn.ReLU()
        
    def forward(self, x, h):
        x = torch.stack([x for _ in range(self.T)])
        out, h = self.gru(x, h)
        
        out = torch.stack([self.fc(self.relu(out[i, ...])) for i in range(self.T)])
        if self.transform == 'softmax':
            out = softmax(out,dim=-1)
        elif self.transform == 'sigmoid':
            out = sigmoid(out)
        return out
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = weight.new(self.num_hid_layers, batch_size, self.num_hid_units).zero_()
        return hidden
    
class connectivity_cGRU(torch.nn.Module):
    def __init__(self, num_in, num_out, num_hid_units=256,num_hid_layers=1, transform=None, feedback=False):
        super(connectivity_cGRU, self).__init__()
        self.num_out = num_out
        self.num_hid_units = num_hid_units
        self.num_hid_layers = num_hid_layers
        self.transform = transform
        self.feedback = feedback
        
        gru_in = 2*num_in if feedback else num_in
        self.gru = torch.nn.GRU(gru_in, num_hid_units, num_hid_layers,)
        self.fc = torch.nn.Linear(self.num_hid_units, self.num_out)
        self.relu = torch.nn.ReLU()
        
    def forward(self, omega, h, phase, coupling_strength=.3, alpha=1e-1,
                GRU_steps=10, kuramoto_steps=100, return_connectivities=False):
        phase_trajectory = []
        connectivities = []
        if self.feedback:
            for t in range(kuramoto_steps):
                data = torch.cat([omega, phase], dim=-1).unsqueeze(0)
                out, h = self.gru(data,h)
                out = self.fc(self.relu(out.squeeze(0)))
                if self.transform == 'softmax':
                    out = softmax(out,dim=-1)
                elif self.transform == 'sigmoid':
                    out = sigmoid(out)

                connectivity = make_connectivity(out)
                if return_connectivities:
                    connectivities.append(connectivity)
                phase = kuramoto_step(phase, coupling_strength*connectivity, omega, alpha=alpha)
                phase_trajectory.append(phase)        

            if return_connectivities:
                return phase_trajectory, connectivities
            else:
                return phase_trajectory
        else:
            data = torch.stack([omega for _ in range(GRU_steps)])
            out, h = self.gru(data, h)

            out = torch.stack([self.fc(self.relu(out[i, ...])) for i in range(GRU_steps)])
            if self.transform == 'softmax':
                out = softmax(out,dim=-1)
            elif self.transform == 'sigmoid':
                out = sigmoid(out)
            connectivity = make_connectivity(out[-1,...])
            if return_connectivities:
                connectivities.append(connectivity)
                
            for t in range(kuramoto_steps):
                phase = kuramoto_step(phase, coupling_strength*connectivity, omega, alpha=alpha)
                phase_trajectory.append(phase)  
                
            if return_connectivities:
                return phase_trajectory, connectivities
            else:
                return phase_trajectory
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        hidden = weight.new(self.num_hid_layers, batch_size, self.num_hid_units).zero_()
        return hidden

In [None]:
# Optimize connectivity network

def optimize_connectivity_cgru(num_units, num_out, num_links, omega_name='uniform',iterations=100,
                              transform=None,kuramoto_steps=100,
                              burn_in_steps = 75, lr=.01, alpha=.1,coupling_strength = .7,
                              num_hid_units=256, num_hid_layers=1,batch_size=256,optimizer='Adam',
                              view_inner_opt=-1,view_lr=-1, view_connectivities=-1, GRU_steps=10,
                              feedback=False, verbose=0, pretrained=False):
    
    
    # Initial stuff
    omega_dist, g0 = get_dist(omega_name)
    cn = connectivity_cGRU(num_units, num_out, num_hid_units=num_hid_units,
                          num_hid_layers=num_hid_layers, transform=transform, feedback=feedback)
    if pretrained:
        cn.load_state_dict(torch.load('/media/data_cifs_lrs/projects/prj_synchrony/results/models/brede.pt'))
    
    if optimizer=='Adam':
        opt = torch.optim.Adam(cn.parameters(), lr=lr)
    elif optimizer=='SGD':
        opt = torch.optim.SGD(cn.parameters(), lr=lr)

    cvh = []
    co  = []
    pn  = []
    oh  = []
    ioh = []
    ah  = []
    
    triu_ind = torch.triu_indices(row=num_units, col=num_units, offset=1)
    for i in range(iterations):
        start = time.time()
        omega = omega_dist.sample(sample_shape=torch.Size([batch_size,num_units]))
        opt.zero_grad()

        phase = np.pi * torch.ones(batch_size,num_units)
        h = cn.init_hidden(batch_size)
        
        if view_connectivities and i == (iterations - 1):
            flow, connectivities = cn.forward(omega, h, phase, kuramoto_steps=kuramoto_steps, GRU_steps=GRU_steps,
                                              coupling_strength=coupling_strength, return_connectivities=True)
        else:
            flow = cn.forward(omega, h, phase, kuramoto_steps=kuramoto_steps, GRU_steps=GRU_steps,
                              coupling_strength=coupling_strength)
        
        flow = torch.stack(flow).transpose(1,0)
        truncated_flow = flow[:,burn_in_steps:,:]
        cv = circular_variance(truncated_flow)
        cvh.append(cv.detach().cpu().numpy())
        cv.backward()
        opt.step()
        stop = time.time()
        if verbose:
            print('Iteration {}. Loss: {}. Time/batch: {}'.format(i, cv.detach().numpy(), stop-start))

        if cv != cv:
            ipdb.set_trace()
           
    return cvh, cn, connectivities