# Imports

In [8]:
first_time_importing_torch = True

In [9]:
import time
# NOTE: Importing torch the first time will always take a long time!
if first_time_importing_torch:
    print(f"Importing torch ...")
    import_torch_start_time = time.time() 
import torch
if first_time_importing_torch:
    import_torch_end_time = time.time()
    print(f"Importing torch took {import_torch_end_time - import_torch_start_time} seconds")
    first_time_importing_torch = False

import torch.nn as nn
import torch.nn.functional as F

from functools import partial

Importing torch ...
Importing torch took 3.886222839355469e-05 seconds


# Print details helper

In [87]:
def print_details(x, name):
    print(f"{name}:")
    print(x)
    print(f"  shape: {x.shape}")
    # print(f"  dtype: {x.dtype}")
    # print(f"  device: {x.device}")
    # print(f"  is_complex: {x.is_complex()}")
    # print(f"  is_floating_point: {x.is_floating_point()}")
    # print(f"  is_contiguous: {x.is_contiguous()}")
    # print(f"  is_pinned: {x.is_pinned()}")
    print()


# GradOps

In [88]:
class GradOperators(torch.nn.Module):
    @staticmethod
    def diff_kernel(ndim, mode):
        if mode == "doublecentral":
            kern = torch.tensor((-1, 0, 1))
        elif mode == "central":
            kern = torch.tensor((-1, 0, 1)) / 2
        elif mode == "forward":
            kern = torch.tensor((0, -1, 1))
        elif mode == "backward":
            kern = torch.tensor((-1, 1, 0))
        else:
            raise ValueError(f"mode should be one of (central, forward, backward, doublecentral), not {mode}")
        kernel = torch.zeros(ndim, 1, *(ndim * (3,)))
        for i in range(ndim):
            idx = tuple([i, 0, *(i * (1,)), slice(None), *((ndim - i - 1) * (1,))])
            kernel[idx] = kern
        return kernel

    def __init__(self, dim:int=2, mode:str="doublecentral", padmode:str = "circular"):
        """
        An Operator for finite Differences / Gradients
        Implements the forward as apply_G and the adjoint as apply_GH.
        
        Args:
            dim (int, optional): Dimension. Defaults to 2.
            mode (str, optional): one of doublecentral, central, forward or backward. Defaults to "doublecentral".
            padmode (str, optional): one of constant, replicate, circular or refelct. Defaults to "circular".
        """
        super().__init__()
        self.register_buffer("kernel", self.diff_kernel(dim, mode), persistent=False)
        self._dim = dim
        self._conv = (torch.nn.functional.conv1d, torch.nn.functional.conv2d, torch.nn.functional.conv3d)[dim - 1]
        self._convT = (torch.nn.functional.conv_transpose1d, torch.nn.functional.conv_transpose2d, torch.nn.functional.conv_transpose3d)[dim - 1]
        self._pad = partial(torch.nn.functional.pad, pad=2 * dim * (1,), mode=padmode)
        if mode == 'central':
            self._norm = (self.dim) ** (1 / 2)
        else:
            self._norm = (self.dim * 4) ** (1 / 2)

    @property
    def dim(self):
        return self._dim
    
    def apply_G(self, x):
        """
        Forward
        """
        if x.is_complex():
            xr = torch.view_as_real(x).moveaxis(-1, 0)
        else:
            xr = x
        xr = xr.reshape(-1, 1, *x.shape[-self.dim :])
        xp = self._pad(xr)
        y = self._conv(xp, weight=self.kernel, bias=None, padding=0)
        if x.is_complex():
            y = y.reshape(2, *x.shape[: -self.dim], self.dim, *x.shape[-self.dim :])
            y = torch.view_as_complex(y.moveaxis(0, -1).contiguous())
        else:
            y = y.reshape(*x.shape[0 : -self.dim], self.dim, *x.shape[-self.dim :])
        return y

    def apply_GH(self, x):
        """
        Adjoint
        """
        if x.is_complex():
            xr = torch.view_as_real(x).moveaxis(-1, 0)
        else:
            xr = x
        xr = xr.reshape(-1, self.dim, *x.shape[-self.dim :])
        print_details(xr, "xr")

        xp = self._pad(xr)
        y = self._convT(xp, weight=self.kernel, bias=None, padding=2)
        if x.is_complex():
            y = y.reshape(2, *x.shape[: -self.dim - 1], *x.shape[-self.dim :])
            y = torch.view_as_complex(y.moveaxis(0, -1).contiguous())
        else:
            y = y.reshape(*x.shape[: -self.dim - 1], *x.shape[-self.dim :])
        return y
    
  
    def apply_GHG(self, x):
        if x.is_complex():
            xr = torch.view_as_real(x).moveaxis(-1, 0)
        else:
            xr = x
        xr = xr.reshape(-1, 1, *x.shape[-self.dim :])
        xp = self._pad(xr)
        tmp = self._conv(xp, weight=self.kernel, bias=None, padding=0)
        tmp = self._pad(tmp)
        y = self._convT(tmp, weight=self.kernel, bias=None, padding=2)
        if x.is_complex():
            y = y.reshape(2, *x.shape)
            y = torch.view_as_complex(y.moveaxis(0, -1).contiguous())
        else:
            y = y.reshape(*x.shape)
        return y

    def forward(self, x, direction=1):
        if direction>0:
            return self.apply_G(x)
        elif direction<0:
            return self.apply_GH(x)
        else:
            return self.apply_GHG(x)


    @property
    def normGHG(self):
        return self._norm


# GradOps test

In [91]:
def test_apply_G_2D():
    grad_ops = GradOperators(dim=2, mode="forward")
    print_details(grad_ops.kernel, "G.kernel")

    x = torch.tensor([[[1, 2], [3, 4]]], dtype=torch.float32)
    print_details(x, "x")

    # x_padde_circular = G._pad(x, mode="circular")
    # print_details(x_padde_circular, "x_padde_circular")

    # x_padded_constant = G._pad(x, mode="constant")
    # print_details(x_padded_constant, "x_padded_constant")
    
    y = grad_ops.apply_G(x)
    print_details(y, "y")

    z = grad_ops.apply_GH(y)
    print_details(z, "z")

In [92]:
test_apply_G_2D()

G.kernel:
tensor([[[[ 0.,  0.,  0.],
          [ 0., -1.,  0.],
          [ 0.,  1.,  0.]]],


        [[[ 0.,  0.,  0.],
          [ 0., -1.,  1.],
          [ 0.,  0.,  0.]]]])
  shape: torch.Size([2, 1, 3, 3])

x:
tensor([[[1., 2.],
         [3., 4.]]])
  shape: torch.Size([1, 2, 2])

y:
tensor([[[[ 2.,  2.],
          [-2., -2.]],

         [[ 1., -1.],
          [ 1., -1.]]]])
  shape: torch.Size([1, 2, 2, 2])

xr:
tensor([[[[ 2.,  2.],
          [-2., -2.]],

         [[ 1., -1.],
          [ 1., -1.]]]])
  shape: torch.Size([1, 2, 2, 2])

z:
tensor([[[-6., -2.],
         [ 2.,  6.]]])
  shape: torch.Size([1, 2, 2])



In [84]:
def test_apply_G_3D():
    G = GradOperators(dim=3, mode="forward")
    print_details(G.kernel, "G.kernel")

    x = torch.tensor([[[[1], [2]], [[3], [4]]]], dtype=torch.float32)
    print_details(x, "x")

    x_padde_circular = G._pad(x, mode="circular")
    print_details(x_padde_circular, "x_padde_circular")

    x_padded_constant = G._pad(x, mode="constant")
    print_details(x_padded_constant, "x_padded_constant")
    
    y = G.apply_G(x)
    print_details(y, "y")

In [85]:
test_apply_G_3D()

G.kernel:
tensor([[[[[ 0.,  0.,  0.],
           [ 0.,  0.,  0.],
           [ 0.,  0.,  0.]],

          [[ 0.,  0.,  0.],
           [ 0., -1.,  0.],
           [ 0.,  0.,  0.]],

          [[ 0.,  0.,  0.],
           [ 0.,  1.,  0.],
           [ 0.,  0.,  0.]]]],



        [[[[ 0.,  0.,  0.],
           [ 0.,  0.,  0.],
           [ 0.,  0.,  0.]],

          [[ 0.,  0.,  0.],
           [ 0., -1.,  0.],
           [ 0.,  1.,  0.]],

          [[ 0.,  0.,  0.],
           [ 0.,  0.,  0.],
           [ 0.,  0.,  0.]]]],



        [[[[ 0.,  0.,  0.],
           [ 0.,  0.,  0.],
           [ 0.,  0.,  0.]],

          [[ 0.,  0.,  0.],
           [ 0., -1.,  1.],
           [ 0.,  0.,  0.]],

          [[ 0.,  0.,  0.],
           [ 0.,  0.,  0.],
           [ 0.,  0.,  0.]]]]])
  shape: torch.Size([3, 1, 3, 3, 3])

x:
tensor([[[[1.],
          [2.]],

         [[3.],
          [4.]]]])
  shape: torch.Size([1, 2, 2, 1])

x_padde_circular:
tensor([[[[4., 4., 4.],
          [3., 3., 

# Primal Dual NN

In [27]:
class ClipAct(nn.Module):
    def forward(self, x, threshold):
        return clipact(x, threshold)


def clipact(x, threshold):
    is_complex = x.is_complex()
    if is_complex:
        x = torch.view_as_real(x)
        threshold = threshold.unsqueeze(-1)
    x = torch.clamp(x, -threshold, threshold)
    if is_complex:
        x = torch.view_as_complex(x)
    return x

In [69]:
class DynamicImagePrimalDualNN(nn.Module):
    def __init__(
        self,
        T=128,
        cnn_block=None,
        mode="lambda_cnn",
        up_bound=0,
        phase="training",
    ):
        super(DynamicImagePrimalDualNN, self).__init__()

        # gradient operators and clipping function
        dim = 3
        self.GradOps = GradOperators(dim, mode="forward", padmode="circular")

        # operator norms
        self.op_norm_AHA = torch.sqrt(torch.tensor(1.0))
        self.op_norm_GHG = torch.sqrt(torch.tensor(12.0))
        # operator norm of K = [A, \nabla]
        # https://iopscience.iop.org/article/10.1088/0031-9155/57/10/3065/pdf,
        # see page 3083
        self.L = torch.sqrt(self.op_norm_AHA**2 + self.op_norm_GHG**2)

        # function for projecting
        self.ClipAct = ClipAct()

        if mode == "lambda_xyt":
            # one single lambda for x,y and t
            self.lambda_reg = nn.Parameter(torch.tensor([-1.5]), requires_grad=True)

        elif mode == "lambda_xy_t":
            # one (shared) lambda for x,y and one lambda for t
            self.lambda_reg = nn.Parameter(
                torch.tensor([-4.5, -1.5]), requires_grad=True
            )

        elif mode == "lambda_cnn":
            # the CNN-block to estimate the lambda regularization map
            # must be a CNN yielding a two-channeld output, i.e.
            # one map for lambda_cnn_xy and one map for lambda_cnn_t
            self.cnn = cnn_block    # NOTE: This is actually the UNET!!! (At least in this project)
            self.up_bound = torch.tensor(up_bound)

        # number of terations
        self.T = T
        self.mode = mode

        # constants depending on the operators
        self.tau = nn.Parameter(
            torch.tensor(10.0), requires_grad=True
        )  # starting value approximately  1/L
        self.sigma = nn.Parameter(
            torch.tensor(10.0), requires_grad=True
        )  # starting value approximately  1/L

        # theta should be in \in [0,1]
        self.theta = nn.Parameter(
            torch.tensor(10.0), requires_grad=True
        )  # starting value approximately  1

        # distinguish between training and test phase;
        # during training, the input is padded using "reflect" padding, because
        # patches are used by reducing the number of temporal points;
        # while testing, "reflect" padding is used in x,y- direction, while
        # circular padding is used in t-direction
        self.phase = phase

    def get_lambda_cnn(self, x):
        # padding
        # arbitrarily chosen, maybe better to choose it depending on the
        # receptive field of the CNN or so;
        # seems to be important in order not to create "holes" in the
        # lambda_maps in t-direction
        npad_xy = 4
        npad_t = 8
        pad = (npad_t, npad_t, npad_xy, npad_xy, npad_xy, npad_xy)

        if self.phase == "training":
            x = F.pad(x, pad, mode="reflect")

        elif self.phase == "testing":
            pad_refl = (0, 0, npad_xy, npad_xy, npad_xy, npad_xy)
            pad_circ = (npad_t, npad_t, 0, 0, 0, 0)

            x = F.pad(x, pad_refl, mode="reflect")
            x = F.pad(x, pad_circ, mode="circular")

        # estimate parameter map
        lambda_cnn = self.cnn(x) # NOTE: The cnn is actually the UNET block!!! (At least in this project)

        # crop
        neg_pad = tuple([-pad[k] for k in range(len(pad))])
        lambda_cnn = F.pad(lambda_cnn, neg_pad)

        # double spatial map and stack
        lambda_cnn = torch.cat((lambda_cnn[:, 0, ...].unsqueeze(1), lambda_cnn), dim=1)

        # constrain map to be striclty positive; further, bound it from below
        if self.up_bound > 0:
            # constrain map to be striclty positive; further, bound it from below
            lambda_cnn = self.up_bound * self.op_norm_AHA * torch.sigmoid(lambda_cnn)
        else:
            lambda_cnn = 0.1 * self.op_norm_AHA * F.softplus(lambda_cnn)

        return lambda_cnn

    def forward(self, x, lambda_map=None):
        # initial reconstruction
        mb, _, Nx, Ny, Nt = x.shape
        device = x.device

        # starting values
        xbar = x.clone()
        x0 = x.clone()
        xnoisy = x.clone()

        # dual variable
        p = x.clone()
        q = torch.zeros(mb, 3, Nx, Ny, Nt, dtype=x.dtype).to(device)

        print("BEFORE LOOP")
        print_details(q, "q")

        # sigma, tau, theta
        sigma = (1 / self.L) * torch.sigmoid(self.sigma)  # \in (0,1/L)
        tau = (1 / self.L) * torch.sigmoid(self.tau)  # \in (0,1/L)
        theta = torch.sigmoid(self.theta)  # \in (0,1)

        # distinguish between the different cases
        if self.mode == "lambda_xyt":
            lambda_reg = F.softplus(self.lambda_reg)  # \in (0,\infty)

        elif self.mode == "lambda_xy_t":
            # get xy- and t-lambda
            lambda_reg_xy = torch.stack(2 * [self.lambda_reg[0]])
            lambda_reg_t = self.lambda_reg[1].unsqueeze(0)

            # conatentate xy -and t-lambda
            lambda_reg = (
                torch.cat([lambda_reg_xy, lambda_reg_t])
                .unsqueeze(0)
                .unsqueeze(-1)
                .unsqueeze(-1)
                .unsqueeze(-1)
            )
            lambda_reg = F.softplus(lambda_reg)

        elif self.mode == "lambda_cnn":
            if lambda_map is None:
                # estimate lambda reg from the image
                lambda_reg = self.get_lambda_cnn(x)
            else:
                lambda_reg = lambda_map

        print_details(self.L, "L")
        # Assert the L is sqrt(1 * 12)
        assert torch.allclose(self.L, torch.sqrt(torch.tensor(1.0 + 12.0)))
        print_details(sigma, "sigma")
        # Assert sigma is 1/L * sigmoid(10.0)
        assert torch.allclose(sigma, 1/self.L * torch.sigmoid(torch.tensor(10.0)))

        # Algorithm 2 - Unrolled PDHG algorithm (page 18)
        # TODO: In the paper, L is one of the inputs but not used anywhere in the pseudo code???
        for kT in range(self.T):
            print(f"Step {kT+1}/{self.T}")

            # update p
            p =  (p + sigma * (xbar - xnoisy) ) / (1. + sigma)

            # update q
            print_details(xbar, "xbar")
            grad = self.GradOps.apply_G(xbar)
            print_details(grad, "grad")
            q = self.ClipAct(q + sigma * grad, lambda_reg)

            x1 = x0 - tau * p - tau * self.GradOps.apply_GH(q)

            if kT != self.T - 1:
                # update xbar
                xbar = x1 + theta * (x1 - x0)
                x0 = x1
            
            print_details(x1, "x1")

        return x1


# Primal Dual NN test

In [74]:
def test_primitive_dual_nn():
    T = 2
    mode = "lambda_cnn"
    up_bound = 0
    phase = "training"
    cnn_block = None
    primal_dual_nn = DynamicImagePrimalDualNN(T, cnn_block, mode, up_bound, phase)
    print(primal_dual_nn)
    print()

    # x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
    x = torch.tensor([[1, 1], [1, 1]], dtype=torch.float32)
    # Add "time" dimension. Assume the order (height, width, time)
    x = x.unsqueeze(-1)
    # Add "channel" dimension. Assume the order (channel, height, width, time)
    x = x.unsqueeze(0)
    # Add "batch" dimension. Assume the order (batch, channel, height, width, time)
    x = x.unsqueeze(0)
    print_details(x, "x")

    lambda_scalar = torch.tensor([0.05], dtype=torch.float32)

    y = primal_dual_nn(x, lambda_scalar)
    print_details(y, "y")

In [75]:
test_primitive_dual_nn()

DynamicImagePrimalDualNN(
  (GradOps): GradOperators()
  (ClipAct): ClipAct()
)

x:
tensor([[[[[1.],
           [1.]],

          [[1.],
           [1.]]]]])
  shape: torch.Size([1, 1, 2, 2, 1])

BEFORE LOOP
q:
tensor([[[[[0.],
           [0.]],

          [[0.],
           [0.]]],


         [[[0.],
           [0.]],

          [[0.],
           [0.]]],


         [[[0.],
           [0.]],

          [[0.],
           [0.]]]]])
  shape: torch.Size([1, 3, 2, 2, 1])

L:
tensor(3.6056)
  shape: torch.Size([])

sigma:
tensor(0.2773, grad_fn=<MulBackward0>)
  shape: torch.Size([])

Step 1/2
xbar:
tensor([[[[[1.],
           [1.]],

          [[1.],
           [1.]]]]])
  shape: torch.Size([1, 1, 2, 2, 1])

grad:
tensor([[[[[[0.],
            [0.]],

           [[0.],
            [0.]]],


          [[[0.],
            [0.]],

           [[0.],
            [0.]]],


          [[[0.],
            [0.]],

           [[0.],
            [0.]]]]]])
  shape: torch.Size([1, 1, 3, 2, 2, 1])

x1:
te