In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch
import copy
import mlflow
import mlflow.pytorch
from pathlib import Path

In [None]:
from src.training.SmallRawDataset import SmallRawDataset
from src.training.losses.ShadowAwareLoss import ShadowAwareLoss
from src.training.VGGFeatureExtractor import VGGFeatureExtractor
from src.training.train_loop import train_one_epoch, visualize
from src.training.utils import apply_gamma_torch
from src.training.load_config import load_config
from src.Restorer.Cond_NAF import  make_full_model_RGGB


In [None]:
run_config = load_config()
dataset_path = Path(run_config['jpeg_output_subdir'])
align_csv = dataset_path / run_config['secondary_align_csv']

In [None]:
device=run_config['device']

batch_size = run_config['batch_size']
lr = run_config['lr_base'] * batch_size
clipping =  run_config['clipping']

num_epochs = run_config['num_epochs_pretraining']
val_split = run_config['val_split']
crop_size = run_config['crop_size']
experiment = run_config['mlflow_experiment']
mlflow_path = run_config['mlflow_path']
mlflow.set_tracking_uri(f"file://{mlflow_path}")
mlflow.set_experiment(experiment)

In [None]:
model_params = run_config['model_params']
rggb = model_params['rggb']

In [None]:
dataset = SmallRawDataset(dataset_path, align_csv, crop_size=crop_size)

# Split dataset into train and val
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
torch.manual_seed(42)  # For reproducibility
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# Set the validation dataset to use the same crops
val_dataset = copy.deepcopy(val_dataset)
val_dataset.dataset.validation = True

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [None]:
model =  make_full_model_RGGB(model_params, model_name=None)
model = model.to(device)

params = {**run_config, **model_params}

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), 
                          feature_layers=[14], 
                          activation=nn.ReLU
                          )
vfe = vfe.to(device)

loss_fn = ShadowAwareLoss(
    alpha=run_config['alpha'],
    beta=run_config['beta'],
    l1_weight=run_config['l1_weight'],
    ssim_weight=run_config['ssim_weight'],
    tv_weight=run_config['tv_weight'],
    vgg_loss_weight=run_config['vgg_loss_weight'],
    apply_gamma_fn=apply_gamma_torch,
    vgg_feature_extractor=vfe,
    device=device,
)

In [None]:
with mlflow.start_run(run_name=run_config['run_name']) as run:

    mlflow.log_params(params)
    for epoch in range(num_epochs):
        train_one_epoch(epoch, model, optimizer, train_loader, device, loss_fn, clipping, 
                        log_interval = 10, sleep=0.0, rggb=rggb, max_batches=0)
        
    mlflow.pytorch.log_model(
        pytorch_model=model,
        name=run_config['run_path'],
    )

In [None]:
run.info.run_id