In [7]:
import torch
from cnp.encoders import StandardANPEncoder
from cnp.lnp import StandardANP
from cnp.cov import AddHomoNoise

from torch import nn

In [None]:
x_context = torch.ones(size=(10, 5, 1))
x_target = torch.ones(size=(10, 8, 1))
y_context = torch.ones(size=(10, 5, 1))
y_target = torch.ones(size=(10, 8, 1))

encoder = StandardANPEncoder(input_dim=2, latent_dim=128)
encoder(x_context, y_context, x_target).sample().shape

In [None]:
input_dim = 1
add_noise = AddHomoNoise()

anp = StandardANP(input_dim=input_dim, add_noise=add_noise)

In [None]:
result = anp(x_context=x_context, y_context=y_context, x_target=x_target, num_samples=100)
result[1].shape

In [None]:
optim = torch.optim.Adam(lr=1e-2, params=anp.parameters())


for i in range(1000):
    
    optim.zero_grad()
    
    loss = anp.loss(x_context, y_context, x_target, y_target, num_samples=10)
    
    loss.backward()
    optim.step()

In [58]:
class HalfUNet(nn.Module):

    def __init__(self,
                 input_dim,
                 in_channels,
                 out_channels):
        
        super().__init__()
        
        conv = getattr(nn, f'Conv{input_dim}d')
        convt = getattr(nn, f'ConvTranspose{input_dim}d')
        
        self.activation = nn.ReLU()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_halving_layers = 6

        self.l1 = conv(in_channels=self.in_channels,
                       out_channels=self.in_channels,
                       kernel_size=5,
                       stride=2,
                       padding=2)
        
        self.l2 = conv(in_channels=self.in_channels,
                       out_channels=2*self.in_channels,
                       kernel_size=5,
                       stride=2,
                       padding=2)
        
        self.l3 = conv(in_channels=2*self.in_channels,
                       out_channels=2*self.in_channels,
                       kernel_size=5,
                       stride=2,
                       padding=2)

#         for layer in [self.l1, self.l2, self.l3]:
#             init_layer_weights(layer)
            
        self.l4 = convt(in_channels=2*self.in_channels,
                        out_channels=2*self.in_channels,
                        kernel_size=5,
                        stride=2,
                        padding=2,
                        output_padding=1)
        
        self.l5 = convt(in_channels=4*self.in_channels,
                        out_channels=self.in_channels,
                        kernel_size=5,
                        stride=2,
                        padding=2,
                        output_padding=1)
        
        self.l6 = convt(in_channels=2*self.in_channels,
                        out_channels=self.in_channels,
                        kernel_size=5,
                        stride=2,
                        padding=2,
                        output_padding=1)


        self.last_layer_multiplier = conv(in_channels=2*self.in_channels,
                                          out_channels=self.out_channels,
                                          kernel_size=1,
                                          stride=1,
                                          padding=0)
            

    def forward(self, x):
        """Forward pass through the convolutional structure.

        Args:
            x (tensor): Inputs of shape `(batch, n_in, in_channels)`.

        Returns:
            tensor: Outputs of shape `(batch, n_out, out_channels)`.
        """
        
        h1 = self.activation(self.l1(x))
        h2 = self.activation(self.l2(h1))
        h3 = self.activation(self.l3(h2))
        h4 = self.activation(self.l4(h3))

        h4 = torch.cat([h4, h2], dim=1)
        h5 = self.activation(self.l5(h4))

        h5 = torch.cat([h5, h1], dim=1)
        h6 = self.activation(self.l6(h5))
        h6 = torch.cat([x, h6], dim=1)

        return self.last_layer_multiplier(h6)

In [60]:
half_unet = HalfUNet(input_dim=1,
                     in_channels=8,
                     out_channels=2)

zeros = torch.zeros(size=(100, 8, 128))

half_unet(zeros).shape

torch.Size([100, 2, 128])