In [1]:
import torch
import logging
import os
from editor.utils import set_up_logging
from config import LOGS_PATH

set_up_logging(LOGS_PATH)

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
f"Using device {device}"

'Using device cuda:0'

In [2]:
from scipy.stats import loguniform, uniform, randint
from editor.models import MODELS, test_models


hyperparameters = [{
    "batch_size": [32, 64, 128],
    "edit_count": [12],
    "bin_count": [16],
    "learning_rate": loguniform(5e-4, 5e-3),
    "scheduler_gamma": uniform(loc=0.8, scale=0.15),
    "num_epochs": [12],
    "elu_alpha": uniform(0.5, 1.5),
    "leaky_relu_slope": uniform(0, 0.03),
    "dropout_prob": uniform(0, 0.1),
    "features": [[16, 32, 64], [32, 64, 128], [8, 16, 32], [8, 8, 8], [16, 16, 16]],
    "kernel_sizes": [[3, 3, 3]],
    "model_type": ["Residual3"],  # list(MODELS.keys()),
    "clip_gradients": [True, False],
    "use_instance_norm": [True, False],
    "use_elu": [True, False],
    "leaky_relu_alpha": uniform(0, 0.05),
}]

test_models()

2024-06-22 15:59:06,999 - INFO - Testing model Dummy
2024-06-22 15:59:07,004 - INFO - Test passed! Output shape matches input shape.
2024-06-22 15:59:07,004 - INFO - Testing model SimpleCNN
2024-06-22 15:59:07,478 - INFO - Test passed! Output shape matches input shape.
2024-06-22 15:59:07,481 - INFO - Testing model Residual
2024-06-22 15:59:08,560 - INFO - Test passed! Output shape matches input shape.
2024-06-22 15:59:08,566 - INFO - Testing model Residual2
2024-06-22 15:59:09,671 - INFO - Test passed! Output shape matches input shape.
2024-06-22 15:59:09,676 - INFO - Testing model Residual3
2024-06-22 15:59:11,272 - INFO - Test passed! Output shape matches input shape.


In [3]:
from pathlib import Path
from typing import List, Any, Dict
from torch.utils.data import DataLoader
from config import CACHE_PATH
from editor.training import HistogramDataset


def get_data_loader(data: List[Path], hyperparameters: Dict[str, Any]) -> DataLoader:
    return DataLoader(
        dataset=HistogramDataset(
            paths=data,
            edit_count=hyperparameters["edit_count"],
            bin_count=hyperparameters["bin_count"],
            delete_corrupt_images=False,
            cache_path=CACHE_PATH,
        ),
        batch_size=hyperparameters["batch_size"],
        shuffle=True,
        num_workers=os.cpu_count(),
    )


def serialise_hparams(hyperparameters: Dict[str, Any]) -> Dict[str, Any]:
    return {k: str(v) if isinstance(v, list) else v for k, v in hyperparameters.items()}

In [4]:
from typing import Optional
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from torch.optim import Adam
from tqdm.notebook import tqdm
from editor.utils import get_next_run_name
from editor.visualisation import plot_histograms_in_2d
from editor.models import create_model, save_model
from datetime import timedelta, datetime
from config import MODELS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA


def train(
    hyperparameters: Dict[str, Any],
    max_duration: Optional[timedelta] = None,
    use_tqdm: bool = True,
) -> Path:
    start_time = datetime.now()

    log_dir = RUNS_PATH / get_next_run_name(RUNS_PATH)
    with SummaryWriter(log_dir) as writer:
        train_data_loader = get_data_loader(TRAIN_DATA, hyperparameters)
        test_data_loader = get_data_loader(TEST_DATA, hyperparameters)

        model = create_model(
            type=hyperparameters["model_type"],
            bin_count=hyperparameters["bin_count"],
            device=device,
        ).train()
        writer.add_graph(model, next(iter(train_data_loader))[0].to(device))

        optimizer = Adam(model.parameters(), lr=hyperparameters["learning_rate"])
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=1, gamma=hyperparameters["scheduler_gamma"]
        )
        loss_function = torch.nn.KLDivLoss(reduction="batchmean").to(device)

        try:
            for epoch in range(hyperparameters["num_epochs"]):
                epoch_loss = 0
                writer.add_scalar(
                    "Actual learning rate", scheduler.get_last_lr()[0], epoch
                )
                for batch_id, (edited_histogram, original_histogram) in enumerate(
                    tqdm(train_data_loader, desc=f"Epoch {epoch}", unit="batch")
                    if use_tqdm
                    else train_data_loader
                ):
                    current_time = datetime.now()
                    if (
                        max_duration is not None
                        and current_time - start_time > max_duration
                    ):
                        raise TimeoutError(f"Time limit {max_duration} exceeded")

                    optimizer.zero_grad()
                    predicted_original = model(edited_histogram.to(device))
                    loss = loss_function(
                        torch.log(torch.clamp(predicted_original, 1e-10, 1)),
                        original_histogram.to(device),
                    )

                    epoch_loss += loss.item()
                    writer.add_scalar(
                        "Loss/train/batch",
                        loss,
                        global_step=epoch * len(train_data_loader) + batch_id,
                    )
                    loss.backward()
                    optimizer.step()

                logging.info(f"Epoch {epoch} train loss: {epoch_loss}")
                with torch.no_grad():
                    model.eval()
                    loader = iter(test_data_loader)
                    edited_histogram, original_histogram = next(loader)
                    predicted_original = model(edited_histogram.to(device))
                    writer.add_figure(
                        "histogram",
                        plot_histograms_in_2d(
                            {
                                "original": original_histogram[0].numpy().squeeze(),
                                "edited": edited_histogram.cpu()[0].numpy().squeeze(),
                                "predicted": predicted_original.cpu()[0]
                                .numpy()
                                .squeeze(),
                            }
                        ),
                        epoch,
                    )

                    epoch_test_loss = 0
                    for batch_id, (edited_histogram, original_histogram) in enumerate(
                        test_data_loader
                    ):
                        predicted_original = model(edited_histogram.to(device))
                        epoch_test_loss += loss_function(
                            torch.log(torch.clamp(predicted_original, 1e-10, 1)),
                            original_histogram.to(device),
                        ).item()
                writer.add_hparams(
                    serialise_hparams(hyperparameters),
                    {
                        "Loss/test/epoch": epoch_test_loss,
                        "Loss/train/epoch": epoch_loss,
                    },
                    global_step=epoch,
                    run_name=log_dir.absolute(),
                )
                logging.info(f"Epoch {epoch} test loss: {epoch_test_loss}")

                model.train()
                scheduler.step()
        finally:
            model_path = MODELS_PATH / get_next_run_name(MODELS_PATH)
            save_model(model, hyperparameters, model_path)
            del model
        return model_path

In [7]:
# train(
#     {
#         "batch_size": 128,
#         "edit_count": 12,
#         "bin_count": 16,
#         "learning_rate": 1e-3,
#         "scheduler_gamma": 0.8,
#         "elu_alpha": 1,
#         "dropout_prob": 0.05,
#         "features": [8, 16, 32],
#         "kernel_sizes": [3, 3, 3],
#         "num_epochs": 12,
#         "model_type": "Residual3",
#         "clip_gradients": True,
#         "use_instance_norm": True,
#         "use_elu": False,
#         "leaky_relu_alpha": 0.01,
#     }
# )

2024-06-22 16:57:28,986 - INFO - Loaded 22479 original images
2024-06-22 16:57:28,991 - INFO - Loaded 2498 original images


Epoch 0:   0%|          | 0/2108 [00:00<?, ?batch/s]

2024-06-22 17:00:34,218 - INFO - Epoch 0 train loss: 11718.475350141525
  scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor
2024-06-22 17:00:43,669 - INFO - Epoch 0 test loss: 575.4344878196716


Epoch 1:   0%|          | 0/2108 [00:00<?, ?batch/s]

2024-06-22 17:03:46,881 - INFO - Epoch 1 train loss: 9741.187401413918
2024-06-22 17:03:56,471 - INFO - Epoch 1 test loss: 536.2769713401794


Epoch 2:   0%|          | 0/2108 [00:00<?, ?batch/s]

2024-06-22 17:06:59,896 - INFO - Epoch 2 train loss: 9120.070751070976
2024-06-22 17:07:09,641 - INFO - Epoch 2 test loss: 553.2901458740234


Epoch 3:   0%|          | 0/2108 [00:00<?, ?batch/s]

2024-06-22 17:10:12,942 - INFO - Epoch 3 train loss: 5763.117876529694
2024-06-22 17:10:22,622 - INFO - Epoch 3 test loss: 507.6950304508209


Epoch 4:   0%|          | 0/2108 [00:00<?, ?batch/s]

2024-06-22 17:13:25,908 - INFO - Epoch 4 train loss: 6363.094870328903
2024-06-22 17:13:36,532 - INFO - Epoch 4 test loss: 532.4468264579773


Epoch 5:   0%|          | 0/2108 [00:00<?, ?batch/s]

2024-06-22 17:16:40,056 - INFO - Epoch 5 train loss: 4596.043945550919
2024-06-22 17:16:49,784 - INFO - Epoch 5 test loss: 438.763400554657


Epoch 6:   0%|          | 0/2108 [00:00<?, ?batch/s]

2024-06-22 17:19:53,205 - INFO - Epoch 6 train loss: 5266.503381967545
2024-06-22 17:20:02,990 - INFO - Epoch 6 test loss: 573.5293898582458


Epoch 7:   0%|          | 0/2108 [00:00<?, ?batch/s]

2024-06-22 17:23:06,445 - INFO - Epoch 7 train loss: 5163.991681098938
2024-06-22 17:23:16,136 - INFO - Epoch 7 test loss: 672.4951323270798


Epoch 8:   0%|          | 0/2108 [00:00<?, ?batch/s]

2024-06-22 17:26:19,453 - INFO - Epoch 8 train loss: 12930.857147455215
2024-06-22 17:26:29,204 - INFO - Epoch 8 test loss: 636.4001806974411


Epoch 9:   0%|          | 0/2108 [00:00<?, ?batch/s]

2024-06-22 17:29:32,560 - INFO - Epoch 9 train loss: 13841.072596549988
2024-06-22 17:29:42,246 - INFO - Epoch 9 test loss: 2833.8614711761475


Epoch 10:   0%|          | 0/2108 [00:00<?, ?batch/s]

2024-06-22 17:32:45,585 - INFO - Epoch 10 train loss: 15531.006411075592
2024-06-22 17:32:55,355 - INFO - Epoch 10 test loss: 469.56569051742554


Epoch 11:   0%|          | 0/2108 [00:00<?, ?batch/s]

2024-06-22 17:35:58,670 - INFO - Epoch 11 train loss: 17766.0949113369
2024-06-22 17:36:08,527 - INFO - Epoch 11 test loss: 3254.2825841903687
2024-06-22 17:36:08,529 - INFO - Saving model to /home/andras/projects/bipolaroid/models/run_66.pth
2024-06-22 17:36:08,529 - INFO - Parameter count: 429457


PosixPath('/home/andras/projects/bipolaroid/models/run_66')

In [6]:
from random import choice
from itertools import count
import json


for _ in count():
    current_hyperparameters = {
        k: v.rvs() if hasattr(v, "rvs") else choice(v)
        for k, v in choice(hyperparameters).items()
    }
    key = json.dumps(current_hyperparameters, indent=2, sort_keys=True)
    logging.info(f"Starting {get_next_run_name(RUNS_PATH)} with hparams {key}")
    try:
        train(current_hyperparameters, max_duration=timedelta(hours=8), use_tqdm=False)
    except KeyboardInterrupt as e:
        logging.info("Interrupted, stopping")
        break
    except TimeoutError as e:
        logging.warning(f"Timeout, aborting experiment")
    except Exception as e:
        logging.error(
            f"Error with hparams {current_hyperparameters}:\n\t{e}", stack_info=True
        )