In [9]:
import torch
import os
import importlib
import doctest
from datetime import datetime

In [10]:
from data import turtle_data_loading
importlib.reload(turtle_data_loading)
turtle_data_loading_doctest_result = doctest.testmod(turtle_data_loading)
print(turtle_data_loading_doctest_result)
from data.turtle_data_loading import get_datasets

from data import get_data_loaders
importlib.reload(get_data_loaders)
get_data_loaders_doctest_result = doctest.testmod(get_data_loaders)
print(get_data_loaders_doctest_result)
from data.get_data_loaders import get_data_loaders

from scripts import train
importlib.reload(train)
from scripts.train import start_training

from networks import unet_3d
importlib.reload(unet_3d)
unet_3d_doctest_result = doctest.testmod(unet_3d)
print(unet_3d_doctest_result)
from networks.unet_3d import UNet3d

TestResults(failed=0, attempted=0)
TestResults(failed=0, attempted=0)
Using new version 2 of train.py
TestResults(failed=0, attempted=0)


In [11]:
def start_benchmark_tv_model():
    
    turtle_data_path = "../data/turtle_id_2022/turtles-data/data"
    now = datetime.now()
    
    # Prepare to save the model
    save_dir = "models"
    model_name = f"benchmark-TV-{now}".replace(" ", "_").replace(":", "-")
    model_states_dir = f"{save_dir}/{model_name}"

    os.makedirs(model_states_dir, exist_ok=True)
    
    config = {
        "model_name": model_name,
        "project": "benchmark",
        "dataset": turtle_data_path,
        "train_data_path": f"{turtle_data_path}/train.txt",
        "val_data_path": f"{turtle_data_path}/val.txt",
        "test_data_path": f"{turtle_data_path}/test.txt",
        "train_num_samples": 10,
        "val_num_samples": 5,
        "test_num_samples": 5,
        "data_gen_num_threads": 16,

        "resize_square": 512,
        "sigmas": "[0.05, 0.1, 0.15, 0.2, 0.25, 0.3]",
        "batch_size": 1,
        "random_seed": 42,
        "n_dim_data": 3,

        "architecture": "UNET-PDHG",
        "regularisation": "TV",
        "in_channels": 1,
        "out_channels": 2,
        "init_filters": 128,
        "n_blocks": 3,
        "activation": "LeakyReLU",
        "downsampling_kernel": (2, 2, 1),
        "downsampling_mode": "max_pool",
        "upsampling_kernel": (2, 2, 1),
        "upsampling_mode": "linear_interpolation",

        "optimizer": "Adam",
        "learning_rate": 1e-4,
        "loss_function": "MSELoss",

        "up_bound": 0,
        # "scale_factor": 1,
        # "constraint_activation": "sigmoid", # "softplus" or "sigmoid"
        # "scale_factor": 0.01,
        # "constraint_activation": "softplus",
        # "T": 256, # Higher T, NET does not have to try as hard? Less overfitting?
        "T": 128,
        # "T": 16,

        "epochs": 10_000,
        "device": "cuda:0",
        "start_epoch": 0,
        "wandb_mode": "online",
        "save_epoch_wandb": 100,
        "save_epoch_local": 10,
        "save_dir": save_dir,
        "model_states_dir": model_states_dir,
        "pretrained_model_path": None,
        "is_state_dict": False,
    }
    
    device = config["device"]
    torch.set_default_device(device)
    
    if config["pretrained_model_path"] is None or config["is_state_dict"]:
        # Define CNN block
        if config["regularisation"] == "TV":
            from networks.static_img_primal_dual_nn import StaticImagePrimalDualNN
            # Construct primal-dual operator with nn
            pdhg_net = StaticImagePrimalDualNN(
                cnn_block=None, 
                device=device,
                up_bound=config["up_bound"], 
                T=config["T"],
            ).to(device)
            if config["is_state_dict"]:
                pdhg_net.load_state_dict(torch.load(f"{model_states_dir}/{config['pretrained_model_path']}.pt"), device=device)
            else:
                unet = UNet3d(
                    in_channels=config["in_channels"],
                    out_channels=config["out_channels"],
                    init_filters=config["init_filters"],
                    n_blocks=config["n_blocks"],
                    activation=config["activation"],
                    downsampling_kernel=config["downsampling_kernel"],
                    downsampling_mode=config["downsampling_mode"],
                    upsampling_kernel=config["upsampling_kernel"],
                    upsampling_mode=config["upsampling_mode"],
                ).to(device)
                pdhg_net.cnn = unet
        else:
            raise ValueError(f"Regularisation {config['regularisation']} not implemented.")

    else:
        pdhg_net = torch.load(f"{config['model_states_dir']}/{config['pretrained_model_path']}.pt").to(device)

    pdhg_net.train(True)
    
    datasets = get_datasets(config, device=device)
    
    data_loader_train, data_loader_valid, data_loader_test = get_data_loaders(config, datasets)

    del data_loader_test # Not used for now
    
    for noisy, clean in data_loader_train:
        print(noisy.shape)
        print(clean.shape)
        pdhg_net
        break
    
    start_training(
        pdhg_net=pdhg_net,
        data_loader_train=data_loader_train,
        data_loader_valid=data_loader_valid,
        config=config,
    )

In [12]:
start_benchmark_tv_model()

Loading original image paths in images_crop_resize_512_greyscale 

100%|██████████| 10/10 [00:00<00:00, 200684.40it/s]


Loading noisy image paths sigma=0.05 in images_crop_resize_512_greyscale_noisy_0_05 

100%|██████████| 10/10 [00:00<00:00, 249660.95it/s]


Loading noisy image paths sigma=0.1 in images_crop_resize_512_greyscale_noisy_0_1 

100%|██████████| 10/10 [00:00<00:00, 215092.51it/s]


Loading noisy image paths sigma=0.15 in images_crop_resize_512_greyscale_noisy_0_15 

100%|██████████| 10/10 [00:00<00:00, 211833.54it/s]


Loading noisy image paths sigma=0.2 in images_crop_resize_512_greyscale_noisy_0_2 

100%|██████████| 10/10 [00:00<00:00, 208671.84it/s]


Loading noisy image paths sigma=0.25 in images_crop_resize_512_greyscale_noisy_0_25 

100%|██████████| 10/10 [00:00<00:00, 249660.95it/s]


Loading noisy image paths sigma=0.3 in images_crop_resize_512_greyscale_noisy_0_3 

100%|██████████| 10/10 [00:00<00:00, 248183.67it/s]


Loading original images 

100%|██████████| 10/10 [00:00<00:00, 85.41it/s]


Loading noisy images sigma=0.05 

100%|██████████| 10/10 [00:00<00:00, 100.63it/s]


Loading noisy images sigma=0.1 

100%|██████████| 10/10 [00:00<00:00, 101.84it/s]


Loading noisy images sigma=0.15 

100%|██████████| 10/10 [00:00<00:00, 106.62it/s]


Loading noisy images sigma=0.2 

100%|██████████| 10/10 [00:00<00:00, 77.43it/s]


Loading noisy images sigma=0.25 

100%|██████████| 10/10 [00:00<00:00, 83.71it/s]


Loading noisy images sigma=0.3 

100%|██████████| 10/10 [00:00<00:00, 87.82it/s]


Loading original image paths in images_crop_resize_512_greyscale 

100%|██████████| 5/5 [00:00<00:00, 81920.00it/s]


Loading noisy image paths sigma=0.05 in images_crop_resize_512_greyscale_noisy_0_05 

100%|██████████| 5/5 [00:00<00:00, 104857.60it/s]


Loading noisy image paths sigma=0.1 in images_crop_resize_512_greyscale_noisy_0_1 

100%|██████████| 5/5 [00:00<00:00, 109798.53it/s]


Loading noisy image paths sigma=0.15 in images_crop_resize_512_greyscale_noisy_0_15 

100%|██████████| 5/5 [00:00<00:00, 106997.55it/s]


Loading noisy image paths sigma=0.2 in images_crop_resize_512_greyscale_noisy_0_2 

100%|██████████| 5/5 [00:00<00:00, 99391.09it/s]


Loading noisy image paths sigma=0.25 in images_crop_resize_512_greyscale_noisy_0_25 

100%|██████████| 5/5 [00:00<00:00, 107546.26it/s]


Loading noisy image paths sigma=0.3 in images_crop_resize_512_greyscale_noisy_0_3 

100%|██████████| 5/5 [00:00<00:00, 108100.62it/s]


Loading original images 

100%|██████████| 5/5 [00:00<00:00, 92.33it/s]


Loading noisy images sigma=0.05 

100%|██████████| 5/5 [00:00<00:00, 94.04it/s]


Loading noisy images sigma=0.1 

100%|██████████| 5/5 [00:00<00:00, 90.54it/s]


Loading noisy images sigma=0.15 

100%|██████████| 5/5 [00:00<00:00, 91.44it/s]


Loading noisy images sigma=0.2 

100%|██████████| 5/5 [00:00<00:00, 92.57it/s]


Loading noisy images sigma=0.25 

100%|██████████| 5/5 [00:00<00:00, 87.64it/s]


Loading noisy images sigma=0.3 

100%|██████████| 5/5 [00:00<00:00, 88.23it/s]


Loading original image paths in images_crop_resize_512_greyscale 

100%|██████████| 5/5 [00:00<00:00, 124091.83it/s]


Loading noisy image paths sigma=0.05 in images_crop_resize_512_greyscale_noisy_0_05 

100%|██████████| 5/5 [00:00<00:00, 99864.38it/s]


Loading noisy image paths sigma=0.1 in images_crop_resize_512_greyscale_noisy_0_1 

100%|██████████| 5/5 [00:00<00:00, 131072.00it/s]


Loading noisy image paths sigma=0.15 in images_crop_resize_512_greyscale_noisy_0_15 

100%|██████████| 5/5 [00:00<00:00, 105916.77it/s]


Loading noisy image paths sigma=0.2 in images_crop_resize_512_greyscale_noisy_0_2 

100%|██████████| 5/5 [00:00<00:00, 101311.69it/s]


Loading noisy image paths sigma=0.25 in images_crop_resize_512_greyscale_noisy_0_25 

100%|██████████| 5/5 [00:00<00:00, 108660.73it/s]


Loading noisy image paths sigma=0.3 in images_crop_resize_512_greyscale_noisy_0_3 

100%|██████████| 5/5 [00:00<00:00, 107546.26it/s]


Loading original images 

100%|██████████| 5/5 [00:00<00:00, 92.44it/s]


Loading noisy images sigma=0.05 

100%|██████████| 5/5 [00:00<00:00, 93.15it/s]


Loading noisy images sigma=0.1 

100%|██████████| 5/5 [00:00<00:00, 89.25it/s]


Loading noisy images sigma=0.15 

100%|██████████| 5/5 [00:00<00:00, 90.86it/s]


Loading noisy images sigma=0.2 

100%|██████████| 5/5 [00:00<00:00, 106.01it/s]


Loading noisy images sigma=0.25 

100%|██████████| 5/5 [00:00<00:00, 101.10it/s]


Loading noisy images sigma=0.3 

100%|██████████| 5/5 [00:00<00:00, 101.47it/s]


torch.Size([1, 1, 512, 512, 1])
torch.Size([1, 1, 512, 512, 1])


2024-07-01 05:37:49,458 - DEBUG - Starting new HTTPS connection (1): api.wandb.ai:443
2024-07-01 05:37:49,766 - DEBUG - https://api.wandb.ai:443 "POST /graphql HTTP/11" 200 None
2024-07-01 05:37:49,991 - DEBUG - https://api.wandb.ai:443 "POST /graphql HTTP/11" 200 None
[34m[1mwandb[0m: Currently logged in as: [33mtrung-vuthanh24[0m ([33mwof[0m). Use [1m`wandb login --relogin`[0m to force relogin
2024-07-01 05:37:50,074 - DEBUG - Popen(['git', 'cat-file', '--batch-check'], cwd=/mnt/h/GIT/DISSERTATION/LearningRegularizationParameterMaps, stdin=<valid stream>, shell=False, universal_newlines=False)


  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

KeyboardInterrupt: 