In [1]:
import torch
from cnp.encoders import StandardANPEncoder, ConvEncoder
from cnp.decoders import ConvDecoder
from cnp.lnp import StandardANP, LatentNeuralProcess
from cnp.cov import AddHomoNoise
from cnp.utils import move_channel_idx

from torch import nn

In [2]:
x_context = torch.ones(size=(10, 5, 1))
x_target = torch.ones(size=(10, 8, 1))
y_context = torch.ones(size=(10, 5, 1))
y_target = torch.ones(size=(10, 8, 1))

encoder = StandardANPEncoder(input_dim=2, latent_dim=128)
encoder(x_context, y_context, x_target).sample().shape

torch.Size([10, 8, 128])

In [3]:
input_dim = 1
add_noise = AddHomoNoise()

anp = StandardANP(input_dim=input_dim, add_noise=add_noise)

In [4]:
result = anp(x_context=x_context, y_context=y_context, x_target=x_target, num_samples=100)
result[1].shape

torch.Size([100, 10, 8, 8])

In [5]:
optim = torch.optim.Adam(lr=1e-2, params=anp.parameters())


for i in range(1000):
    
    optim.zero_grad()
    
    loss = anp.loss(x_context, y_context, x_target, y_target, num_samples=10)
    
    loss.backward()
    optim.step()

In [6]:
class HalfUNet(nn.Module):

    def __init__(self,
                 input_dim,
                 in_channels,
                 out_channels):
        
        super().__init__()
        
        conv = getattr(nn, f'Conv{input_dim}d')
        convt = getattr(nn, f'ConvTranspose{input_dim}d')
        
        self.activation = nn.ReLU()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_halving_layers = 6

        self.l1 = conv(in_channels=self.in_channels,
                       out_channels=self.in_channels,
                       kernel_size=5,
                       stride=2,
                       padding=2)
        
        self.l2 = conv(in_channels=self.in_channels,
                       out_channels=2*self.in_channels,
                       kernel_size=5,
                       stride=2,
                       padding=2)
        
        self.l3 = conv(in_channels=2*self.in_channels,
                       out_channels=2*self.in_channels,
                       kernel_size=5,
                       stride=2,
                       padding=2)
            
        self.l4 = convt(in_channels=2*self.in_channels,
                        out_channels=2*self.in_channels,
                        kernel_size=5,
                        stride=2,
                        padding=2,
                        output_padding=1)
        
        self.l5 = convt(in_channels=4*self.in_channels,
                        out_channels=self.in_channels,
                        kernel_size=5,
                        stride=2,
                        padding=2,
                        output_padding=1)
        
        self.l6 = convt(in_channels=2*self.in_channels,
                        out_channels=self.in_channels,
                        kernel_size=5,
                        stride=2,
                        padding=2,
                        output_padding=1)


        self.last_layer_multiplier = conv(in_channels=2*self.in_channels,
                                          out_channels=self.out_channels,
                                          kernel_size=1,
                                          stride=1,
                                          padding=0)
            

    def forward(self, x):
        """Forward pass through the convolutional structure.

        Args:
            x (tensor): Inputs of shape `(batch, n_in, in_channels)`.

        Returns:
            tensor: Outputs of shape `(batch, n_out, out_channels)`.
        """
        
        h1 = self.activation(self.l1(x))
        h2 = self.activation(self.l2(h1))
        h3 = self.activation(self.l3(h2))
        h4 = self.activation(self.l4(h3))

        h4 = torch.cat([h4, h2], dim=1)
        h5 = self.activation(self.l5(h4))

        h5 = torch.cat([h5, h1], dim=1)
        h6 = self.activation(self.l6(h5))
        h6 = torch.cat([x, h6], dim=1)

        return self.last_layer_multiplier(h6)

In [7]:
half_unet = HalfUNet(input_dim=1,
                     in_channels=8,
                     out_channels=2)

zeros = torch.zeros(size=(100, 8, 128))

half_unet(zeros).shape

torch.Size([100, 2, 128])

In [31]:
class StandardConvNP(LatentNeuralProcess):
    
    def __init__(self, input_dim, add_noise):
        
        # Dimension of output is 1 for scalar outputs -- do not change
        output_dim = 1
        
        # Num channels of input passed to encoder CNN
        encoder_conv_input_channels = 8
        
        # Num channels of latent function -- outputted by encoder, expected by decoder
        latent_function_channels = 1
        
        # Num channels of output of decoder CNN
        decoder_conv_output_channels = 8
        
        # Num channels of output of decoder
        decoder_out_channels = 32
        
        # Encoder convolutional architecture
        encoder_conv = HalfUNet(input_dim=input_dim,
                                in_channels=encoder_conv_input_channels,
                                out_channels=2*latent_function_channels)
        
        # Encoder convolutional architecture
        decoder_conv = HalfUNet(input_dim=input_dim,
                                in_channels=latent_function_channels,
                                out_channels=decoder_conv_output_channels)

        # Construct the convolutional encoder
        grid_multiplier =  2 ** encoder_conv.num_halving_layers
        points_per_unit = 32
        init_length_scale = 2.0 / points_per_unit
        grid_margin = 0.2
        
        encoder = LatentConvEncoder(input_dim=input_dim,
                                    conv_architecture=encoder_conv,
                                    init_length_scale=init_length_scale, 
                                    points_per_unit=points_per_unit, 
                                    grid_multiplier=grid_multiplier,
                                    grid_margin=grid_margin)
        
        decoder = ConvDecoder(input_dim=input_dim,
                              conv_architecture=decoder_conv,
                              conv_out_channels=decoder_conv.out_channels,
                              out_channels=2*decoder_out_channels,
                              init_length_scale=init_length_scale,
                              points_per_unit=points_per_unit,
                              grid_multiplier=grid_multiplier,
                              grid_margin=grid_margin)


        super().__init__(encoder=encoder,
                         decoder=decoder,
                         add_noise=add_noise)
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        
        
class LatentConvEncoder(ConvEncoder):

    def __init__(self,
                 input_dim,
                 conv_architecture,
                 init_length_scale, 
                 points_per_unit, 
                 grid_multiplier,
                 grid_margin):
        
        self.conv_input_channels = conv_architecture.in_channels
        self.conv_output_channels = conv_architecture.out_channels // 2
        
        super().__init__(input_dim=input_dim, 
                         out_channels=self.conv_input_channels, 
                         init_length_scale=init_length_scale, 
                         points_per_unit=points_per_unit, 
                         grid_multiplier=grid_multiplier,
                         grid_margin=grid_margin)
        
        self.conv_architecture = conv_architecture
        
        
    def forward(self, x_context, y_context, x_target):
        
        r = super().forward(x_context, y_context, x_context)
        r = self.conv_architecture(r)
        
        mean = r[:, ::2]
        scale = torch.exp(r[:, 1::2])
        
        distribution = torch.distributions.Normal(loc=mean, scale=scale)
        
        return distribution

In [32]:
x_context = torch.randn(size=(10, 5, 1))
x_target = torch.randn(size=(10, 8, 1))

y_context = torch.randn(size=(10, 5, 1))
y_target = torch.randn(size=(10, 8, 1))

In [33]:
convnp = StandardConvNP(input_dim=1, add_noise=AddHomoNoise())
mean, cov = convnp(x_context, y_context, x_target, num_samples=3)

In [34]:
convnp.loss(x_context, y_context, x_target, y_target, num_samples=3)

tensor(1.6421, grad_fn=<NegBackward>)

In [35]:
# optim = torch.optim.Adam(lr=1e-2, params=convnp.parameters())

# for i in range(1000):
    
#     optim.zero_grad()
    
#     loss = convnp.loss(x_context, y_context, x_target, y_target, num_samples=10)
    
#     print(loss)
    
#     loss.backward()
#     optim.step()