In [1]:
import torch
from datetime import datetime

from data.turtle_data_loading import get_datasets
from scripts import train
from scripts.train import start_training

Using new version 2 of turtle_data_generate.py


2024-06-25 15:47:24,149 - DEBUG - Popen(['git', 'version'], cwd=/mnt/h/GIT/DISSERTATION/LearningRegularizationParameterMaps/turtle_id, stdin=None, shell=False, universal_newlines=False)
2024-06-25 15:47:24,151 - DEBUG - Popen(['git', 'version'], cwd=/mnt/h/GIT/DISSERTATION/LearningRegularizationParameterMaps/turtle_id, stdin=None, shell=False, universal_newlines=False)
2024-06-25 15:47:24,213 - DEBUG - Trying paths: ['/root/.docker/config.json', '/root/.dockercfg']
2024-06-25 15:47:24,214 - DEBUG - No config file found


Using new version 2 of train.py


In [2]:
def train():
    
    turtle_data_path = "../data/turtle_id_2022/turtles-data/data"
    
    time = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    project = "turtle_id_2022_t_out_0_1"
    model_name = f"model-{project}-{time}"
    
    description = f"""
Bigger U-Net
    """
    
    config = {
        "project": project,
        "description": description,
        "dataset": turtle_data_path,
        "train_data_path": f"{turtle_data_path}/train.txt",
        "val_data_path": f"{turtle_data_path}/val.txt",
        "test_data_path": f"{turtle_data_path}/test.txt",
        "train_num_samples": 100,
        "val_num_samples": 10,
        "test_num_samples": 1, # Don't use 0, 0 means all!
        "data_gen_num_threads": 16,

        "resize_square": 512,
        "sigmas": "[0.05, 0.1, 0.15, 0.2, 0.25, 0.3]",
        "batch_size": 1,
        "random_seed": 42,

        "architecture": "UNET-PDHG",
        "model_name": model_name,
        "in_channels": 1,
        "out_channels": 2,
        "init_filters": 128,
        "n_blocks": 5,
        # "n_blocks": 3,
        "activation": "LeakyReLU",
        "downsampling_kernel": (2, 2, 1),
        "downsampling_mode": "max_pool",
        "upsampling_kernel": (2, 2, 1),
        "upsampling_mode": "linear_interpolation",

        "optimizer": "Adam",
        "learning_rate": 1e-4,
        "loss_function": "MSELoss",

        # "up_bound": 0.5, # If > 0, it will use Lambda = up_bound * sigmoid(Lambda)
        "up_bound": 0, # When using <= 0. it will default to Lambda = 0.1 * softplus(Lambda)
        # "up_bound": 0.2, # I saw that there was too much regularization with the default, so I'm trying a small value.
        # "up_bound": 0.001, # Too small and it will stop changing the image at all at some point (Lambda goes to 0)
        # "T": 256, # Higher T, NET does not have to try as hard? Less overfitting?
        "T": 128, # From the paper, it seems 128 is just as good as 256, but faster.
        # "T": 16,

        "epochs": 100_000,
        "device": "cuda:0",

        "wandb_mode": "online",
        "save_epoch_wandb": 100,
        "save_epoch_local": 10,
        "save_dir": "models",
    }
    
    torch.set_default_device(config["device"])
    
    # datasets = get_datasets(config, data_path=config["dataset"], size=config["resize_square"])
    
    start_training(
        config=config,
        get_datasets=get_datasets,
        pretrained_model_path=None,
        is_state_dict=False, 
        start_epoch=0
    )

In [3]:
train()

Loading original image paths in images_crop_resize_512_greyscale 

100%|██████████| 100/100 [00:00<00:00, 1072712.02it/s]


Loading noisy image paths sigma=0.05 in images_crop_resize_512_greyscale_noisy_0_05 

100%|██████████| 100/100 [00:00<00:00, 1594792.40it/s]


Loading noisy image paths sigma=0.1 in images_crop_resize_512_greyscale_noisy_0_1 

100%|██████████| 100/100 [00:00<00:00, 1519675.36it/s]


Loading noisy image paths sigma=0.15 in images_crop_resize_512_greyscale_noisy_0_15 

100%|██████████| 100/100 [00:00<00:00, 1536375.09it/s]


Loading noisy image paths sigma=0.2 in images_crop_resize_512_greyscale_noisy_0_2 

100%|██████████| 100/100 [00:00<00:00, 1503334.77it/s]


Loading noisy image paths sigma=0.25 in images_crop_resize_512_greyscale_noisy_0_25 

100%|██████████| 100/100 [00:00<00:00, 1530767.88it/s]


Loading noisy image paths sigma=0.3 in images_crop_resize_512_greyscale_noisy_0_3 

100%|██████████| 100/100 [00:00<00:00, 1519675.36it/s]


Loading original images 

100%|██████████| 100/100 [00:01<00:00, 99.17it/s]


Loading noisy images sigma=0.05 

100%|██████████| 100/100 [00:00<00:00, 102.68it/s]


Loading noisy images sigma=0.1 

100%|██████████| 100/100 [00:01<00:00, 99.09it/s]


Loading noisy images sigma=0.15 

100%|██████████| 100/100 [00:01<00:00, 93.93it/s]


Loading noisy images sigma=0.2 

100%|██████████| 100/100 [00:01<00:00, 95.55it/s]


Loading noisy images sigma=0.25 

100%|██████████| 100/100 [00:01<00:00, 92.29it/s]


Loading noisy images sigma=0.3 

100%|██████████| 100/100 [00:01<00:00, 91.31it/s]


Loading original image paths in images_crop_resize_512_greyscale 

100%|██████████| 10/10 [00:00<00:00, 191520.73it/s]


Loading noisy image paths sigma=0.05 in images_crop_resize_512_greyscale_noisy_0_05 

100%|██████████| 10/10 [00:00<00:00, 208671.84it/s]


Loading noisy image paths sigma=0.1 in images_crop_resize_512_greyscale_noisy_0_1 

100%|██████████| 10/10 [00:00<00:00, 219597.07it/s]


Loading noisy image paths sigma=0.15 in images_crop_resize_512_greyscale_noisy_0_15 

100%|██████████| 10/10 [00:00<00:00, 212908.83it/s]


Loading noisy image paths sigma=0.2 in images_crop_resize_512_greyscale_noisy_0_2 

100%|██████████| 10/10 [00:00<00:00, 220752.84it/s]


Loading noisy image paths sigma=0.25 in images_crop_resize_512_greyscale_noisy_0_25 

100%|██████████| 10/10 [00:00<00:00, 219597.07it/s]


Loading noisy image paths sigma=0.3 in images_crop_resize_512_greyscale_noisy_0_3 

100%|██████████| 10/10 [00:00<00:00, 210769.05it/s]


Loading original images 

100%|██████████| 10/10 [00:00<00:00, 65.35it/s]


Loading noisy images sigma=0.05 

100%|██████████| 10/10 [00:00<00:00, 100.07it/s]


Loading noisy images sigma=0.1 

100%|██████████| 10/10 [00:00<00:00, 98.68it/s]


Loading noisy images sigma=0.15 

100%|██████████| 10/10 [00:00<00:00, 95.50it/s]


Loading noisy images sigma=0.2 

100%|██████████| 10/10 [00:00<00:00, 95.16it/s]


Loading noisy images sigma=0.25 

100%|██████████| 10/10 [00:00<00:00, 92.48it/s]


Loading noisy images sigma=0.3 

100%|██████████| 10/10 [00:00<00:00, 93.57it/s]


Loading original image paths in images_crop_resize_512_greyscale 

100%|██████████| 1/1 [00:00<00:00, 24528.09it/s]


Loading noisy image paths sigma=0.05 in images_crop_resize_512_greyscale_noisy_0_05 

100%|██████████| 1/1 [00:00<00:00, 24105.20it/s]


Loading noisy image paths sigma=0.1 in images_crop_resize_512_greyscale_noisy_0_1 

100%|██████████| 1/1 [00:00<00:00, 23831.27it/s]


Loading noisy image paths sigma=0.15 in images_crop_resize_512_greyscale_noisy_0_15 

100%|██████████| 1/1 [00:00<00:00, 18558.87it/s]


Loading noisy image paths sigma=0.2 in images_crop_resize_512_greyscale_noisy_0_2 

100%|██████████| 1/1 [00:00<00:00, 24244.53it/s]


Loading noisy image paths sigma=0.25 in images_crop_resize_512_greyscale_noisy_0_25 

100%|██████████| 1/1 [00:00<00:00, 26051.58it/s]


Loading noisy image paths sigma=0.3 in images_crop_resize_512_greyscale_noisy_0_3 

100%|██████████| 1/1 [00:00<00:00, 27060.03it/s]


Loading original images 

100%|██████████| 1/1 [00:00<00:00, 62.32it/s]


Loading noisy images sigma=0.05 

100%|██████████| 1/1 [00:00<00:00, 97.92it/s]


Loading noisy images sigma=0.1 

100%|██████████| 1/1 [00:00<00:00, 80.06it/s]


Loading noisy images sigma=0.15 

100%|██████████| 1/1 [00:00<00:00, 96.78it/s]


Loading noisy images sigma=0.2 

100%|██████████| 1/1 [00:00<00:00, 92.46it/s]


Loading noisy images sigma=0.25 

100%|██████████| 1/1 [00:00<00:00, 94.80it/s]


Loading noisy images sigma=0.3 

100%|██████████| 1/1 [00:00<00:00, 89.79it/s]
2024-06-25 15:47:46,691 - DEBUG - Starting new HTTPS connection (1): api.wandb.ai:443
2024-06-25 15:47:47,058 - DEBUG - https://api.wandb.ai:443 "POST /graphql HTTP/11" 200 None
2024-06-25 15:47:47,282 - DEBUG - https://api.wandb.ai:443 "POST /graphql HTTP/11" 200 None
[34m[1mwandb[0m: Currently logged in as: [33mtrung-vuthanh24[0m ([33mwof[0m). Use [1m`wandb login --relogin`[0m to force relogin
2024-06-25 15:47:47,361 - DEBUG - Popen(['git', 'cat-file', '--batch-check'], cwd=/mnt/h/GIT/DISSERTATION/LearningRegularizationParameterMaps, stdin=<valid stream>, shell=False, universal_newlines=False)


  0%|          | 10/100000 [1:20:19<13223:15:13, 476.08s/it]

Epoch 10 - VALIDATION LOSS: 0.0013461580803171577 - VALIDATION PSNR: 30.504953384399414 - VALIDATION SSIM: 0.79832830496112


  0%|          | 20/100000 [2:39:01<13143:29:55, 473.26s/it]

Epoch 20 - VALIDATION LOSS: 0.001270540133312655 - VALIDATION PSNR: 30.71758460998535 - VALIDATION SSIM: 0.8030047093453389


  0%|          | 30/100000 [3:57:42<13126:54:11, 472.71s/it]

Epoch 30 - VALIDATION LOSS: 0.0012144653973033806 - VALIDATION PSNR: 30.99583625793457 - VALIDATION SSIM: 0.8200432144885984


  0%|          | 40/100000 [5:16:24<13132:52:39, 472.97s/it]

Epoch 40 - VALIDATION LOSS: 0.0011922972810377056 - VALIDATION PSNR: 31.0939884185791 - VALIDATION SSIM: 0.8225964670549991


  0%|          | 50/100000 [6:35:09<13140:45:50, 473.30s/it]

Epoch 50 - VALIDATION LOSS: 0.0011837294544723893 - VALIDATION PSNR: 31.154296875 - VALIDATION SSIM: 0.8228334613463308


  0%|          | 60/100000 [7:53:50<13125:47:59, 472.81s/it]

Epoch 60 - VALIDATION LOSS: 0.0012126622937406258 - VALIDATION PSNR: 31.035606384277344 - VALIDATION SSIM: 0.8173995097848582


  0%|          | 70/100000 [9:12:46<13221:40:48, 476.31s/it]

Epoch 70 - VALIDATION LOSS: 0.0012313059759132253 - VALIDATION PSNR: 30.97110366821289 - VALIDATION SSIM: 0.8161794873818009


wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)
  0%|          | 80/100000 [10:31:45<13118:44:05, 472.65s/it]

Epoch 80 - VALIDATION LOSS: 0.0012531021364945141 - VALIDATION PSNR: 30.870908737182617 - VALIDATION SSIM: 0.8130054550489504


  0%|          | 86/100000 [11:25:58<13282:30:46, 478.58s/it]


KeyboardInterrupt: 