In [7]:
%pdb

Automatic pdb calling has been turned ON


In [10]:
import torch
import torch.nn as nn
from torch.nn import Parameter


class Kernel(nn.Module):
    def __init__(self, string):
        super().__init__()
        self.string = string
        self.params = nn.ParameterList()
        self.n_param_groups = 0
        self.index = 0
        self.value = self.getValue() 
        self.t0=Parameter(torch.zeros(1))
        self.t1=Parameter(torch.ones(1))
    
    def getValue(self):
        value = self.parseExpression()
        if self.hasNext():
            raise Exception(
                "Unexpected character found: '" +
                self.peek() +
                "' at index " +
                str(self.index))
        return value
    
    def peek(self):
        return self.string[self.index:self.index + 1]
    
    def hasNext(self):
        return self.index < len(self.string)
    
    def parseExpression(self):
        return self.parseAddition()
    
    def parseAddition(self):
        values = [self.parseMultiplication()]
        while True:
            char = self.peek()
            if char == '+':
                self.index += 1
                values.append(self.parseMultiplication())
            else:
                break
        if len(values)==1:
            return values[0]
        else:
            return {'op':'+', 'values':values}
    
    def parseMultiplication(self):
        values = [self.parseParenthesis()]
        while True:
            char = self.peek()
            if char == '*':
                self.index += 1
                values.append(self.parseParenthesis())
            else:
                break
        if len(values) == 1:
            return values[0]
        else:
            return {'op':'*', 'values':values}
    
    def parseParenthesis(self):
        char = self.peek()
        if char == '(':
            self.index += 1
            value = self.parseExpression()
            if self.peek() != ')':
                raise Exception(
                    "No closing parenthesis found at character "
                    + str(self.index))
            self.index += 1
            return value
        else:
            return self.parseValue()
   
    def makeGaussianParams(self, var_init=-5):
         _mu = Parameter(torch.randn(1))
         _mu.type="mu"
         _mu.param_group = self.n_param_groups
         #_mu.lr=0.01
         self.params.append(_mu)

         _var = Parameter(torch.ones(1)*var_init)
         _var.type="var"
         _var.param_group = self.n_param_groups
         #_var.lr=0.1
         self.params.append(_var)

         self.n_param_groups += 1
         return {'_mu':_mu, '_var':_var}

    def sampleGaussian(self, params, n_runs=1):
        mu = params['_mu']*1 #to that it's not a parameter
        var = F.softplus(params['_var'])
        dist = Normal(mu, var)
        x = dist.rsample(sample_shape=torch.Size([n_runs]))
#        print(mu.size(), x.size())
#        x = mu.unsqueeze(0).repeat(n_runs, 1)
        #kl = torch.zeros(n_runs)
        n = Normal(self.t0.data, self.t1.data)
        kl = torch.distributions.kl._kl_normal_normal(dist, n).repeat(n_runs)
        #kl = (dist.log_prob(x) - self.n.log_prob(x)).reshape(n_runs)
        return x, kl

    def parseValue(self):
        char = self.peek()
        self.index += 1
        if char=="C":
            return {'op':"Constant", '_c':self.makeGaussianParams()}
        elif char=="W":
            return {'op':'WhiteNoise', '_var':self.makeGaussianParams()}
        elif char=="R":
            return {'op':'RBF', '_l2':self.makeGaussianParams()}
        elif char=="E":
            return {'op':'ExpSinSq', '_t':self.makeGaussianParams(), '_l2':self.makeGaussianParams()}
        #elif char=="X":
        #    return {'op':'Default'}
        else:
            raise NotImplementedError("Cannot parse char: " + str(char))

    def forward(self, x1, x2, kernel=None):
        batch_size = x1.size(0)
        n = x1.size(1)
        if kernel is None: kernel=self.value
        if kernel['op'] == '+':
            t, kl = self.forward(x1, x2, kernel['values'][0])
            for v in kernel['values'][1:]:
                _t, _kl = self.forward(x1, x2, v)
                t += _t
                kl += _kl
            return t, kl
        elif kernel['op'] == '*':
            t, kl = self.forward(x1, x2, kernel['values'][0])
            for v in kernel['values'][1:]:
                _t, _kl = self.forward(x1, x2, v)
                t *= _t
                kl += _kl
            return t, kl
        elif kernel['op'] == 'Constant':
            _c, kl_c = self.sampleGaussian(kernel['_c'], n_runs=batch_size)
            c = F.softplus(_c).repeat(1, n) 
            return c, kl_c
        elif kernel['op'] == 'WhiteNoise':
            _var, kl_var = self.sampleGaussian(kernel['_var'], n_runs=batch_size)
            var = F.softplus(_var-3)
            return (x1==x2).float() * var, kl_var
        elif kernel['op'] == 'RBF':
            _l2, kl_l2 = self.sampleGaussian(kernel['_l2'], n_runs=batch_size)
            l2 = F.softplus(_l2)
            return (-(x1-x2)**2/(2*l2)).exp(), kl_l2
        elif kernel['op'] == 'ExpSinSq':
            _t, kl_t = self.sampleGaussian(kernel['_t'], n_runs=batch_size)
            t = F.softplus(_t)
            _l2, kl_l2 = self.sampleGaussian(kernel['_l2'], n_runs=batch_size)
            l2 = F.softplus(_l2)
            return (-2*(math.pi*(x1-x2).abs()/t).sin()**2/l2).exp(), kl_t + kl_l2
        #elif kernel['op'] == 'WhiteNoise':
        #    return (x1==x2).float(), torch.zeros(batch_size)
        else:
            raise NotImplementedError()


In [19]:
k = Kernel("(R+R)*W")

In [20]:
k.value

{'op': '*',
 'values': [{'op': '+',
   'values': [{'op': 'RBF', '_l2': {'_mu': Parameter containing:
      tensor([1.2557], requires_grad=True), '_var': Parameter containing:
      tensor([-5.], requires_grad=True)}},
    {'op': 'RBF', '_l2': {'_mu': Parameter containing:
      tensor([0.7595], requires_grad=True), '_var': Parameter containing:
      tensor([-5.], requires_grad=True)}}]},
  {'op': 'WhiteNoise', '_var': {'_mu': Parameter containing:
    tensor([-0.0822], requires_grad=True), '_var': Parameter containing:
    tensor([-5.], requires_grad=True)}}]}