In [187]:
import torch
import torch.nn as nn
import scipy.stats as stats

In [188]:
n = 100

x = torch.concat([torch.tensor(stats.norm.rvs(loc=(-1,-1), scale=1, size=(n,2)), dtype=torch.float32), torch.tensor(stats.norm.rvs(loc=(1,1), scale=1, size=(n,2)), dtype=torch.float32)])
y = torch.concat([torch.zeros(n//2), torch.ones(n//2), torch.zeros(n//2), torch.ones(n//2)])
c = torch.concat([torch.zeros(n//2), torch.ones(n//2), torch.zeros(n//2), torch.ones(n//2)])

X = torch.stack(torch.unbind(x,dim=0))
Y = torch.stack(torch.unbind(y,dim=0))
C = torch.stack(torch.unbind(c,dim=0))

In [189]:
class ICNNet(nn.Module):
    def __init__(self, input_size = 2, layer_sizes = [2,32,64,32,8,1], context_layer_sizes=[1,32,64,32,8,1]):
        super(ICNNet, self).__init__()
        self.n_layers = len(layer_sizes)

        self.layers_activation = nn.ModuleList([nn.Softplus() for _ in range(self.n_layers-1)])

        self.layers_z = nn.ModuleList([nn.Linear(layer_sizes[i], layer_sizes[i+1], bias=False) for i in range(self.n_layers-1)])

        self.layers_zu = nn.ModuleList([nn.Sequential(nn.Linear(context_layer_sizes[i], layer_sizes[i]), nn.ReLU()) for i in range(self.n_layers-1)])

        self.layers_x = nn.ModuleList([nn.Linear(input_size, layer_sizes[i+1], bias=False) for i in range(self.n_layers-1)])

        self.layers_xu = nn.ModuleList([nn.Linear(context_layer_sizes[i], input_size) for i in range(self.n_layers-1)])

        self.layers_u = nn.ModuleList([nn.Linear(context_layer_sizes[i], layer_sizes[i+1]) for i in range(self.n_layers-1)])

        self.layers_v = nn.ModuleList([nn.Sequential(nn.Linear(context_layer_sizes[i], context_layer_sizes[i+1]), nn.ReLU()) for i in range(self.n_layers-1)])

    def forward(self, x, c):
        input = x
        u = c
        for i in range(self.n_layers-1):
            x = self.layers_activation[i](self.layers_z[i](x * self.layers_zu[i](u)) + self.layers_x[i](input * self.layers_xu[i](u)) + self.layers_u[i](u))
            u = self.layers_v[i](u)
        return x

In [195]:
from torch import optim
epochs = 100

# Initialize the model
model = ICNNet()

# Define the loss function and the optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

for epoch in range(epochs):
    for i in range(len(X)):
        optimizer.zero_grad() # Zero the gradients
        
        output = model(X[i].unsqueeze(0), C[i].unsqueeze(0))  # Assuming context c is same as input x

        loss = criterion(output, Y[i].unsqueeze(0)) # Compute the loss
        loss.backward() # Backward pass
        optimizer.step() # Update the parameters

        for layers_k in model.layers_z:
            for param in layers_k.parameters():
                param.data.clamp_min_(0)

    print(f"Epoch {epoch+1}/{epochs} Loss: {loss.item()}")

Epoch 1/10 Loss: 0.018610745668411255
Epoch 2/10 Loss: 0.0027061128057539463
Epoch 3/10 Loss: 0.0007559371297247708
Epoch 4/10 Loss: 0.00012830697232857347
Epoch 5/10 Loss: 3.4207691896881443e-06
Epoch 6/10 Loss: 5.353360847948352e-06
Epoch 7/10 Loss: 9.492339813732542e-06
Epoch 8/10 Loss: 8.261000402853824e-06
Epoch 9/10 Loss: 5.956099812465254e-06
Epoch 10/10 Loss: 6.11777659287327e-06
