In [1]:
import torch
import logging
import os
from datetime import datetime
from config import LOGS_PATH

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler(
            LOGS_PATH / f"train-{datetime.now().isoformat(timespec='minutes')}.log"
        ),
    ],
)

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
f"Using device {device}"

'Using device cuda:0'

In [2]:
from scipy.stats import loguniform, uniform, randint
from editor.models import MODELS

common_hyperparameters = {
    "batch_size": [64],
    "edit_count": [12],
    "bin_count": [16],
    "clip_gradients": [False],
    "learning_rate": loguniform(1e-4, 5e-3),
    "scheduler_gamma": uniform(loc=0.7, scale=0.3),
    "num_epochs": [24],
    # "num_epochs": randint(5, 10),
    "model_type": list(MODELS.keys()),
}
hyperparameters = [
    # {
    #     **common_hyperparameters,
    #     "loss": ["progressive"],
    #     "loss_sizes": [[4, 8, 16, 32], [8, 16, 32], [16, 32], [8, 32]],
    #     "loss_damping": uniform(0.2, 5),
    # },
    {
        **common_hyperparameters,
        "loss": ["kl"],
    },
]

In [3]:
from pathlib import Path
from typing import List, Any, Dict
from torch.utils.data import DataLoader
from config import CACHE_PATH
from editor.training import HistogramDataset


def get_data_loader(data: List[Path], hyperparameters: Dict[str, Any]) -> DataLoader:
    return DataLoader(
        dataset=HistogramDataset(
            paths=data,
            edit_count=hyperparameters["edit_count"],
            bin_count=hyperparameters["bin_count"],
            delete_corrupt_images=False,
            cache_path=CACHE_PATH,
        ),
        batch_size=hyperparameters["batch_size"],
        shuffle=True,
        num_workers=os.cpu_count(),
    )


def serialise_hparams(hyperparameters: Dict[str, Any]) -> Dict[str, Any]:
    return {k: str(v) if isinstance(v, list) else v for k, v in hyperparameters.items()}

In [4]:
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from torch.optim import Adam
from tqdm.notebook import tqdm
from torch.nn.utils import clip_grad_norm_
from editor.training import ProgressivePoolingLoss
from editor.utils import get_next_run_name
from editor.visualisation import plot_histograms_in_2d
from editor.models import create_model, test_models
from datetime import timedelta, datetime
import json
from config import MODELS_PATH, RUNS_PATH, TRAIN_DATA, TEST_DATA


test_models()


def train(
    hyperparameters: Dict[str, Any], max_duration: timedelta, use_tqdm: bool
) -> Path:
    start_time = datetime.now()

    log_dir = RUNS_PATH / get_next_run_name(RUNS_PATH)
    with SummaryWriter(log_dir) as writer:
        train_data_loader = get_data_loader(TRAIN_DATA, hyperparameters)
        test_data_loader = get_data_loader(TEST_DATA, hyperparameters)

        model = (
            create_model(
                type=hyperparameters["model_type"],
                bin_count=hyperparameters["bin_count"],
            )
            .train()
            .to(device)
        )
        writer.add_graph(model, next(iter(train_data_loader))[0].to(device))

        optimizer = Adam(model.parameters(), lr=hyperparameters["learning_rate"])
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=1, gamma=hyperparameters["scheduler_gamma"]
        )

        loss_function = {
            "progressive": lambda: ProgressivePoolingLoss(
                target_sizes=hyperparameters["loss_sizes"],
                damping=hyperparameters["loss_damping"],
            ),
            "kl": lambda: torch.nn.KLDivLoss(reduction="batchmean"),
        }[hyperparameters["loss"]]().to(device)

        try:
            for epoch in range(hyperparameters["num_epochs"]):
                epoch_loss = 0
                writer.add_scalar(
                    "Actual learning rate", scheduler.get_last_lr()[0], epoch
                )
                for batch_id, (edited_histogram, original_histogram) in enumerate(
                    tqdm(train_data_loader, desc=f"Epoch {epoch}", unit="batch")
                    if use_tqdm
                    else train_data_loader
                ):
                    current_time = datetime.now()
                    if current_time - start_time > max_duration:
                        raise TimeoutError(f"Time limit {max_duration} exceeded")
                    edited_histogram = edited_histogram.to(device)
                    original_histogram = original_histogram.to(device)

                    optimizer.zero_grad()
                    predicted_original = model(edited_histogram)
                    sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)
                    predicted_original = predicted_original / sum

                    if hyperparameters["loss"] == "kl":
                        predicted_original = torch.clamp(
                            predicted_original, 0.0000000000000000000001, 1
                        )

                    loss = {
                        "kl": lambda: loss_function(
                            torch.log(predicted_original),
                            original_histogram,
                        ),
                        "progressive": lambda: loss_function(
                            predicted_original, original_histogram
                        ),
                    }[hyperparameters["loss"]]()

                    epoch_loss += loss.item()
                    writer.add_scalar(
                        "Loss/train/batch",
                        loss,
                        global_step=epoch * len(train_data_loader) + batch_id,
                    )
                    loss.backward()

                    if hyperparameters["clip_gradients"]:
                        clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()

                logging.info(f"Epoch {epoch} train loss: {epoch_loss}")
                with torch.no_grad():
                    model.eval()
                    loader = iter(test_data_loader)
                    edited_histogram, original_histogram = next(loader)
                    edited_histogram = edited_histogram.to(device)
                    original_histogram = original_histogram.to(device)
                    predicted_original = model(edited_histogram)
                    sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)
                    predicted_original = predicted_original / sum
                    writer.add_figure(
                        "histogram",
                        plot_histograms_in_2d(
                            {
                                "original": original_histogram.cpu()[0]
                                .numpy()
                                .squeeze(),
                                "edited": edited_histogram.cpu()[0].numpy().squeeze(),
                                "predicted": predicted_original.cpu()[0]
                                .numpy()
                                .squeeze(),
                            }
                        ),
                        epoch,
                    )

                    epoch_test_loss = 0
                    for batch_id, (edited_histogram, original_histogram) in enumerate(
                        test_data_loader
                    ):
                        edited_histogram = edited_histogram.to(device)
                        original_histogram = original_histogram.to(device)

                        predicted_original = model(edited_histogram)
                        sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)
                        predicted_original = predicted_original / sum

                        if hyperparameters["loss"] == "kl":
                            predicted_original = torch.clamp(
                                predicted_original, 0.0000000000000000000001, 1
                            )

                        loss = {
                            "kl": lambda: loss_function(
                                torch.log(predicted_original),
                                original_histogram,
                            ),
                            "progressive": lambda: loss_function(
                                predicted_original, original_histogram
                            ),
                        }[hyperparameters["loss"]]()

                        epoch_test_loss += loss.item()
                writer.add_hparams(
                    serialise_hparams(hyperparameters),
                    {
                        "Loss/test/epoch": epoch_test_loss,
                        "Loss/train/epoch": epoch_loss,
                    },
                    global_step=epoch,
                    run_name=log_dir.absolute(),
                )
                logging.info(f"Epoch {epoch} test loss: {epoch_test_loss}")

                model.train()
                scheduler.step()
        except Exception:
            raise
        finally:
            run_name = get_next_run_name(MODELS_PATH)
            model_path = (MODELS_PATH / run_name).with_suffix(".pth")
            params_path = (MODELS_PATH / run_name).with_suffix(".json")

            logging.info(f"Saving model to {model_path}")
            with open(model_path, "wb") as f:
                torch.save(model.state_dict(), f)
            with open(params_path, "w") as f:
                json.dump(hyperparameters, f, indent=2)
            del model
            torch.cuda.empty_cache()
        return model_path

Testing model Residual
Test passed! Output shape matches input shape.


In [5]:
# train(
#     {
#         "batch_size": 64,
#         "edit_count": 8,
#         "bin_count": 16,
#         "clip_gradients": False,
#         "learning_rate": 0.0005220900529274365,
#         "scheduler_gamma": 0.5479991284291021,
#         "num_epochs": 24,
#         "model_type": "Residual",
#         "loss": "kl",
#     }
# )

In [6]:
from random import choice
from itertools import count
import json


for _ in count():
    current_hyperparameters = {
        k: v.rvs() if hasattr(v, "rvs") else choice(v)
        for k, v in choice(hyperparameters).items()
    }
    key = json.dumps(current_hyperparameters, indent=2, sort_keys=True)
    logging.info(
        f"Starting {get_next_run_name(RUNS_PATH)} with hparams {key}"
    )
    try:
        train(current_hyperparameters, max_duration=timedelta(hours=8), use_tqdm=False)
    except KeyboardInterrupt as e:
        logging.info("Interrupted, stopping")
        break
    except TimeoutError as e:
        logging.warning(f"Timeout, aborting experiment")
    except Exception as e:
        logging.error(
            f"Error with hparams {current_hyperparameters}:\n\t{e}", stack_info=True
        )

2024-06-16 20:21:51,962 - INFO - Starting run_0 with hparams {
  "batch_size": 64,
  "bin_count": 32,
  "clip_gradients": false,
  "edit_count": 8,
  "learning_rate": 0.0013249692770052317,
  "loss": "kl",
  "model_type": "Residual",
  "num_epochs": 16,
  "scheduler_gamma": 1.3114281184948258
}
2024-06-16 20:21:52,012 - INFO - Loaded 22479 original images
2024-06-16 20:21:52,016 - INFO - Loaded 2498 original images
2024-06-16 20:35:43,995 - INFO - Epoch 0 train loss: 6540.840226650238
2024-06-16 20:36:15,017 - INFO - Epoch 0 test loss: 1531.6546006202698
2024-06-16 20:49:58,543 - INFO - Epoch 1 train loss: 5763.938045859337
2024-06-16 20:50:29,893 - INFO - Epoch 1 test loss: 1608.853798866272
2024-06-16 21:04:13,577 - INFO - Epoch 2 train loss: 5448.952376246452
2024-06-16 21:04:45,607 - INFO - Epoch 2 test loss: 1465.128571987152
2024-06-16 21:18:31,962 - INFO - Epoch 3 train loss: 5633.2793600559235
2024-06-16 21:19:09,149 - INFO - Epoch 3 test loss: 1330.329261302948
2024-06-16 21:3