In [2]:
import torch
import logging
import os
from datetime import datetime

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler(f"train-{datetime.now().date()}.log"),
    ],
)

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
f"Using device {device}"

2024-06-03 07:46:08,999 - INFO - PyTorch version: 2.2.2


'Using device cuda:0'

In [2]:
from scipy.stats import loguniform, uniform
from editor.models import MODELS

common_hyperparameters = {
    "batch_size": [16, 32, 64],
    "edit_count": [8, 16],
    "bin_count": [16, 32],
    "clip_gradients": [True, False],
    "learning_rate": loguniform(0.0001, 0.005),
    "scheduler_gamma": uniform(0.1, 0.9),
    "num_epochs": [5],
    "model_type": list(MODELS.keys()),
}
hyperparameters = [
    # {
    #     **common_hyperparameters,
    #     "loss": ["progressive"],
    #     "loss_sizes": [[4, 8, 16, 32], [8, 16, 32], [16, 32], [8, 32]],
    #     "loss_damping": uniform(0.2, 5),
    # },
    {
        **common_hyperparameters,
        "loss": ["kl"],
    },
]

In [3]:
from typing import Any, Dict
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
from torch.optim import Adam
from tqdm.notebook import tqdm
from torch.nn.utils import clip_grad_norm_
from editor.training import ProgressivePoolingLoss
from editor.utils import get_next_run_name
from editor.visualisation import plot_histograms_in_2d
from editor.training import create_data_loaders
from editor.models import create_model, test_models
from config import DATA, MODELS_PATH
from datetime import timedelta, datetime

test_models()


def train(
    hyperparameters: Dict[str, Any], max_duration: timedelta, use_tqdm: bool
) -> Path:
    start_time = datetime.now()
    model_path = (MODELS_PATH / get_next_run_name(Path("runs"))).with_suffix(".pth")

    log_dir = Path("runs") / get_next_run_name(Path("runs"))
    with SummaryWriter(log_dir) as writer:
        train_data_loader, test_data_loader = create_data_loaders(
            data=DATA,
            edit_count=hyperparameters["edit_count"],
            bin_count=hyperparameters["bin_count"],
            training_batch_size=hyperparameters["batch_size"],
        )

        model = (
            create_model(
                type=hyperparameters["model_type"],
                bin_count=hyperparameters["bin_count"],
            )
            .train()
            .to(device)
        )
        writer.add_graph(model, next(iter(train_data_loader))[0].to(device))

        optimizer = Adam(model.parameters(), lr=hyperparameters["learning_rate"])
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=1, gamma=hyperparameters["scheduler_gamma"]
        )

        loss_function = {
            "progressive": lambda: ProgressivePoolingLoss(
                target_sizes=hyperparameters["loss_sizes"],
                damping=hyperparameters["loss_damping"],
            ),
            "kl": lambda: torch.nn.KLDivLoss(reduction="batchmean"),
        }[hyperparameters["loss"]]().to(device)

        try:
            for epoch in range(hyperparameters["num_epochs"]):
                epoch_loss = 0
                writer.add_scalar(
                    "Actual learning rate", scheduler.get_last_lr()[0], epoch
                )
                for batch_id, (edited_histogram, original_histogram) in enumerate(
                    tqdm(train_data_loader, desc=f"Epoch {epoch}", unit="batch")
                    if use_tqdm
                    else train_data_loader
                ):
                    current_time = datetime.now()
                    if current_time - start_time > max_duration:
                        raise TimeoutError(f"Time limit {max_duration} exceeded")
                    edited_histogram = edited_histogram.to(device)
                    original_histogram = original_histogram.to(device)

                    optimizer.zero_grad()
                    predicted_original = model(edited_histogram)
                    sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)
                    predicted_original = predicted_original / sum

                    if hyperparameters["loss"] == "kl":
                        predicted_original = torch.clamp(
                            predicted_original, 0.0000000000000000000001, 1
                        )

                    loss = {
                        "kl": lambda: loss_function(
                            torch.log(predicted_original),
                            original_histogram,
                        ),
                        "progressive": lambda: loss_function(
                            predicted_original, original_histogram
                        ),
                    }[hyperparameters["loss"]]()

                    epoch_loss += loss.item()
                    writer.add_scalar(
                        "Loss/train/batch",
                        loss,
                        epoch * len(train_data_loader) + batch_id,
                    )
                    loss.backward()

                    if hyperparameters["clip_gradients"]:
                        clip_grad_norm_(model.parameters(), max_norm=1.0)
                    optimizer.step()

                writer.add_hparams(
                    {
                        k: str(v) if isinstance(v, list) else v
                        for k, v in hyperparameters.items()
                    },
                    {
                        "Loss/train/epoch": epoch_loss,
                    },
                    global_step=epoch,
                    run_name=log_dir.absolute(),
                )
                with torch.no_grad():
                    model.eval()
                    loader = iter(test_data_loader)
                    edited_histogram, original_histogram = next(loader)
                    edited_histogram = edited_histogram.to(device)
                    original_histogram = original_histogram.to(device)
                    predicted_original = model(edited_histogram)
                    sum = torch.sum(predicted_original, dim=(2, 3, 4), keepdim=True)
                    predicted_original = predicted_original / sum
                    writer.add_figure(
                        "histogram",
                        plot_histograms_in_2d(
                            {
                                "original": original_histogram.cpu()[0]
                                .numpy()
                                .squeeze(),
                                "edited": edited_histogram.cpu()[0].numpy().squeeze(),
                                "predicted": predicted_original.cpu()[0]
                                .numpy()
                                .squeeze(),
                            }
                        ),
                        epoch,
                    )
                    model.train()
                scheduler.step()
        except Exception as e:
            raise
        finally:
            torch.save(model.state_dict(), model_path)
            del model
            torch.cuda.empty_cache()
        return model_path

Testing model SimpleCNN
Test passed! Output shape matches input shape.
Testing model Residual
Test passed! Output shape matches input shape.
Testing model NormalisedCNN
Test passed! Output shape matches input shape.
Testing model SmartRes
Test passed! Output shape matches input shape.
Testing model attention2
Test passed! Output shape matches input shape.
Testing model advanced_attention
Test passed! Output shape matches input shape.
Testing model Res2
Test passed! Output shape matches input shape.
Testing model attention1
Test passed! Output shape matches input shape.


In [4]:
# train(
#     {
#         "batch_size": 64,
#         "edit_count": 25,
#         "bin_count": 32,
#         "clip_gradients": True,
#         "learning_rate": 0.005,
#         "scheduler_gamma": 0.7,
#         "num_epochs": 20,
#         "model_type": "NormalisedCNN",
#         "loss": "progressive",
#         "loss_sizes": [16, 32],
#         "loss_damping": 2,
#     }
# )

In [5]:
from random import choice
from itertools import count
import json


tried = set()

for _ in count():
    current_hyperparameters = {
        k: v.rvs() if hasattr(v, "rvs") else choice(v)
        for k, v in choice(hyperparameters).items()
    }
    key = json.dumps(current_hyperparameters, indent=2, sort_keys=True)
    if key in tried:
        continue
    tried.add(key)
    logging.info(
        f"Starting {get_next_run_name(Path("runs"))} with hparams {key}"
    )
    try:
        train(current_hyperparameters, max_duration=timedelta(hours=2), use_tqdm=False)
    except KeyboardInterrupt as e:
        logging.info("Interrupted, stopping")
        break
    except TimeoutError as e:
        logging.warning(f"Timeout, aborting experiment")
    except Exception as e:
        logging.error(
            f"Error with hparams {current_hyperparameters}:\n\t{e}", stack_info=True
        )

2024-06-02 21:42:49,762 - INFO - Starting run_51 with hparams {
  "batch_size": 16,
  "bin_count": 64,
  "clip_gradients": true,
  "edit_count": 16,
  "learning_rate": 0.0019018860481580008,
  "loss": "kl",
  "model_type": "Residual",
  "num_epochs": 10,
  "scheduler_gamma": 0.5124233085818609
}
2024-06-02 21:42:49,787 - INFO - Loaded 359668 training images and 39964 test images
2024-06-02 23:43:03,698 - INFO - Starting run_52 with hparams {
  "batch_size": 16,
  "bin_count": 16,
  "clip_gradients": false,
  "edit_count": 8,
  "learning_rate": 2.9976475506468536e-05,
  "loss": "kl",
  "model_type": "SmartRes",
  "num_epochs": 10,
  "scheduler_gamma": 0.8138813825657673
}
2024-06-02 23:43:03,991 - INFO - Loaded 179834 training images and 19982 test images
  scale = np.sqrt(self._sizes) * dpi / 72.0 * self._factor
2024-06-02 23:52:17,393 - INFO - Starting run_53 with hparams {
  "batch_size": 8,
  "bin_count": 32,
  "clip_gradients": false,
  "edit_count": 8,
  "learning_rate": 0.0002765

CUDA out of memory. Tried to allocate 6.00 GiB. GPU 0 has a total capacity of 15.99 GiB of which 0 bytes is free. Including non-PyTorch memory, this process has 17179869184.00 GiB memory in use. Of the allocated memory 39.04 GiB is allocated by PyTorch, and 2.02 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Error occurs, No graph saved


2024-06-03 02:05:41,071 - INFO - Starting run_58 with hparams {
  "batch_size": 64,
  "bin_count": 16,
  "clip_gradients": false,
  "edit_count": 16,
  "learning_rate": 5.8262398455352215e-05,
  "loss": "kl",
  "model_type": "attention2",
  "num_epochs": 10,
  "scheduler_gamma": 0.17181073763193916
}
2024-06-03 02:05:41,262 - INFO - Loaded 359668 training images and 39964 test images
2024-06-03 03:49:02,268 - INFO - Starting run_59 with hparams {
  "batch_size": 16,
  "bin_count": 16,
  "clip_gradients": false,
  "edit_count": 32,
  "learning_rate": 0.00017213076448986518,
  "loss": "kl",
  "model_type": "NormalisedCNN",
  "num_epochs": 10,
  "scheduler_gamma": 0.1302383221350669
}
2024-06-03 03:49:02,397 - INFO - Loaded 719337 training images and 79927 test images
2024-06-03 04:28:45,612 - INFO - Starting run_60 with hparams {
  "batch_size": 16,
  "bin_count": 16,
  "clip_gradients": false,
  "edit_count": 32,
  "learning_rate": 0.00010975854085067054,
  "loss": "kl",
  "model_type":