In [212]:
import torch
import torch.nn as nn
import scipy.stats as stats

__toy data__


In [213]:
n = 100

x = torch.concat([torch.tensor(stats.norm.rvs(loc=-1, scale=1, size=(n,1)), dtype=torch.float32), torch.tensor(stats.norm.rvs(loc=1, scale=1, size=(n,1)), dtype=torch.float32)])
y = torch.concat([torch.zeros(n//2), torch.ones(n//2), torch.zeros(n//2), torch.ones(n//2)])
c = torch.concat([torch.zeros(n//2), torch.ones(n//2), torch.zeros(n//2), torch.ones(n//2)])

X = torch.stack(torch.unbind(x,dim=0))
Y = torch.stack(torch.unbind(y,dim=0))
C = torch.stack(torch.unbind(c,dim=0))

__ICNN__

In [215]:
class ICNNet(nn.Module):
    def __init__(self, input_size = 1, layer_sizes = [1,32,64,32,8,1], context_layer_sizes=[1,32,64,32,8,1]):
        super(ICNNet, self).__init__()
        self.n_layers = len(layer_sizes)

        self.layers_activation = nn.ModuleList([nn.Softplus() for _ in range(self.n_layers-1)])

        self.layers_z = nn.ModuleList([nn.Linear(layer_sizes[i], layer_sizes[i+1], bias=False) for i in range(self.n_layers-1)])

        self.layers_zu = nn.ModuleList([nn.Sequential(nn.Linear(context_layer_sizes[i], layer_sizes[i]), nn.ReLU()) for i in range(self.n_layers-1)])

        self.layers_x = nn.ModuleList([nn.Linear(input_size, layer_sizes[i+1], bias=False) for i in range(self.n_layers-1)])

        self.layers_xu = nn.ModuleList([nn.Linear(context_layer_sizes[i], input_size) for i in range(self.n_layers-1)])

        self.layers_u = nn.ModuleList([nn.Linear(context_layer_sizes[i], layer_sizes[i+1]) for i in range(self.n_layers-1)])

        self.layers_v = nn.ModuleList([nn.Sequential(nn.Linear(context_layer_sizes[i], context_layer_sizes[i+1]), nn.ReLU()) for i in range(self.n_layers-1)])

    def forward(self, x, c):
        input = x
        u = c
        for i in range(self.n_layers-1):
            x = self.layers_activation[i](self.layers_z[i](x * self.layers_zu[i](u)) + self.layers_x[i](input * self.layers_xu[i](u)) + self.layers_u[i](u))
            u = self.layers_v[i](u)
        
        return x

__PICNN training__

In [216]:
from torch import optim
epochs = 10

# Initialize the model
model = ICNNet()

# Define the loss function and the optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

for epoch in range(epochs):
    for i in range(len(X)):
        optimizer.zero_grad() # Zero the gradients
        
        output = model(X[i].unsqueeze(0), C[i].unsqueeze(0))  # Assuming context c is same as input x

        loss = criterion(output, Y[i].unsqueeze(0)) # Compute the loss
        loss.backward() # Backward pass
        optimizer.step() # Update the parameters

        for layers_k in model.layers_z:
            for param in layers_k.parameters():
                param.data.clamp_min_(0)

    print(f"Epoch {epoch+1}/{epochs} Loss: {loss.item()}")

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1/10 Loss: 0.0006194292800500989
Epoch 2/10 Loss: 1.8419825210003182e-05
Epoch 3/10 Loss: 7.869392106840678e-07
Epoch 4/10 Loss: 6.194000548020995e-07
Epoch 5/10 Loss: 9.118837169808103e-07
Epoch 6/10 Loss: 9.649445928516798e-06
Epoch 7/10 Loss: 1.960744157258887e-05
Epoch 8/10 Loss: 2.2288833861239254e-05
Epoch 9/10 Loss: 2.189047881984152e-05
Epoch 10/10 Loss: 2.2460253603640012e-05


__Makkuva training__

In [224]:
from torch import optim
epochs = 10
train_freq_f = 2

# Initialize the model
input_size = 1
ICNNf = ICNNet(layer_sizes = [input_size, 32,64,32,8,1], context_layer_sizes=[1,8,8,8,8,8])

output_size = 1
ICNNg = ICNNet(layer_sizes = [output_size, 32,64,32,8,1], context_layer_sizes=[1,8,8,8,8,8])

# Define the loss function and the optimizer
optimizer_f = optim.Adam(ICNNf.parameters())
optimizer_g = optim.Adam(ICNNg.parameters())

for epoch in range(epochs):
    for _ in range(train_freq_f) :
        for i in range(len(X)):
            optimizer.zero_grad() # Zero the gradients
            
            x = X[i].unsqueeze(0)
            c = C[i].unsqueeze(0)
            y = Y[i].unsqueeze(0).requires_grad_()

            output_g = ICNNg(y, c)
            diff_output_g = torch.autograd.grad(output_g, y, grad_outputs=torch.ones_like(output_g), create_graph=True)[0]

            loss = ICNNg(x, c) - ICNNf(diff_output_g, c)

            loss.backward() # Backward pass
            optimizer_f.step() # Update the parameters

            print(i, loss)

            for layers_k in ICNNf.layers_z:
                for param in layers_k.parameters():
                    param.data.clamp_min_(0)

    for i in range(len(X)):
            optimizer.zero_grad() # Zero the gradients
            
            x = X[i].unsqueeze(0)
            c = C[i].unsqueeze(0)
            y = Y[i].unsqueeze(0).requires_grad_()

            output_g = ICNNg(y, c)
            diff_output_g = torch.autograd.grad(output_g, y, grad_outputs=torch.ones_like(output_g), create_graph=True)[0]

            loss = - torch.dot(y, diff_output_g) - ICNNf(diff_output_g, c)

            loss.backward() # Backward pass
            optimizer_g.step() # Update the parameters

            print(i, loss)

            for layers_k in ICNNg.layers_z:
                for param in layers_k.parameters():
                    param.data.clamp_min_(0)

    print(f"Epoch {epoch+1}/{epochs} Loss: {loss.item()}")

0 tensor([[0.0687]], grad_fn=<SubBackward0>)
1 tensor([[0.0506]], grad_fn=<SubBackward0>)
2 tensor([[0.0500]], grad_fn=<SubBackward0>)
3 tensor([[0.0483]], grad_fn=<SubBackward0>)
4 tensor([[0.0460]], grad_fn=<SubBackward0>)
5 tensor([[0.0454]], grad_fn=<SubBackward0>)
6 tensor([[0.0423]], grad_fn=<SubBackward0>)
7 tensor([[0.0402]], grad_fn=<SubBackward0>)
8 tensor([[0.0386]], grad_fn=<SubBackward0>)
9 tensor([[0.0368]], grad_fn=<SubBackward0>)
10 tensor([[0.0346]], grad_fn=<SubBackward0>)
11 tensor([[0.0324]], grad_fn=<SubBackward0>)
12 tensor([[0.0297]], grad_fn=<SubBackward0>)
13 tensor([[0.0281]], grad_fn=<SubBackward0>)
14 tensor([[0.0266]], grad_fn=<SubBackward0>)
15 tensor([[0.0246]], grad_fn=<SubBackward0>)
16 tensor([[0.0214]], grad_fn=<SubBackward0>)
17 tensor([[0.0195]], grad_fn=<SubBackward0>)
18 tensor([[0.0163]], grad_fn=<SubBackward0>)
19 tensor([[0.0143]], grad_fn=<SubBackward0>)
20 tensor([[0.0118]], grad_fn=<SubBackward0>)
21 tensor([[0.0085]], grad_fn=<SubBackward0>