In [13]:
import MGTomo.model as mgmodel
import numpy as np
import astra
import MGTomo.tomoprojection as mgproj
from MGTomo.utils import myexp, mylog, mydiv
import MGTomo.Yfunctions as fcts
from scipy import interpolate
import h5py

import torch
from torch.func import grad

In [14]:
max_levels = 1
maxIter = [5,5]

In [15]:
# load image
images = ['vessel','gear','shepp', 'roux']
with h5py.File('images.h5','r') as f:
    x_orig = np.array(f[images[0]]) # shepp, vessel, gear
    
print(f'name={images[0]}, shape={x_orig.shape}, dtype={x_orig.dtype}')

name=vessel, shape=(1023, 1023), dtype=float64


In [16]:
x_orig = np.random.randn(63, 63)
#x_torch = torch.from_numpy(x_orig)
x_torch = torch.ones(63, 63)*0.2
x_torch.requires_grad = True

model = mgmodel.astra_model(63,{'mode' : 'line', 'num_angles' : 50, 'level_decrease' : 1})
fine_dim = model.dim
A = [mgproj.TomoTorch(model.proj_factory(fine_dim))]
b = [A[0](x_torch)]
level = {int(np.sqrt(A[0].shape[1])): 0}

for i in range(1,max_levels+1):
    corse_dim = model.reduce_dim(fine_dim)
    A.append(mgproj.TomoTorch(model.proj_factory(corse_dim)))
    b.append(torch.from_numpy(model.reduce_rhs(b[-1].detach().numpy(), fine_dim, corse_dim)))
    level.update({int(np.sqrt(A[i].shape[1])): i})
    fine_dim=corse_dim

In [17]:
print(level)

{63: 0, 31: 1}


In [18]:
for i in range(max_levels+1):
    assert b[i].shape[0]*b[i].shape[1] == A[i].shape[0], 'dimension mismatch'
    print(f'level {i}:', b[i].shape[0]*b[i].shape[1], A[i].shape[0], np.sqrt(A[i].shape[1]))

level 0: 3150 3150 63.0
level 1: 1550 1550 31.0


## Operators and objective definition

In [19]:
fh = lambda x: fcts.kl_distance(x, A[0], b[0])

In [20]:
def R(y):
    x = y[1:-1:2, 1:-1:2]
    return x

def bilinear_interpolation(v):
    inter = interpolate.RegularGridInterpolator((np.arange(1,2*v.shape[0],2), np.arange(1,2*v.shape[1],2)), v.detach().numpy(), method='linear', bounds_error=False, fill_value=0.0)
    xx, yy = np.meshgrid(np.arange(0,2*v.shape[0]+1), np.arange(0,2*v.shape[1]+1))
    return inter((yy,xx))

def P(x):
    return torch.from_numpy(bilinear_interpolation(x))

## Coarsening functions tester

In [21]:
def coarsen_fn(fh, x, y0, l):
    x0 = R(y0)
    x0.retain_grad()
    
    fH = lambda x: fcts.kl_distance(x, A[l], b[l])
    
    fhy0 = fh(y0)
    fhy0.backward()
    
    fHx0 = fH(x0)
    fHx0.backward()
    
    kappa = R(y0.grad) - x0.grad
    val = fH(x) + torch.sum(kappa * x)
    
    return val

In [22]:
y0 = torch.ones(63, 63)
y0.requires_grad = True

x = torch.ones(31, 31)*0.4

psi = lambda x: coarsen_fn(fh, x, y0, 1)
psi(x)

tensor(35329.7688, dtype=torch.float64)

## ML Tester

In [23]:
def MLO(fh, y, l=0):
    x = R(y)
    y0, x0 = y, R(y)
    y0.requires_grad = True
    #x0.requires_grad = True
    psi = lambda x: coarsen_fn(fh, x, y0, l+1)
    
    for i in range(maxIter[l]):
        x = fcts.SMART(psi, x, tau)
        
    if l < max_levels:
        x = MLO(psi, x, l+1)
        
    d = P(x-x0)
    y, a = fcts.armijo_linesearch(fh, y0, d)
    
    for i in range(maxIter[l]):
        y = fcts.SMART(psi, y, tau)
    
    return y

In [24]:
tau = 0.2
MLO(fh, y0)

TypeError: unsupported operand type(s) for *: 'float' and 'NoneType'