In [1]:
import torch
import importlib
import doctest

In [2]:
from networks.gradops import gradops_2
importlib.reload(gradops_2)
gradops_2_doctest_result = doctest.testmod(gradops_2)
print(gradops_2_doctest_result)

TestResults(failed=0, attempted=39)


In [3]:
import networks
importlib.reload(networks)

from networks import tgv_pdhg
importlib.reload(tgv_pdhg)

from networks import tgv_pdhg_net
importlib.reload(tgv_pdhg_net)
tgv_pdhg_net_doctest_result = doctest.testmod(tgv_pdhg_net)
print(tgv_pdhg_net_doctest_result)

from networks.tgv_pdhg_net import TgvPdhgNet

TestResults(failed=0, attempted=14)


In [4]:
net = TgvPdhgNet(
    cnn=None,
    device="cuda" if torch.cuda.is_available() else "cpu",
    constraint_activation="softplus",
    scale_factor=0.1,
)

In [5]:
def test_scalar():
    alpha0 = 0.1
    alpha1 = 0.1
    params = (alpha0, alpha1)
    u = torch.randn(1, 1, 512, 512)
    T = 128
    u_T = net(u=u, regularisation_params=params, T=T)
    print(f"u_T.shape = {u_T.shape}")
    assert u_T.shape == u.shape
    
test_scalar()

u_T.shape = torch.Size([1, 1, 512, 512])


In [6]:
def test_reg_maps():
    alpha0 = torch.randn(1, 1, 512, 512)
    alpha1 = torch.randn(1, 1, 512, 512)
    params = (alpha0, alpha1)
    u = torch.randn(1, 1, 512, 512)
    T = 128
    u_T = net(u=u, regularisation_params=params, T=T)
    print(f"u_T.shape = {u_T.shape}")
    assert u_T.shape == u.shape
    
test_reg_maps()

u_T.shape = torch.Size([1, 1, 512, 512])


In [7]:
from networks import unet_2d
importlib.reload(unet_2d)
unet_2d_doctest_result = doctest.testmod(unet_2d)
print(unet_2d_doctest_result)

TestResults(failed=0, attempted=5)


In [8]:
from networks.unet_2d import UNet2d
def test_unet():
    unet = UNet2d(init_filters=8, n_blocks=2)
    x = torch.randn(1, 1, 16, 16) # 4D, normal use case when training
    print(x.shape) # torch.Size([1, 1, 16, 16])
    y = unet(x)
    print(y.shape) # torch.Size([1, 2, 16, 16])
    # x = torch.randn(1, 16, 16) # 3D, for demo and easy testing by hand
    # print(x.shape) # torch.Size([6, 6])
    # y = unet(x)
    # print(y.shape) # torch.Size([6, 6])
test_unet()

torch.Size([1, 1, 16, 16])
torch.Size([1, 2, 16, 16])


In [10]:
def test_combination():
    unet = UNet2d(init_filters=8, n_blocks=2)
    pdhg_net = TgvPdhgNet(
        cnn=unet,
        device="cuda" if torch.cuda.is_available() else "cpu",
        constraint_activation="softplus",
        scale_factor=0.1,
    )
    u = torch.randn(1, 1, 512, 512).to(pdhg_net.device)
    reg_maps = pdhg_net.get_regularisation_param_maps(u)
    print(f"reg_maps.shape = {reg_maps.shape}")
    u_T = pdhg_net(u=u, regularisation_params=reg_maps, T=128)
    print(f"u_T.shape = {u_T.shape}")
    
    u_T = pdhg_net(u, T=128)
    print(f"u_T.shape = {u_T.shape}")
    
test_combination()

reg_maps.shape = torch.Size([2, 512, 512])
u_T.shape = torch.Size([1, 1, 512, 512])
u_T.shape = torch.Size([1, 1, 512, 512])
