In [22]:
import torch
import os
import importlib
import doctest
from datetime import datetime

In [23]:
from data import turtle_data_loading
importlib.reload(turtle_data_loading)
turtle_data_loading_doctest_result = doctest.testmod(turtle_data_loading)
print(turtle_data_loading_doctest_result)
from data.turtle_data_loading import get_datasets

from data import get_data_loaders
importlib.reload(get_data_loaders)
get_data_loaders_doctest_result = doctest.testmod(get_data_loaders)
print(get_data_loaders_doctest_result)
from data.get_data_loaders import get_data_loaders

from scripts import train
importlib.reload(train)
from scripts.train import start_training

from networks import unet_2d
importlib.reload(unet_2d)
unet_2d_doctest_result = doctest.testmod(unet_2d)
print(unet_2d_doctest_result)
from networks.unet_2d import UNet2d

TestResults(failed=0, attempted=0)
TestResults(failed=0, attempted=0)
Using new version 2 of train.py
TestResults(failed=0, attempted=5)


In [24]:
def start_training_tgv_pdhg_model():
    
    turtle_data_path = "../data/turtle_id_2022/turtles-data/data"
    now = datetime.now()
    model_name = f"{now}".replace(" ", "_").replace(":", "-")
    config = {
        "model_name": model_name,
        "project": "tgv-test",
        "dataset": turtle_data_path,
        "train_data_path": f"{turtle_data_path}/train.txt",
        "val_data_path": f"{turtle_data_path}/val.txt",
        "test_data_path": f"{turtle_data_path}/test.txt",
        "train_num_samples": 10,
        "val_num_samples": 5,
        "test_num_samples": 5,
        "data_gen_num_threads": 16,

        "resize_square": 256,
        "sigmas": "[0.05, 0.1, 0.15, 0.2, 0.25, 0.3]",
        "batch_size": 1,
        "random_seed": 42,

        "architecture": "UNET-PDHG",
        "in_channels": 1,
        "out_channels": 2,
        "init_filters": 128,
        "n_blocks": 3,
        "activation": "LeakyReLU",
        "downsampling_kernel": (2, 2),
        "downsampling_mode": "max_pool",
        "upsampling_kernel": (2, 2),
        "upsampling_mode": "linear_interpolation",

        "optimizer": "Adam",
        "learning_rate": 1e-4,
        "loss_function": "MSELoss",

        "scale_factor": 0.5,
        "constraint_activation": "sigmoid", # "softplus" or "sigmoid"
        # "T": 256, # Higher T, NET does not have to try as hard? Less overfitting?
        "T": 128,
        # "T": 16,

        "epochs": 10_000,
        "device": "cuda:0",

        "wandb_mode": "online",
        "save_epoch_wandb": 100,
        "save_epoch_local": 10,
        "save_dir": "models",
    }
    
    device = config["device"]
    
    torch.set_default_device(device)
    
    # datasets = get_datasets(config, data_path=config["dataset"], size=config["resize_square"])
    
    # Prepare to save the model
    save_dir = config["save_dir"]
    model_name = config["model_name"]
    model_states_dir = f"{save_dir}/{model_name}"

    os.makedirs(model_states_dir, exist_ok=True)
    
    pretrained_model_path = None
    is_state_dict = False
    if pretrained_model_path is None or is_state_dict:
        # Define CNN block
        unet = UNet2d(
            in_channels=config["in_channels"],
            out_channels=config["out_channels"],
            init_filters=config["init_filters"],
            n_blocks=config["n_blocks"],
            activation=config["activation"],
            downsampling_kernel=config["downsampling_kernel"],
            downsampling_mode=config["downsampling_mode"],
            upsampling_kernel=config["upsampling_kernel"],
            upsampling_mode=config["upsampling_mode"],
        ).to(device)

        from networks.tgv_pdhg_net import TgvPdhgNet
        # Construct primal-dual operator with nn
        pdhg_net = TgvPdhgNet(
            cnn=unet, 
            device=device,
            constraint_activation=config["constraint_activation"], 
            scale_factor=config["scale_factor"],
        ).to(device)
        if is_state_dict:
            pdhg_net.load_state_dict(torch.load(f"{model_states_dir}/{pretrained_model_path}.pt"))
    else:
        pdhg_net = torch.load(f"{model_states_dir}/{pretrained_model_path}.pt").to(device)

    pdhg_net.train(True)
    
    datasets = get_datasets(config, device=device, n_dim_data=2)
    
    data_loader_train, data_loader_valid, data_loader_test = get_data_loaders(config, datasets)

    del data_loader_test # Not used for now
    
    for noisy, clean in data_loader_train:
        print(noisy.shape)
        print(clean.shape)
        pdhg_net
        break
    
    start_training(
        pdhg_net=pdhg_net,
        config=config,
        data_loader_train=data_loader_train,
        data_loader_valid=data_loader_valid,
        model_states_dir=model_states_dir,
        start_epoch=0
    )

In [25]:
start_training_tgv_pdhg_model()

Loading original image paths in images_crop_resize_256_greyscale 

100%|██████████| 10/10 [00:00<00:00, 47934.90it/s]


Loading noisy image paths sigma=0.05 in images_crop_resize_256_greyscale_noisy_0_05 

100%|██████████| 10/10 [00:00<00:00, 95542.23it/s]


Loading noisy image paths sigma=0.1 in images_crop_resize_256_greyscale_noisy_0_1 

100%|██████████| 10/10 [00:00<00:00, 23327.61it/s]


Loading noisy image paths sigma=0.15 in images_crop_resize_256_greyscale_noisy_0_15 

100%|██████████| 10/10 [00:00<00:00, 27999.36it/s]


Loading noisy image paths sigma=0.2 in images_crop_resize_256_greyscale_noisy_0_2 

100%|██████████| 10/10 [00:00<00:00, 39383.14it/s]


Loading noisy image paths sigma=0.25 in images_crop_resize_256_greyscale_noisy_0_25 

100%|██████████| 10/10 [00:00<00:00, 25450.87it/s]


Loading noisy image paths sigma=0.3 in images_crop_resize_256_greyscale_noisy_0_3 

100%|██████████| 10/10 [00:00<00:00, 20620.96it/s]


Loading original images 

100%|██████████| 10/10 [00:00<00:00, 94.60it/s]


Loading noisy images sigma=0.05 

100%|██████████| 10/10 [00:00<00:00, 106.48it/s]


Loading noisy images sigma=0.1 

100%|██████████| 10/10 [00:00<00:00, 101.61it/s]


Loading noisy images sigma=0.15 

100%|██████████| 10/10 [00:00<00:00, 98.09it/s]


Loading noisy images sigma=0.2 

100%|██████████| 10/10 [00:00<00:00, 101.47it/s]


Loading noisy images sigma=0.25 

100%|██████████| 10/10 [00:00<00:00, 88.99it/s]


Loading noisy images sigma=0.3 

100%|██████████| 10/10 [00:00<00:00, 104.08it/s]


Loading original image paths in images_crop_resize_256_greyscale 

100%|██████████| 5/5 [00:00<00:00, 39494.39it/s]


Loading noisy image paths sigma=0.05 in images_crop_resize_256_greyscale_noisy_0_05 

100%|██████████| 5/5 [00:00<00:00, 20620.96it/s]


Loading noisy image paths sigma=0.1 in images_crop_resize_256_greyscale_noisy_0_1 

100%|██████████| 5/5 [00:00<00:00, 20440.08it/s]


Loading noisy image paths sigma=0.15 in images_crop_resize_256_greyscale_noisy_0_15 

100%|██████████| 5/5 [00:00<00:00, 19765.81it/s]


Loading noisy image paths sigma=0.2 in images_crop_resize_256_greyscale_noisy_0_2 

100%|██████████| 5/5 [00:00<00:00, 16231.83it/s]


Loading noisy image paths sigma=0.25 in images_crop_resize_256_greyscale_noisy_0_25 

100%|██████████| 5/5 [00:00<00:00, 20360.70it/s]


Loading noisy image paths sigma=0.3 in images_crop_resize_256_greyscale_noisy_0_3 

100%|██████████| 5/5 [00:00<00:00, 19508.39it/s]


Loading original images 

100%|██████████| 5/5 [00:00<00:00, 111.49it/s]


Loading noisy images sigma=0.05 

100%|██████████| 5/5 [00:00<00:00, 130.31it/s]


Loading noisy images sigma=0.1 

100%|██████████| 5/5 [00:00<00:00, 140.45it/s]


Loading noisy images sigma=0.15 

100%|██████████| 5/5 [00:00<00:00, 134.89it/s]


Loading noisy images sigma=0.2 

100%|██████████| 5/5 [00:00<00:00, 133.51it/s]


Loading noisy images sigma=0.25 

100%|██████████| 5/5 [00:00<00:00, 139.15it/s]


Loading noisy images sigma=0.3 

100%|██████████| 5/5 [00:00<00:00, 138.70it/s]


Loading original image paths in images_crop_resize_256_greyscale 

100%|██████████| 5/5 [00:00<00:00, 33879.68it/s]


Loading noisy image paths sigma=0.05 in images_crop_resize_256_greyscale_noisy_0_05 

100%|██████████| 5/5 [00:00<00:00, 29704.70it/s]


Loading noisy image paths sigma=0.1 in images_crop_resize_256_greyscale_noisy_0_1 

100%|██████████| 5/5 [00:00<00:00, 45789.34it/s]


Loading noisy image paths sigma=0.15 in images_crop_resize_256_greyscale_noisy_0_15 

100%|██████████| 5/5 [00:00<00:00, 19803.14it/s]


Loading noisy image paths sigma=0.2 in images_crop_resize_256_greyscale_noisy_0_2 

100%|██████████| 5/5 [00:00<00:00, 20203.78it/s]


Loading noisy image paths sigma=0.25 in images_crop_resize_256_greyscale_noisy_0_25 

100%|██████████| 5/5 [00:00<00:00, 46397.17it/s]


Loading noisy image paths sigma=0.3 in images_crop_resize_256_greyscale_noisy_0_3 

100%|██████████| 5/5 [00:00<00:00, 13107.20it/s]


Loading original images 

100%|██████████| 5/5 [00:00<00:00, 97.92it/s]


Loading noisy images sigma=0.05 

100%|██████████| 5/5 [00:00<00:00, 130.71it/s]


Loading noisy images sigma=0.1 

100%|██████████| 5/5 [00:00<00:00, 120.05it/s]


Loading noisy images sigma=0.15 

100%|██████████| 5/5 [00:00<00:00, 108.96it/s]


Loading noisy images sigma=0.2 

100%|██████████| 5/5 [00:00<00:00, 114.22it/s]


Loading noisy images sigma=0.25 

100%|██████████| 5/5 [00:00<00:00, 137.41it/s]


Loading noisy images sigma=0.3 

100%|██████████| 5/5 [00:00<00:00, 114.31it/s]


torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])




VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▅█
train_PSNR,▁██
train_SSIM,▁██
train_loss,█▁▁
training_iter_PSNR,▁▆█▅▆▆▇▇▇▅▇▆▆▆▆▆▆▇▅█▆█▅▇▆▇▇▆▆▆▆▇█▆▆▄▆▇█▅
training_iter_SSIM,▁▃█▃▆▄▄▄▅▃▄▄▇▃▃▄▄▄▃▅▆▆▃█▃▄▃▃▄▄▅▄▆▄▄▃▄▄▆▃
training_iter_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁
val_PSNR,▁▁
val_SSIM,▁▁
val_iter_PSNR,▅▄▆█▃▅█▂▅▇▄▄▅▃▄▄▃▃▂▅▄▆█▃▆█▂▅▇▁▄▅▃▄▃▃▃▂▅▄

0,1
epoch,3.0
train_PSNR,-20.73707
train_SSIM,0.30879
train_loss,135.80941
training_iter_PSNR,-24.51221
training_iter_SSIM,0.23214
training_iter_loss,282.63159
val_PSNR,-19.34791
val_SSIM,0.32476
val_iter_PSNR,-17.42106


2024-06-30 14:47:56,806 - DEBUG - Popen(['git', 'cat-file', '--batch-check'], cwd=/mnt/h/GIT/DISSERTATION/LearningRegularizationParameterMaps, stdin=<valid stream>, shell=False, universal_newlines=False)


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112403677776052, max=1.0…

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/60 [00:00<?, ?it/s]

ValueError: NaN returned by loss function...