In [32]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.distributions.kl import kl_divergence

In [33]:
x_dim=2
y_dim=1

xspace = torch.linspace(-2, 2, 100).unsqueeze(0).transpose(0, 1)
yspace = torch.linspace(-2, 2, 100).unsqueeze(0).transpose(0, 1)
xx = torch.cat([xspace for i in range(yspace.size()[0])], dim=1)
xx1 = xx.transpose(0, 1)
linspace = torch.cat([xx1.unsqueeze(2), xx.unsqueeze(2)], dim=2).reshape(100**2,2)

result = linspace[:,0].unsqueeze(0).transpose(0, 1)
data = []
for i in range(2000):
    data.append([linspace, result])

In [34]:
class encoder(nn.Module):
    def __init__(self, output_sizes):
        super(encoder, self).__init__()
        self.output_sizes = output_sizes
        
        self.fc1 = nn.Linear(x_dim + y_dim, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3_sigma = nn.Linear(10, self.output_sizes)
        self.fc3_mu = nn.Linear(10, self.output_sizes)
        self.relu = nn.ReLU()
        
    def forward(self, x, y):
        #x and y are 1xn dimensional torch tensors
        
        xy = torch.cat([x, y], dim=1)
        l1 = self.relu(self.fc1(xy))
        l2 = self.relu(self.fc2(l1))
        logvar = self.fc3_sigma(l2)
        mu = self.fc3_mu(l2)
        sigma = torch.exp(logvar.mul(1/2))
        
        out_mu = torch.mean(mu, dim=0)
        out_sigma = torch.mean(sigma, dim=0)
        return out_mu.reshape(self.output_sizes, 1), out_sigma.reshape(self.output_sizes, 1), torch.distributions.Normal(out_mu, out_sigma)

In [35]:
class decoder(nn.Module):
    def __init__(self, encoded_size):
        super(decoder, self).__init__()
        self.encoded_size = encoded_size
        
        self.fc1 = nn.Linear(self.encoded_size + x_dim, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc21 = nn.Linear(10, 10)
        self.fc3_mu = nn.Linear(10, y_dim)
        self.fc3_sigma = nn.Linear(10, y_dim)
        self.relu = nn.ReLU()
        
    def forward(self, z, x):
        #x and z are 1xn dimensional torch tensors
        
        zmulti = torch.cat([z for i in range(x.size()[0])], dim=0)
    
        xz = torch.cat([x, zmulti], dim=1)
        
        l1 = self.relu(self.fc1(xz))
        l2 = self.relu(self.fc2(l1))
        l21 = self.relu(self.fc21(l2))
        out_mu = self.fc3_mu(l21)
        out_logvar = self.fc3_sigma(l21)
        out_sigma = torch.exp(out_logvar.mul(1/2))
        
        dist = torch.distributions.Normal(out_mu, out_sigma)
        
        return out_mu, out_sigma, dist

In [36]:
class CNP(nn.Module):
    def __init__(self, encoded_size):
        super(CNP, self).__init__()
        self.encoded_size = encoded_size
        self._encoder = encoder(encoded_size)
        self._decoder = decoder(encoded_size)
        
    def forward(self, context_x, context_y, target_x, target_y=None):
        en_mu, en_sigma, en_dist = self._encoder(context_x, context_y)
        representation = en_dist.rsample().unsqueeze(0)
        mu, sigma, dist = self._decoder(representation, target_x)
        
        if target_y is not None:
            log_p = dist.log_prob(target_y.transpose(0, 1)).sum()
            MSE = (mu - target_y.transpose(0,1)).pow(2).sum()
        else:
            log_p = None
            MSE = None
        return mu, sigma, log_p, en_dist, MSE

In [37]:
cnp = CNP(3)
optimizer = torch.optim.Adam(cnp.parameters(), lr=1e-3)
num_test_maximum = 1
plot_frequency = 1

def lossf(log_p, z_prior, z_posterior, MSE):
    KL =  kl_divergence(z_prior, z_posterior).prod()
    return - log_p + KL


def train(data, cnp, epochs, test_data=None):
    cnp.train()
    for epoch in range(epochs):
        iterations = 0
        total_loss = 0
        for function in data:
            optimizer.zero_grad()
            num_points = function[0].size()[0]
            perm = torch.randperm(num_points)
            #num_context = np.random.randint(num_points - num_test_maximum, num_points)
            num_context = num_test_maximum
            context_x = function[0][perm][0:num_context]
            context_y = function[1][perm][0:num_context]
            test_x = function[0][perm][num_context:num_points]
            test_y = function[1][perm][num_context:num_points]
            
            mu, sigma, log_p, en_dist, MSE = cnp(context_x, context_y, test_x, test_y)
            loss = lossf(log_p, torch.distributions.Normal(torch.zeros(cnp.encoded_size), 1), en_dist, MSE)
            loss.backward()
            optimizer.step()
            
            total_loss += loss
            iterations += 1
            
            if iterations % 200 == 0:
                print(iterations)
        
        if epoch % plot_frequency == 0:
            print("EPOCH: {}, LOSS {}".format(epoch, total_loss))
        '''
            test_x = test_data[0].unsqueeze(0)
            test_y = test_data[1].unsqueeze(0)
            
            linspace = torch.linspace(-2, 2, 100).unsqueeze(0)
            print(linspace)
            mu, sigma, _, _ = cnp(test_x, test_y, linspace)
            lin = (linspace.numpy()[0])
            low = np.array((mu-sigma).detach().numpy().T[0])
            high = np.array((mu+sigma).detach().numpy().T[0])
            plt.plot(linspace.numpy()[0], mu.detach().numpy())
            plt.fill_between(lin, low, high, facecolor='#65c9f7', interpolate=True)
            plt.scatter(test_x.numpy(), test_y.numpy(), c='black')
            plt.show()
        '''
            
def plot_priors(np, number):
    norm = torch.distributions.Normal(torch.zeros(np.encoded_size), 1)
    for i in range(number):
        z = norm.rsample().unsqueeze(0).transpose(0, 1)
        lin = torch.linspace(-2, 2, 100).unsqueeze(0)
        ys, _, _ = np._decoder(z, lin)
        yplot = ys.transpose(0,1).squeeze(0).detach().numpy()
        xplot = lin.squeeze(0).numpy()
        plt.plot(xplot, yplot)

In [25]:
train(data, cnp, 100000)

200


KeyboardInterrupt: 

In [26]:
xspace = torch.linspace(-2, 2, 100).unsqueeze(0).transpose(0, 1)
yspace = torch.linspace(-2, 2, 100).unsqueeze(0).transpose(0, 1)
xx = torch.cat([xspace for i in range(yspace.size()[0])], dim=1)
xx1 = xx.transpose(0, 1)
linspace = torch.cat([xx1.unsqueeze(2), xx.unsqueeze(2)], dim=2).reshape(100**2,2)

In [27]:
function = data[0]

In [28]:
num_points = function[0].size()[0]
perm = torch.randperm(num_points)
num_context = np.random.randint(num_points - num_test_maximum, num_points)
context_x = function[0][perm][0:num_context]
context_y = function[1][perm][0:num_context]
test_x = function[0][perm][num_context:num_points]
test_y = function[1][perm][num_context:num_points]


mu, sigma, log_p, en_dist, mse = cnp(context_x, context_y, test_x, test_y)

In [29]:
mu

tensor([[-0.0021]], grad_fn=<AddmmBackward>)

In [30]:
test_x

tensor([[-0.5859,  1.6768]])

In [31]:
test_y

tensor([[-0.5859]])