In [1]:
import MGTomo.model as mgmodel
import time
import numpy as np
import MGTomo.tomoprojection as mgproj
from MGTomo.utils import myexp, mylog, mydiv
import MGTomo.functions as fcts
from skimage import data
from skimage.transform import resize
from MGTomo.optimize import armijo_linesearch_box

from MGTomo.gridop import P,R, RBox, PBox

import torch
from torch.func import grad

from torch.linalg import matrix_norm

import matplotlib.pyplot as plt 

In [2]:
N = 255
# load image
x_orig = data.shepp_logan_phantom()
x_orig = resize(x_orig, (N,N), anti_aliasing = False)

x_torch = torch.tensor(x_orig, requires_grad = True)

In [85]:
model = mgmodel.astra_model(N,{'mode' : 'line', 'num_angles' : 500, 'level_decrease' : 1})
fine_dim = model.dim
A = mgproj.TomoTorch(model.proj_factory(fine_dim))
b = A(x_torch)

In [4]:
fh = lambda x: fcts.kl_distance(x, A, b)

In [5]:
def GD1(x, f):
    x1 = fcts.BSMART(f, x, 1)
    val = fcts.kl_distance_no_matrix(x1, x) - 0.3*1e-3 *f(x)

    return val

In [86]:
x0 = torch.rand(N, N, requires_grad = True)
GD1(x0, fh)

tensor(3605.0027, dtype=torch.float64, grad_fn=<SubBackward0>)

In [37]:
x0

tensor([[0.7975, 0.2443, 0.5240,  ..., 0.9488, 0.3437, 0.6177],
        [0.0794, 0.3140, 0.6351,  ..., 0.9731, 0.1198, 0.5478],
        [0.3069, 0.4643, 0.9647,  ..., 0.0530, 0.7452, 0.7094],
        ...,
        [0.8773, 0.8940, 0.3376,  ..., 0.8064, 0.8091, 0.0849],
        [0.3125, 0.7933, 0.4930,  ..., 0.9378, 0.6460, 0.3157],
        [0.7869, 0.7764, 0.8651,  ..., 0.0573, 0.5161, 0.6803]],
       requires_grad=True)

In [90]:
def AssB(x, f, tau):
    x1 = fcts.BSMART(f, x, 1)
    xtau = fcts.BSMART(f,x,tau)
    val = fcts.kl_distance_no_matrix(xtau, x) - tau * fcts.kl_distance_no_matrix(x1, x)

    return val.item()

In [67]:
# Create a range of tau values between 0 and 1
tau_values = np.linspace(0, 1, 1000)

# Placeholder for the AssB values
AssB_values = []

# Calculate the AssB values for each tau
for tau in tau_values:
    AssB_values.append(AssB(x0, fh, tau))

# Plot the results
plt.plot(tau_values, AssB_values)
plt.xlabel('Tau')
plt.ylabel('AssB value')
plt.title('AssB over values of Tau between 0 and 1')
plt.show()

KeyboardInterrupt: 

In [None]:
AssB_values

[-32558.8671875,
 -18894.13671875,
 -15564.16796875,
 -13083.55859375,
 -9142.443359375,
 -4957.9296875,
 -2022.890625,
 -620.88671875,
 -147.51171875,
 -28.25390625,
 -4.20703125,
 -0.5,
 -0.05078125,
 -0.00390625,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0,
 0.0]

norm of A is ca. 1/54 by computations done locally

In [91]:
tau = 1/54
x0 = torch.rand(N, N, requires_grad = True)
AssB(x0, fh, tau)

31881.673828125