In [4]:
import torch

from data.turtle_data_loading import get_datasets
from scripts.train import start_training

In [5]:
def train():
    
    turtle_data_path = "../data/turtle_id_2022/turtles-data/data"
    
    config = {
        "project": "turtle_id_2022",
        "dataset": turtle_data_path,
        "train_data_path": f"{turtle_data_path}/train.txt",
        "val_data_path": f"{turtle_data_path}/val.txt",
        "test_data_path": f"{turtle_data_path}/test.txt",
        "train_num_samples": 100,
        "val_num_samples": 10,
        "test_num_samples": 10,
        "data_gen_num_threads": 16,

        "resize_square": 256,
        "sigmas": "[0.05, 0.1, 0.15, 0.2, 0.25, 0.3]",
        "batch_size": 1,
        "random_seed": 42,

        "architecture": "UNET-PDHG",
        "in_channels": 1,
        "out_channels": 2,
        "init_filters": 32,
        "n_blocks": 3,
        "activation": "LeakyReLU",
        "downsampling_kernel": (2, 2, 1),
        "downsampling_mode": "max_pool",
        "upsampling_kernel": (2, 2, 1),
        "upsampling_mode": "linear_interpolation",

        "optimizer": "Adam",
        "learning_rate": 1e-4,
        "loss_function": "MSELoss",

        # "up_bound": 0.5,
        "up_bound": 0,
        # "T": 256, # Higher T, NET does not have to try as hard? Less overfitting?
        "T": 128,
        # "T": 16,

        "epochs": 10_000,
        "device": "cuda:0",

        "wandb_mode": "online",
        "save_epoch_wandb": 100,
        "save_epoch_local": 10,
        "save_dir": "tmp_2",
    }
    
    torch.set_default_device(config["device"])
    
    # datasets = get_datasets(config, data_path=config["dataset"], size=config["resize_square"])
    
    start_training(
        config=config,
        get_datasets=get_datasets,
        pretrained_model_path=None,
        is_state_dict=False, 
        start_epoch=0
    )

In [6]:
train()

Using 16 threads
Using skimage.util.random_noise to add noise to images
Multiprocessing 438 subfolders in 16 threads
Processed 8729 images
Loading images from images_crop_resize_256_greyscale_noisy_0_1


100it [00:00, 103.66it/s]


Loading images from images_crop_resize_256_greyscale_noisy_0_15


100it [00:00, 109.20it/s]


Loading images from images_crop_resize_256_greyscale_noisy_0_2


100it [00:00, 101.32it/s]


Loading images from images_crop_resize_256_greyscale_noisy_0_25


100it [00:00, 109.89it/s]


Loading images from images_crop_resize_256_greyscale_noisy_0_3


100it [00:00, 102.68it/s]


Loading images from images_crop_resize_256_greyscale_noisy_0_1


10it [00:00, 105.61it/s]


Loading images from images_crop_resize_256_greyscale_noisy_0_15


10it [00:00, 101.62it/s]


Loading images from images_crop_resize_256_greyscale_noisy_0_2


10it [00:00, 101.95it/s]


Loading images from images_crop_resize_256_greyscale_noisy_0_25


10it [00:00, 102.36it/s]


Loading images from images_crop_resize_256_greyscale_noisy_0_3


10it [00:00, 98.68it/s]


Loading images from images_crop_resize_256_greyscale_noisy_0_1


10it [00:00, 124.04it/s]


Loading images from images_crop_resize_256_greyscale_noisy_0_15


10it [00:00, 115.19it/s]


Loading images from images_crop_resize_256_greyscale_noisy_0_2


10it [00:00, 103.77it/s]


Loading images from images_crop_resize_256_greyscale_noisy_0_25


10it [00:00, 101.64it/s]


Loading images from images_crop_resize_256_greyscale_noisy_0_3


10it [00:00, 99.09it/s]
2024-06-20 14:31:46,556 - DEBUG - Popen(['git', 'cat-file', '--batch-check'], cwd=/mnt/h/GIT/DISSERTATION/LearningRegularizationParameterMaps, stdin=<valid stream>, shell=False, universal_newlines=False)


  0%|          | 10/10000 [42:21<700:06:00, 252.29s/it]

Epoch 10 - VALIDATION LOSS: 0.0018477293744217605 - VALIDATION PSNR: 28.3865966796875 - VALIDATION SSIM: 0.7155333151093286


  0%|          | 20/10000 [1:23:38<684:13:30, 246.81s/it]

Epoch 20 - VALIDATION LOSS: 0.001793908738764003 - VALIDATION PSNR: 28.593542098999023 - VALIDATION SSIM: 0.7438529670248885


  0%|          | 30/10000 [2:05:11<688:10:29, 248.49s/it]

Epoch 30 - VALIDATION LOSS: 0.0017062908544903622 - VALIDATION PSNR: 28.813276290893555 - VALIDATION SSIM: 0.7543425480204594


  0%|          | 40/10000 [2:46:23<683:22:27, 247.00s/it]

Epoch 40 - VALIDATION LOSS: 0.0016818353143753485 - VALIDATION PSNR: 28.90021514892578 - VALIDATION SSIM: 0.7575325509647951


  0%|          | 50/10000 [3:27:55<689:00:30, 249.29s/it]

Epoch 50 - VALIDATION LOSS: 0.0016721557301934808 - VALIDATION PSNR: 28.935842514038086 - VALIDATION SSIM: 0.7611580367657129


  1%|          | 60/10000 [4:09:11<682:45:16, 247.28s/it]

Epoch 60 - VALIDATION LOSS: 0.001685275664785877 - VALIDATION PSNR: 28.908388137817383 - VALIDATION SSIM: 0.7590342682771429


  1%|          | 70/10000 [4:50:34<683:47:14, 247.90s/it]

Epoch 70 - VALIDATION LOSS: 0.0017016070877434687 - VALIDATION PSNR: 28.866308212280273 - VALIDATION SSIM: 0.7555956423493154


  1%|          | 80/10000 [5:32:01<683:35:57, 248.08s/it]

Epoch 80 - VALIDATION LOSS: 0.0017137280601309613 - VALIDATION PSNR: 28.82942771911621 - VALIDATION SSIM: 0.7540249431447605


  1%|          | 90/10000 [6:13:29<687:05:28, 249.60s/it]

Epoch 90 - VALIDATION LOSS: 0.0017224647020339034 - VALIDATION PSNR: 28.80537223815918 - VALIDATION SSIM: 0.7522422117772871


  1%|          | 100/10000 [6:54:59<682:27:51, 248.17s/it]

Epoch 100 - VALIDATION LOSS: 0.0017500001023290678 - VALIDATION PSNR: 28.688112258911133 - VALIDATION SSIM: 0.7469124926571422


  1%|          | 110/10000 [7:36:29<689:22:25, 250.93s/it]

Epoch 110 - VALIDATION LOSS: 0.0017459169382345863 - VALIDATION PSNR: 28.734872817993164 - VALIDATION SSIM: 0.74866897115615


  1%|          | 120/10000 [8:18:01<684:18:39, 249.34s/it]

Epoch 120 - VALIDATION LOSS: 0.0017609591837390325 - VALIDATION PSNR: 28.689184188842773 - VALIDATION SSIM: 0.7464418423276459


  1%|▏         | 130/10000 [8:59:21<679:29:38, 247.84s/it]

Epoch 130 - VALIDATION LOSS: 0.0017706468258984387 - VALIDATION PSNR: 28.674047470092773 - VALIDATION SSIM: 0.7459473139388055


  1%|▏         | 140/10000 [9:41:10<684:50:13, 250.04s/it]

Epoch 140 - VALIDATION LOSS: 0.0017773114173905923 - VALIDATION PSNR: 28.649147033691406 - VALIDATION SSIM: 0.7447830574071959


  1%|▏         | 142/10000 [9:52:07<685:06:25, 250.19s/it]


KeyboardInterrupt: 