In [1]:
import MGTomo.model as mgmodel
import numpy as np
import MGTomo.tomoprojection as mgproj
from MGTomo.utils import myexp, mylog, mydiv
import MGTomo.functions as fcts
from skimage.data import shepp_logan_phantom
from skimage.transform import resize
from MGTomo.optimize import armijo_linesearch

from MGTomo.gridop import P,R

import torch
from torch.func import grad

from torch.linalg import matrix_norm

import matplotlib.pyplot as plt 

In [2]:
max_levels = 3
maxIter = [2,2,2,2]

In [3]:
N = 1023
# load image
x_orig = shepp_logan_phantom()
x_orig = resize(x_orig, (N,N), anti_aliasing = False)

x_torch = torch.tensor(x_orig, requires_grad = True)

In [4]:
model = mgmodel.astra_model(N,{'mode' : 'line', 'num_angles' : 50, 'level_decrease' : 1})
fine_dim = model.dim
A = [mgproj.TomoTorch(model.proj_factory(fine_dim))]
b = [A[0](x_torch)]
level = {int(np.sqrt(A[0].shape[1])): 0}

for i in range(1,max_levels+1):
    coarse_dim = model.reduce_dim(fine_dim)
    A.append(mgproj.TomoTorch(model.proj_factory(coarse_dim)))
    b.append(torch.from_numpy(model.reduce_rhs(b[-1].detach().numpy(), fine_dim, coarse_dim)))
    level.update({int(np.sqrt(A[i].shape[1])): i})
    fine_dim=coarse_dim

In [5]:
#c0 = A[0].sumnorm()
#tau0 = 0.5 * 1/c0

In [6]:
fh = lambda x: fcts.kl_distance(x, A[0], b[0])

In [7]:
c0 = 56.0952
tau0 = 0.5 * 1/c0

In [8]:
def coarse_condition(y, grad_y, kappa, y_last = None):
    gcond = (matrix_norm(R(grad_y), ord = 1) >= kappa * matrix_norm(grad_y, ord = 1))
    if y_last is not None:
        y_diff_norm = matrix_norm(y_last - y, ord = 1)
        return gcond and (y_diff_norm >= kappa)
    else:
        #print('y_last was none')
        return gcond

In [9]:
def MLO(fh, y, last_pts: list, l=0, kappa = 0.5):
    x = R(y).detach().requires_grad_(True)
    y0, x0 = y, x.clone().detach().requires_grad_(True)
    
    fhy0 = fh(y0)
    fhy0.backward(retain_graph = True)
    grad_fhy0 = y0.grad.clone()
    y0.grad.zero_()
    
    #print('coarse correction at l = ', l)
    if coarse_condition(y, grad_fhy0, kappa, last_pts[l]):
        #print(l, ' : coarse correction activated')
        
        last_pts[l] = y0.clone().detach()
    
        fH = lambda x: fcts.kl_distance(x, A[l+1], b[l+1])
        fHx0 = fH(x0)
        fHx0.backward(retain_graph = True)
        grad_fHx0 = x0.grad.clone()
        x0.grad.zero_()

        kappa = R(grad_fhy0) - grad_fHx0

        psi = lambda x: fH(x) + torch.sum(kappa * (x-x0))

        for i in range(maxIter[l+1]):
            #print(l, ': psi - ', i)
            x.retain_grad()
            val = fcts.SMART(psi, x, tau[l+1])
            x = val.clone().detach().requires_grad_(True)
            
        if l < max_levels-1:
            x, last_pts, _ = MLO(psi, x, last_pts, l+1)

        assert psi(x) < psi(x0), 'psi(x) < psi(x0) = fH(x0) does not hold'
    else: 
        print(l, ' : coarse correction not activated')
    
    d = P(x-x0)
    z, a = armijo_linesearch(fh, y0, d)
    
    assert z.min() >= 0
    
    for i in range(maxIter[l]):
        #print(l, ': fh - ', i)
        z.retain_grad()
        zval = fcts.SMART(fh, z, tau[l])
        y0.grad.zero_()
        z = zval.clone().detach().requires_grad_(True)
    return z, last_pts, a

In [10]:
tau = [tau0]*(max_levels+1)

In [None]:
a = []
z0 = torch.rand(N, N, requires_grad = True)
last_pts = [None]*(max_levels+1)
print(fh(z0))
for i in range(100):
    val, ylast, alpha = MLO(fh, z0, last_pts)
    
    z0 = val.clone().detach().requires_grad_(True)
    a.append(alpha)
    print(i, ': ', fh(z0))

tensor(5.9085e+08, dtype=torch.float64, grad_fn=<AddBackward0>)
0 :  tensor(286476.2362, dtype=torch.float64, grad_fn=<AddBackward0>)
1 :  tensor(164373.7015, dtype=torch.float64, grad_fn=<AddBackward0>)
2 :  tensor(96462.3385, dtype=torch.float64, grad_fn=<AddBackward0>)
3 :  tensor(64629.1772, dtype=torch.float64, grad_fn=<AddBackward0>)
4 :  tensor(46626.0025, dtype=torch.float64, grad_fn=<AddBackward0>)
5 :  tensor(35377.9832, dtype=torch.float64, grad_fn=<AddBackward0>)
6 :  tensor(27869.6740, dtype=torch.float64, grad_fn=<AddBackward0>)
7 :  tensor(22612.4408, dtype=torch.float64, grad_fn=<AddBackward0>)
8 :  tensor(18790.1198, dtype=torch.float64, grad_fn=<AddBackward0>)
9 :  tensor(15922.7988, dtype=torch.float64, grad_fn=<AddBackward0>)
10 :  tensor(13713.8416, dtype=torch.float64, grad_fn=<AddBackward0>)
11 :  tensor(11972.5598, dtype=torch.float64, grad_fn=<AddBackward0>)
12 :  tensor(10572.2615, dtype=torch.float64, grad_fn=<AddBackward0>)
13 :  tensor(9426.3316, dtype=torc

In [None]:
plt.imshow(z0.detach().numpy(), cmap = 'gray')