In [7]:
import torch
import torch.nn as nn
import scipy.stats as stats

In [4]:
class ICNNet(nn.Module):
    def __init__(self, layer_sizes = [2,32,64,32,8,1], context_layer_sizes=[2,32,64,32,8,1]):
        self.n_layers = len(layer_sizes)

        self.layers_activation = list()
        self.layers_z = list()
        self.layers_zu = list()
        self.layers_x = list()
        self.layers_xu = list()
        self.layers_u = list()
        self.layers_v = list()
        
        super(ICNNet, self).__init__()

        for i in range(self.n_layers-1): #Convex activation functions
            self.layers_activation.append(nn.Softplus())

        for i in range(self.n_layers-1): #Wz_k
            self.layers_z.append(nn.Sequential(nn.Linear(layer_sizes[i], layer_sizes[i+1], bias=False)))

        for i in range(self.n_layers-1): #Wzu_k and bz_k
            self.layers_zu.append(nn.Sequential(nn.Linear(layer_sizes[i], layer_sizes[i+1]), nn.ReLU()))

        for i in range(self.n_layers-1): #Wx_k
            self.layers_x.append(nn.Sequential(nn.Linear(layer_sizes[i], layer_sizes[i+1], bias=False)))

        for i in range(self.n_layers-1): # Wxu_k and bx_k
            self.layers_xu.append(nn.Sequential(nn.Linear(context_layer_sizes[i], layer_sizes[i+1])))

        for i in range(self.n_layers-1): #Wu_k and bu_k
            self.layers_u.append(nn.Sequential(nn.Linear(context_layer_sizes[i], layer_sizes[i+1])))

        for i in range(self.n_layers-1): #V_k and v_k
            self.layers_v.append(nn.Sequential(nn.Linear(context_layer_sizes[i], context_layer_sizes[i+1]), nn.tanh()))

    def forward(self, x, c):
        input = x.clone()
        u = c.clone()
        for i in range(self.n_layers-1):
            x = self.layers_activation[i](self.layers_z[i](x * self.layers_zu[i](u)) + self.layers_x[i](input * self.layers_xu[i](u)) + self.layers_u[i](u))
            u = self.layers_v[i](u)

In [10]:
data = torch.tensor(stats.norm.rvs(loc=(-1,-1), scale=1, size=(100,2)), dtype=torch.float32)