In [1]:
import numpy as np
import torch
from torch.autograd import Variable
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import random
from copy import copy, deepcopy
from torch.distributions.kl import kl_divergence

In [2]:
class function_dist():
    def __init__(self):
        self.param_low = -1
        self.param_high = 1
    
    def f(self, x, param):
        return np.sin(param*x)
    
    def sample(self, number, range_low, range_high):
        sample = []
        lin = np.linspace(range_low, range_high, 100)
        for i in range(number):
            param = np.random.uniform(self.param_low, self.param_high)
            fx = self.f(lin, param)
            sample.append([lin, fx])
        return sample

f = function_dist()
data = f.sample(100, -1, 1)

In [170]:
class process(nn.Module):
    def __init__(self, x_dim, y_dim, z_dim, r_dim):
        super(process, self).__init__()
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.z_dim = z_dim
        self.r_dim = r_dim
        self.gaussian = torch.distributions.Normal(torch.zeros(r_dim), torch.ones(r_dim))
        
        self.fc1 = nn.Linear(z_dim + x_dim, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, y_dim)
        self.relu = nn.ReLU()
        
        self.rz1 = nn.Linear(r_dim, 10)
        self.rz2 = nn.Linear(10, 10)
        self.rzmu = nn.Linear(10, z_dim)
        self.rzlogvar = nn.Linear(10, z_dim)
        
        self.dr1 = nn.Linear(x_dim + y_dim, 10)
        self.dr2 = nn.Linear(10, 10)
        self.drmu = nn.Linear(10, r_dim)
        self.drlogvar = nn.Linear(10, r_dim)
        
        self.optimizer = torch.optim.Adam(self.parameters(), lr = 3e-4)
        
    def forward(self, z, x):
        zx = []
        for y in x:
            zx.append(torch.cat((z, torch.tensor([y]))).reshape(3, 1))

        zx = torch.cat(zx)
        print(zx)
        o1 = self.relu(self.fc1(zx))
        o2 = self.relu(self.fc2(o1))
        output = self.fc3(o2)
        return output
    
    def z_from_prior(self):
        o1 = self.relu(self.rz1(self.gaussian.rsample()))
        o2 = self.relu(self.rz2(o1))
        mu = self.rzmu(o2)
        logvar = self.rzlogvar(o2)
        return mu, logvar.mul(1/2).exp()
        
        
    def z_from_posterior(self, r_mean, r_std_dev):
        norm = torch.distributions.Normal(r_mean, r_std_dev).rsample()
        o1 = self.relu(self.rz1(norm))
        o2 = self.relu(self.rz2(o1))
        mu = self.rzmu(o2)
        logvar = self.rzlogvar(o2)
        return mu, logvar.mul(1/2).exp()
        

    def draw_sample(self, z_mean, z_std_dev):
        posterior = torch.distributions.Normal(z_mean, z_std_dev)
        z = self.gaussian.rsample()
        lin = np.linspace(-1, 1, 100)
        ys = []
        for i in range(100):
            ys.append(self.forward(z, torch.tensor([lin[i]])).detach().numpy())
        return lin, np.array(ys).T[0]
    
    def plot_samples(self, z_mean, z_std_dev, number):
        for _ in range(number):
            sample_x, sample_y = self.draw_biased_sample(z_mean, z_std_dev)
            plt.plot(sample_x, sample_y, c = 'b')
        plt.show()
        
    def calculate_r(self, data):
        o1 = self.relu(self.dr1(data))
        o2 = self.relu(self.dr2(o1))
        mu = self.drmu(o2)
        logvar = self.drlogvar(o2)
        return mu, logvar.mul(1/2).exp()
    
    def aggregate(self, mus, stds):
        return mus.sum(dim=0)/len(mus), stds.sum(dim=0)/len(stds)
    
    def lossf(self, prior_dist, posterior_dist, x, y):
        MSE = ((x-y)**2).sum()/len(x)
        kl = kl_divergence(prior_dist, posterior_dist)
        return MSE + kl
    
    def train(self, data, epochs):
        for epoch in range(epochs):
            for function in data:
                self.optimizer.zero_grad()
                num_context_points = random.randint(0, 100)
                function = np.array(function).T
                np.random.shuffle(function)
                context = function[0:num_context_points, :]
                function = torch.tensor(function, requires_grad=True).float()
                context = torch.tensor(context, requires_grad=True).float()
                r_mu, r_std = self.calculate_r(function)
                r_mu, r_std = self.aggregate(r_mu, r_std)
                rc_mu, rc_std = self.calculate_r(context)
                rc_mu, rc_std = self.aggregate(rc_mu, rc_std)
                z_prior = self.z_from_posterior(r_mu, r_std)
                
                z_posterior_mu, z_posterior_std = self.z_from_posterior(rc_mu, rc_std)
                z_posterior = torch.distributions.Normal(z_posterior_mu, z_posterior_std).rsample()
                x = context[:,0]
                newy = []
                for point in x:
                    print(z_posterior, x)
                    newy.append(self.forward(z_posterior, x))
                y = context[:,1]
                

In [171]:
model = process(1, 1, 2, 2)

In [172]:
model.train(data, 1)

tensor([-0.6585, -0.7952], grad_fn=<AddBackward0>) tensor([-0.7576,  0.0101, -0.5556,  1.0000, -0.9394, -0.8182,  0.1515,  0.1111,
         0.4343,  0.1717,  0.7374, -0.1919, -0.1313], grad_fn=<SelectBackward>)
tensor([[-0.6585],
        [-0.7952],
        [-0.7576],
        [-0.6585],
        [-0.7952],
        [ 0.0101],
        [-0.6585],
        [-0.7952],
        [-0.5556],
        [-0.6585],
        [-0.7952],
        [ 1.0000],
        [-0.6585],
        [-0.7952],
        [-0.9394],
        [-0.6585],
        [-0.7952],
        [-0.8182],
        [-0.6585],
        [-0.7952],
        [ 0.1515],
        [-0.6585],
        [-0.7952],
        [ 0.1111],
        [-0.6585],
        [-0.7952],
        [ 0.4343],
        [-0.6585],
        [-0.7952],
        [ 0.1717],
        [-0.6585],
        [-0.7952],
        [ 0.7374],
        [-0.6585],
        [-0.7952],
        [-0.1919],
        [-0.6585],
        [-0.7952],
        [-0.1313]], grad_fn=<CatBackward>)


RuntimeError: size mismatch, m1: [39 x 1], m2: [3 x 10] at ..\aten\src\TH/generic/THTensorMath.cpp:961