In [None]:
import numpy as np
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

import torch.optim as optim
from itertools import count
import sys; sys.path.append('../utils/'); from ggnn import Propagator

In [2]:
def gen_mix_mvn(n_components=30, d=2, n_samples=400):
    # Generate random multivariate Gaussians
    pi = np.random.dirichlet([0.8]*n_components)
    mu = np.random.multivariate_normal(np.zeros(d), np.eye(d)*25, size=(n_components))
    mu = np.random.rand(n_components, d) * 10 - 5
    sigma = np.zeros((n_components, d, d))
    for n in range(n_components):
        _tmpmat = np.random.rand(d,d)
        Q, _junk = np.linalg.qr(_tmpmat)
        lam = np.random.exponential(1, d)
        sigma[n] = np.matmul(np.matmul(Q,np.diag(lam)),Q.T)

    # Draw samples
    z = np.random.multinomial(n_samples, pi)
    smps = np.zeros((n_samples, d))
    indexes = np.stack((np.cumsum(np.concatenate(([0], z[:-1]))),
                        np.cumsum(z)), axis=1)
    for ixs, n, m, s in zip(indexes, z, mu, sigma):
        smps[slice(*ixs)] = np.random.multivariate_normal(m, s, size=n)
    
    return smps, (pi, mu, sigma)

In [3]:
n_components = 30   # number of Gaussians in original mixture
n_approx = 5        # number of Gaussians to approximate with
n_samples = 800     # number of samples to draw from mixture for visualisation.

# generate random Gaussian mixture
smps, pars = gen_mix_mvn(30, n_samples=n_samples)



In [4]:
torch.set_default_dtype(torch.double)

In [91]:
class Policy(nn.Module):
    def __init__(self, n_steps, n_dim, state_dim, annotation_dim, global_dim, global_annotation_dim):
        """
        n_steps: number of message passing rounds
        """
        super(Policy, self).__init__()
        
        self.n_steps = n_steps
        self.n_dim = n_dim
        self.state_dim = state_dim
        self.annotation_dim = annotation_dim
        self.global_dim = global_dim
        self.global_annotation_dim = global_annotation_dim

        
        self.propagator = Propagator(state_dim)
        
        self.out = nn.Sequential(
            nn.Linear(global_annotation_dim + n_dim * state_dim,1),
#             nn.Tanh(),
#             nn.Linear(global_dim, 1)
        )
        
        self.saved_log_probs = []
        self.rewards =[]
    def forward(self, pars):
        """
        pars: pi, mu, sigma 
        """

        n_components = len(pars[0])
        prop_state = []
        component_score = torch.zeros(1,n_components)
        for i in range(n_components):
            # prepare input
            annotations = torch.zeros(self.n_dim, self.annotation_dim, )
            annotations[:,0] = torch.from_numpy(pars[1][i])
            padding = torch.zeros(self.n_dim, self.state_dim - self.annotation_dim)
            state = torch.cat((annotations,padding), dim = 1)
            
            # adjacency matrix
            sigma = torch.from_numpy(pars[2][i])
            
            # message passing GRU type
            for t in range(self.n_steps):
                state = self.propagator(state, sigma)

            prop_state.append(state)

            # prediction
            global_state = torch.cat((torch.Tensor([[pars[0][i]]]), state.reshape(1,-1) ), dim=1)
            component_score[0,i] = self.out(global_state).item()


        return F.softmax(component_score, dim =1)




In [92]:
# GNN
n_steps = 4
global_dim = 4
global_annotation_dim = 1
n_dim = 2
state_dim = 4
annotation_dim = 1



policy = Policy(n_steps, n_dim, state_dim, annotation_dim, global_dim, global_annotation_dim)
policy.double()

Policy(
  (propagator): Propagator(
    (reset_gate): Sequential(
      (0): Linear(in_features=8, out_features=4, bias=True)
      (1): Sigmoid()
    )
    (update_gate): Sequential(
      (0): Linear(in_features=8, out_features=4, bias=True)
      (1): Sigmoid()
    )
    (transform): Sequential(
      (0): Linear(in_features=8, out_features=4, bias=True)
      (1): Tanh()
    )
  )
  (out): Sequential(
    (0): Linear(in_features=9, out_features=1, bias=True)
  )
)

In [135]:
def select_action(pars, n_approx):
    probs = policy(pars)
    m = Categorical(probs)
    n_ = 0
    sampled = []
    for _ in count():
        action = m.sample()
        if action.item() not in sampled:
            policy.saved_log_probs.append(m.log_prob(action))
            n_ += 1
            sampled.append(action.item())
        if n_== n_approx:
            break
    return sampled


def select_1action(pars):
    probs = policy(pars)
    m =Categorical(probs)
    action = m.sample()
    policy.saved_log_probs.append(m.log_prob(action))
    return action.item()


In [130]:
# KL divergence calculation
def calc_prob(x, pars):
    n_dim = len(x)
    n_mix = len(pars[0])
    prob = 0
    for m in range(n_mix):
        sigma_inv = np.linalg.inv(pars[2][m])
        _a = np.exp(-np.dot(x - pars[1][m], np.dot(sigma_inv, x - pars[1][m]))/2)
        _a = _a / (np.sqrt(2*np.pi)**n_dim * np.sqrt(np.linalg.det(pars[2][m])))
        prob += pars[0][m] * _a
    
    return prob


def calc_kl_from_samples(sample, pars1, pars2):
    kl = 0
    for i in sample:
        p = calc_prob(i, pars1)
        q = calc_prob(i, pars2)
        if q > 1e-10:
            kl += p * (np.log(p) - np.log(q))
    return kl


In [133]:
def calc_reward(pars, indices):
    ixs_top4 = indices
    ixs_other = np.array(list(set(np.argsort(pars[0])) - set(ixs_top4)))
    pi_approx = np.concatenate((pars[0][ixs_top4], [sum(pars[0][ixs_other])]))
    mu_approx = pars[1][ixs_top4]
    mu_other = np.dot(pars[0][ixs_other], pars[1][ixs_other])/pi_approx[-1]
    sigma_approx = pars[2][ixs_top4]
    sigma_other = -np.outer(mu_other, mu_other)
    for i in ixs_other:
        sigma_other += pars[0][i] * (pars[2][i] + np.outer(pars[1][i], pars[1][i]))

    mu_approx = np.concatenate((mu_approx, mu_other[None,:]), axis=0)
    sigma_approx = np.concatenate((sigma_approx, sigma_other[None,:]), axis=0)
    
    approx_pars = (pi_approx, mu_approx, sigma_approx)
    kl = calc_kl_from_samples(smps, pars, approx_pars)
    return kl
    

In [148]:
del policy.saved_log_probs[:]
del policy.rewards[:]

In [149]:
n_approx = 2
mix = select_1action(pars)
reward = calc_reward(pars, [mix])
policy.rewards.append(reward)

In [150]:
# optimize
optimizer = optim.Adam(policy.parameters(), 1e-3)

loss = policy.saved_log_probs[0] * policy.rewards[0]
optimizer.zero_grad()
loss.backward()
optimizer.step()

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn