In [None]:
import MGTomo.model as mgmodel
import MGTomo.tomoprojection as mgproj
from MGTomo.utils import mylog
import MGTomo.functions as fcts
from MGTomo.optimize import armijo_linesearch, box_bounds_optimized
import MGTomo.coarse_corrections as CC
from MGTomo.gridop import RBox as R, PBox as P

from MGTomo import gridop

import time
import numpy as np
import torch
from torch.linalg import norm

import matplotlib.pyplot as plt 
from skimage import data
from skimage.transform import resize

import datetime


hparams = {
    "image": "shepp_logan",
    "CC": "Bregman",
    "N": 1023,
    "max_levels": 2,
    "maxIter": [1,10,10,16,32,128],
    "num_angels0": 200,
    "P_inf" : 1,
    #"SL_iterate_count": 0,
    "ML_iterate_count": 120,
    "kappa": 0.49,
    "eps": 0.1,
    #"SL_image_indices": range(0,0,0),
    "ML_image_indices": range(0,10),
    'remark': 'testing different eps'
}

x_orig = data.shepp_logan_phantom()
x_orig = resize(x_orig, (hparams["N"],hparams["N"]), anti_aliasing = False)

x_torch = torch.tensor(x_orig, requires_grad = True)


model = mgmodel.astra_model(hparams["N"],{'mode' : 'line', 'num_angles' : hparams["num_angels0"], 'level_decrease' : 1})
fine_dim = model.dim
A = [mgproj.TomoTorch(model.proj_factory(fine_dim))]
b = [A[0](x_torch)]
level = {int(np.sqrt(A[0].shape[1])): 0}
P_nonzero = []


for i in range(1,hparams["max_levels"]+1):
    coarse_dim = model.reduce_dim(fine_dim)
    model_coarse = mgmodel.astra_model(coarse_dim, {'mode' : 'line', 'num_angles' : min(int(coarse_dim*np.pi/4),100), 'level_decrease' : 1})
    A.append(mgproj.TomoTorch(model_coarse.proj_factory(coarse_dim)))
    x_resized = resize(x_orig, (coarse_dim, coarse_dim), anti_aliasing=False)
    xT_resized = torch.tensor(x_resized, requires_grad = True)
    b.append(A[-1](xT_resized))
    P_nonzero.append(gridop.compute_nonzero_elements_of_P(coarse_dim))
    level.update({int(np.sqrt(A[i].shape[1])): i})
    fine_dim=coarse_dim

for i in range(hparams["max_levels"]+1):
    assert b[i].shape[0]*b[i].shape[1] == A[i].shape[0], 'dimension mismatch'
    print(f'level {i}:', b[i].shape[0], np.sqrt(A[i].shape[1]))

fh = lambda x: fcts.kl_distance(x, A[0], b[0])
tau = [torch.reciprocal(Ai.sumnorm_opt()) * 0.5 for Ai in A]

def MLO_box(fh, y, lh, uh, last_pts: list, l=0, kappa = hparams["kappa"], eps = hparams["eps"]):
    x = R(y).detach().requires_grad_(True)
    fhy0 = fh(y)
    #fhy0.backward(retain_graph = True)
    fhy0.backward()
    grad_fhy0 = y.grad.clone()
    y.grad = None
    
    if CC.coarse_condition_bregman(y, grad_fhy0, kappa, eps, last_pts[l]):
    #if True:
        print(l, ' : coarse correction activated')
        last_pts[l] = y.clone().detach()
    
        x0 = x.clone().detach().requires_grad_(True)
        fH = lambda x: fcts.kl_distance(x, A[l+1], b[l+1]) 
        fHx0 = fH(x0)
        fHx0.backward(retain_graph = True)
        grad_fHx0 = x0.grad.clone()
        x0.grad = None

        kappa = R(grad_fhy0) - grad_fHx0
        del grad_fHx0

        with torch.no_grad():
            psi = lambda x: fH(x) + torch.sum(kappa * x)
            lH, uH = box_bounds_optimized(y, x, hparams["P_inf"], lh, uh, P_nonzero[l])

        logvH_new = mylog(x - lH) - mylog(uH - x)
        for i in range(hparams["maxIter"][l+1]):
            #x.retain_grad()
            val, logvH_new = fcts.BSMART_general(psi, x, logvH_new, tau[l+1], lH, uH)
            x = val.detach().requires_grad_(True)
            del val
            x.grad = None
            
        if l < hparams["max_levels"]-1:
            x, last_pts = MLO_box(psi, x,lH, uH, last_pts, l+1)

        d = P(x-x0)
        z, _ = armijo_linesearch(fh, y, d)
        y = z.detach().requires_grad_(True)
    else: 
        print(l, ' : coarse correction not activated')

    logvh_new = mylog(y - lh) - mylog(uh - y)
    
    for i in range(hparams["maxIter"][l]):
        #y.retain_grad()
        yval, logvh_new = fcts.BSMART_general(fh, y, logvh_new, tau[l], lh, uh)
        y = yval.detach().requires_grad_(True)
        del yval
        y.grad = None
    return y, last_pts


z0 = torch.ones(hparams["N"], hparams["N"]) * 0.5
z0.requires_grad_(True)
last_pts = [None]*(hparams["max_levels"]+1)

lh = torch.zeros_like(z0)
uh = torch.ones_like(z0)

rel_f_err = []
rel_f_err.append((norm(z0 - x_torch, 'fro')/norm(z0, 'fro')).item())

norm_fval = []
norm_fval.append(torch.tensor(1.))

fhz = fh(z0)

fhz.backward(retain_graph=True)
Gz0 = norm(z0.grad, 'fro')
z0.grad = None

norm_grad = []
norm_grad.append(torch.tensor(1.))

iteration_times_ML = []
iteration_times_ML.append(0)

for i in range(hparams['ML_iterate_count']):
    iteration_start_time_ML = time.time()
    
    val, ylast = MLO_box(fh, z0, lh, uh, last_pts)
    iteration_end_time_ML = time.time()
    iteration_time_ML = iteration_end_time_ML - iteration_start_time_ML

    iteration_times_ML.append(iteration_time_ML)
    z0 = val.clone().detach().requires_grad_(True)
    rel_f_err.append((norm(z0-x_torch, 'fro')/norm(z0, 'fro')).item())
    fval = fh(z0)
    norm_fval.append((fval/fhz).item())
    fval.backward(retain_graph=True)
    norm_grad.append((norm(z0.grad, 'fro')/Gz0).item())
    z0.grad = None

    print(f"Iteration {i}: {fh(z0)} - Time: {iteration_time_ML:.6f} seconds")

print(f"Overall time for all iterations: {sum(iteration_times_ML):.6f} seconds")
cumaltive_times_ML = [sum(iteration_times_ML[:i+1]) for i in range(len(iteration_times_ML))]

level 0: 200 1023.0
level 1: 100 511.0
level 2: 100 255.0
0  : coarse correction activated
1  : coarse correction activated
Iteration 0: 7427490.200336085 - Time: 30.567667 seconds
tensor(239182.2500, requires_grad=True)
0  : coarse correction activated
tensor(23719.2285, requires_grad=True)
1  : coarse correction activated
Iteration 1: 3799077.389332043 - Time: 30.256962 seconds
tensor(2800.0151, requires_grad=True)
0  : coarse correction activated
tensor(61.1547, requires_grad=True)
1  : coarse correction activated
Iteration 2: 2058670.2134430325 - Time: 30.417014 seconds
tensor(1538.0952, requires_grad=True)
0  : coarse correction activated
tensor(37.3569, requires_grad=True)
1  : coarse correction activated
Iteration 3: 1224709.2592820497 - Time: 30.350555 seconds
tensor(829.0059, requires_grad=True)
0  : coarse correction activated
tensor(27.6955, requires_grad=True)
1  : coarse correction activated
Iteration 4: 812946.0525858547 - Time: 30.375039 seconds
tensor(452.6335, requires

KeyboardInterrupt: 

In [2]:
for i in range(40,50):
    iteration_start_time_ML = time.time()
    
    val, ylast = MLO_box(fh, z0, lh, uh, last_pts)
    iteration_end_time_ML = time.time()
    iteration_time_ML = iteration_end_time_ML - iteration_start_time_ML

    iteration_times_ML.append(iteration_time_ML)
    z0 = val.clone().detach().requires_grad_(True)
    rel_f_err.append((norm(z0-x_torch, 'fro')/norm(z0, 'fro')).item())
    fval = fh(z0)
    norm_fval.append((fval/fhz).item())
    fval.backward(retain_graph=True)
    norm_grad.append((norm(z0.grad, 'fro')/Gz0).item())
    z0.grad = None

    print(f"Iteration {i}: {fh(z0)} - Time: {iteration_time_ML:.6f} seconds")

0  : coarse correction activated
1  : coarse correction activated


Iteration 40: 27196.51908176774 - Time: 30.538825 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 41: 26573.063465931365 - Time: 30.528802 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 42: 25972.213658156772 - Time: 30.647747 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 43: 25392.84823452975 - Time: 30.466639 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 44: 24833.919101396372 - Time: 30.430084 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 45: 24294.447865086673 - Time: 30.463020 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 46: 23773.516898911603 - Time: 30.263099 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 47: 23270.258159563375 - Time: 31.127694 seconds
0  : coarse correction activated
1  : coarse correction activated
It

In [3]:
for i in range(60,80):
    iteration_start_time_ML = time.time()
    
    val, ylast = MLO_box(fh, z0, lh, uh, last_pts)
    iteration_end_time_ML = time.time()
    iteration_time_ML = iteration_end_time_ML - iteration_start_time_ML

    iteration_times_ML.append(iteration_time_ML)
    z0 = val.clone().detach().requires_grad_(True)
    rel_f_err.append((norm(z0-x_torch, 'fro')/norm(z0, 'fro')).item())
    fval = fh(z0)
    norm_fval.append((fval/fhz).item())
    fval.backward(retain_graph=True)
    norm_grad.append((norm(z0.grad, 'fro')/Gz0).item())
    z0.grad = None

    print(f"Iteration {i}: {fh(z0)} - Time: {iteration_time_ML:.6f} seconds")

0  : coarse correction activated
1  : coarse correction activated
Iteration 60: 21858.58799242591 - Time: 30.717260 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 61: 21418.32727738731 - Time: 30.748801 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 62: 20992.10570035365 - Time: 30.651742 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 63: 20579.32037413214 - Time: 30.415018 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 64: 20179.38972685728 - Time: 30.452938 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 65: 19791.773311998968 - Time: 30.488975 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 66: 19415.9425740786 - Time: 30.243572 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 67: 19051.422692879587 - Time: 30.465314 seconds
0  : co

In [4]:
for i in range(80,100):
    iteration_start_time_ML = time.time()
    
    val, ylast = MLO_box(fh, z0, lh, uh, last_pts)
    iteration_end_time_ML = time.time()
    iteration_time_ML = iteration_end_time_ML - iteration_start_time_ML

    iteration_times_ML.append(iteration_time_ML)
    z0 = val.clone().detach().requires_grad_(True)
    rel_f_err.append((norm(z0-x_torch, 'fro')/norm(z0, 'fro')).item())
    fval = fh(z0)
    norm_fval.append((fval/fhz).item())
    fval.backward(retain_graph=True)
    norm_grad.append((norm(z0.grad, 'fro')/Gz0).item())
    z0.grad = None

    print(f"Iteration {i}: {fh(z0)} - Time: {iteration_time_ML:.6f} seconds")

0  : coarse correction activated
1  : coarse correction activated
Iteration 80: 15159.423940177528 - Time: 30.513409 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 81: 14913.6847847833 - Time: 30.307264 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 82: 14674.25428416332 - Time: 30.213955 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 83: 14440.914383777686 - Time: 30.401590 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 84: 14213.460516511315 - Time: 30.161213 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 85: 13991.680456978596 - Time: 30.688899 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 86: 13775.383632288898 - Time: 30.250399 seconds
0  : coarse correction activated
1  : coarse correction activated
Iteration 87: 13564.385671819307 - Time: 30.143014 seconds
0  