In [1]:
import torch
from pysymbolic.models.special_functions import MeijerG
import numpy as np

In [2]:
class meijerPytorch(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, theta, order=None):
        input_ = input.tolist()
        theta_ = theta.tolist()
        order_ = order.tolist()
        func = MeijerG(theta=theta_, order=[int(i) for i in order_])
        ctx.save_for_backward(input, theta,order)
        output = func.evaluate(input_)
        return torch.FloatTensor(output)

    @staticmethod
    def backward(ctx, grad_output):
        input, theta, order = ctx.saved_tensors
        input_ = input.tolist()
        theta_ = theta.tolist()
        order_ = order.tolist()
        
        grad_input = grad_weight = grad_bias = None
        func = MeijerG(theta=theta_, order=[int(i) for i in order_])
        grads = torch.FloatTensor([i[0] for i in func.gradients(input)])

        return None, grads, None

After creating the operation meijerPytorch that defines the forward and backward pass we test the implementation on a sinusoidale test function

In [None]:
meij = meijerPytorch.apply
x = torch.FloatTensor(np.linspace(0.05,0.95,1))
y = torch.sin(x*6)
theta = torch.FloatTensor([1,1,1,0,1])
learning_rate= 0.1

for i in range(20):
    theta.requires_grad = True
    order = torch.FloatTensor([1, 2, 2, 2])

    t = meij(x,theta ,order)
    loss = torch.mean((t-y)**2)
    print(loss)
    loss.backward()
    with torch.no_grad():
        theta -= learning_rate * theta.grad


tensor(0.0609, grad_fn=<MeanBackward0>)
tensor(0.0634, grad_fn=<MeanBackward0>)
tensor(0.0676, grad_fn=<MeanBackward0>)
tensor(0.0721, grad_fn=<MeanBackward0>)
tensor(0.0762, grad_fn=<MeanBackward0>)
tensor(0.0794, grad_fn=<MeanBackward0>)
tensor(0.0818, grad_fn=<MeanBackward0>)
tensor(0.0835, grad_fn=<MeanBackward0>)
tensor(0.0847, grad_fn=<MeanBackward0>)
tensor(0.0855, grad_fn=<MeanBackward0>)
tensor(0.0861, grad_fn=<MeanBackward0>)
tensor(0.0864, grad_fn=<MeanBackward0>)
tensor(0.0867, grad_fn=<MeanBackward0>)
