In [1]:
import random

import torch
from torch import ge, gt, le
from torch.nn import MSELoss
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter

from congrads.checkpoints import CheckpointManager
from congrads.constraints import BinaryConstraint, Constraint, ScalarConstraint
from congrads.core import CongradsCore
from congrads.datasets import BiasCorrection
from congrads.descriptor import Descriptor
from congrads.metrics import MetricManager
from congrads.networks import MLPNetwork
from congrads.utils import (
    CSVLogger,
    preprocess_BiasCorrection,
    split_data_loaders,
)

In [None]:
# Set seed for reproducibility
random.seed(42)
seeds = []
for i in range(3):
    seeds.append(random.randint(10, 10**6))
torch.manual_seed(seeds[0])
torch.cuda.manual_seed(seeds[1])
torch.cuda.manual_seed_all(seeds[2])

In [3]:
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

In [None]:
# Load and preprocess data
data = BiasCorrection(
    "./datasets", preprocess_BiasCorrection, download=True
)
loaders = split_data_loaders(
    data,
    loader_args={
        "batch_size": 100,
        "shuffle": True,
        "num_workers": 6,
        "prefetch_factor": 2,
    },
    valid_loader_args={"shuffle": False},
    test_loader_args={"shuffle": False},
)

In [5]:
# Instantiate network and push to correct device
network = MLPNetwork(25, 2, n_hidden_layers=3, hidden_dim=35)
network = network.to(device)

In [6]:
# Instantiate loss and optimizer
criterion = MSELoss()
optimizer = Adam(network.parameters(), lr=0.001)

In [7]:
# Descriptor setup
descriptor = Descriptor()
descriptor.add("output", 0, "Tmax")
descriptor.add("output", 1, "Tmin")

In [8]:
# Constraints definition
Constraint.descriptor = descriptor
Constraint.device = device
constraints = [
    ScalarConstraint("Tmin", ge, 0),
    ScalarConstraint("Tmin", le, 1),
    ScalarConstraint("Tmax", ge, 0),
    ScalarConstraint("Tmax", le, 1),
    BinaryConstraint("Tmax", gt, "Tmin"),
]

In [9]:
# Initialize metric manager
metric_manager = MetricManager()

In [10]:
# Instantiate core
core = CongradsCore(
    descriptor,
    constraints,
    loaders,
    network,
    criterion,
    optimizer,
    metric_manager,
    device,
)

In [11]:
# Set up metric logging
csv_logger = CSVLogger("logs/BiasCorrection.csv")

def on_epoch_end(epoch: int):
    # Log metric values to TensorBoard and CSV file
    for name, value in metric_manager.aggregate("during_training").items():
        csv_logger.add_value(name, value.item(), epoch)

    # Write changes to disk
    csv_logger.save()

    # Reset metric manager
    metric_manager.reset("during_training")

def on_train_end(epoch: int):
    # Log metric values to TensorBoard and CSV file
    for name, value in metric_manager.aggregate("after_training").items():
        csv_logger.add_value(name, value.item(), epoch)

    # Write changes to disk
    csv_logger.save()

    # Reset metric manager
    metric_manager.reset("after_training")

In [None]:
# Start training
core.fit(
    start_epoch=0,
    max_epochs=5,
    on_epoch_end=on_epoch_end,
    on_train_end=on_train_end,
)