In [1]:
import MGTomo.model as mgmodel
import numpy as np
import astra
import MGTomo.tomoprojection as mgproj
from MGTomo.utils import myexp, mylog, mydiv
import MGTomo.Yfunctions as fcts
from scipy import interpolate
from skimage.data import shepp_logan_phantom
from skimage.transform import resize

from MGTomo.gridop import P,R

import torch
from torch.func import grad

In [2]:
max_levels = 1
maxIter = [5,5]

In [3]:
N = 63
# load image
x_orig = shepp_logan_phantom()
x_orig = resize(x_orig, (N,N), anti_aliasing = False)

x_torch = torch.tensor(x_orig, requires_grad = True)

In [4]:
model = mgmodel.astra_model(N,{'mode' : 'line', 'num_angles' : 50, 'level_decrease' : 1})
fine_dim = model.dim
A = [mgproj.TomoTorch(model.proj_factory(fine_dim))]
b = [A[0](x_torch)]
level = {int(np.sqrt(A[0].shape[1])): 0}

for i in range(1,max_levels+1):
    coarse_dim = model.reduce_dim(fine_dim)
    A.append(mgproj.TomoTorch(model.proj_factory(coarse_dim)))
    b.append(torch.from_numpy(model.reduce_rhs(b[-1].detach().numpy(), fine_dim, coarse_dim)))
    level.update({int(np.sqrt(A[i].shape[1])): i})
    fine_dim=coarse_dim

In [5]:
fh = lambda x: fcts.kl_distance(x, A[0], b[0])

In [6]:
y0 = torch.ones(N, N)*0.5
y0.requires_grad = True

In [7]:
x0 = R(y0)
x0.retain_grad()
#print(x0)

fH = lambda x: fcts.kl_distance(x, A[1], b[1])

fhy0 = fh(y0)
fhy0.backward(retain_graph = True)

fHx0 = fH(x0)
fHx0.backward(retain_graph = True)

kappa = R(y0.grad) - x0.grad

x = torch.ones(coarse_dim, coarse_dim)*0.4
x.requires_grad = True

val = fH(x) + torch.sum(kappa * x)
y0.grad = None
print(val)

tensor(541820.6177, dtype=torch.float64, grad_fn=<AddBackward0>)


In [8]:
def coarsen_fn(fh, x, y0, l):
    x0 = R(y0)
    x0.retain_grad()
    
    fH = lambda x: fcts.kl_distance(x, A[l], b[l])
    
    fhy0 = fh(y0)
    fhy0.backward()
    
    fHx0 = fH(x0)
    fHx0.backward()
    
    kappa = R(y0.grad) - x0.grad
    val = fH(x) + torch.sum(kappa * (x-x0))
    y0.grad = None
    
    return val

In [9]:
y0 = torch.ones(N, N)*0.5
y0.requires_grad = True

In [10]:
x = torch.ones(coarse_dim, coarse_dim)*0.4
x.requires_grad = True

psi = lambda x: coarsen_fn(fh, x, y0, 1)
psi(x)

tensor(-2118311.8823, dtype=torch.float64, grad_fn=<AddBackward0>)

In [11]:
fH(x)

tensor(9794.1177, dtype=torch.float64, grad_fn=<AddBackward0>)

In [12]:
psi(x0)

tensor(220321.1527, dtype=torch.float64, grad_fn=<AddBackward0>)

In [13]:
fH(x0)

tensor(220321.1527, dtype=torch.float64, grad_fn=<AddBackward0>)