In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch
import copy
import mlflow
import mlflow.pytorch
from pathlib import Path

In [None]:
from src.training.DemosaicingDataset import DemosaicingDataset
from src.training.losses.ShadowAwareLoss import ShadowAwareLoss
from src.training.VGGFeatureExtractor import VGGFeatureExtractor
from src.training.train_loop import train_one_epoch, visualize
from src.training.utils import apply_gamma_torch
from src.training.load_config import load_config
from src.Restorer.Cond_NAF import  make_full_model_RGGB_Demosaicing


In [None]:
run_config = load_config('config_demosaicing.yaml')
dataset_path = Path(run_config['cropped_raw_subdir'])
align_csv = dataset_path / run_config['secondary_align_csv']

In [None]:
device=run_config['device']

batch_size = run_config['batch_size']
lr = run_config['lr_base'] * batch_size
clipping =  run_config['clipping']

num_epochs = run_config['num_epochs_pretraining']
cosine_annealing = run_config['cosine_annealing']

val_split = run_config['val_split']
crop_size = run_config['crop_size']
experiment = run_config['mlflow_experiment']
mlflow_path = run_config['mlflow_path']
colorspace = run_config['colorspace']
iso_range = run_config['iso_range']

rggb = True
mlflow.set_tracking_uri(f"file://{mlflow_path}")
mlflow.set_experiment(experiment)

params = {**run_config}

In [None]:
model_params = run_config['model_params']
rggb = model_params['rggb']

model =  make_full_model_RGGB_Demosaicing(model_params, model_name=None)
model = model.to(device)

params = {**run_config, **model_params}

In [None]:
dataset = DemosaicingDataset(dataset_path, align_csv, colorspace, output_crop_size=crop_size, downsample_factor=4)
dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('crw')]
dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('dng_bayer')]
dataset.df = dataset.df[(dataset.df.iso >= iso_range[0]) & (dataset.df.iso <= iso_range[1])]
print(len(dataset.df ))
# Split dataset into train and val
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
torch.manual_seed(42)  # For reproducibility
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# Set the validation dataset to use the same crops
val_dataset = copy.deepcopy(val_dataset)
val_dataset.dataset.validation = True

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
if cosine_annealing:
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs,eta_min=lr*1e-6)
else:
    sched = None
    
vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), 
                          feature_layers=[14], 
                          activation=nn.ReLU
                          )
vfe = vfe.to(device)

loss_fn = ShadowAwareLoss(
    alpha=run_config['alpha'],
    beta=run_config['beta'],
    l1_weight=run_config['l1_weight'],
    ssim_weight=run_config['ssim_weight'],
    tv_weight=run_config['tv_weight'],
    vgg_loss_weight=run_config['vgg_loss_weight'],
    percept_loss_weight=run_config['percept_loss_weight'],
    apply_gamma_fn=apply_gamma_torch,
    vgg_feature_extractor=vfe,
    device=device,
)

In [None]:
from time import perf_counter
import time
from tqdm import tqdm
import torch
import torch.nn as nn
from src.training.utils import apply_gamma_torch
import mlflow

def make_conditioning(conditioning, device):
    B = conditioning.shape[0]
    conditioning_extended = torch.zeros(B, 1).to(device)
    conditioning_extended[:, 0] = conditioning[:, 0]
    return conditioning_extended


def train_one_epoch(epoch, _model, _optimizer, _loader, _device, _loss_func, _clipping, 
                    log_interval = 10, sleep=0.0, rggb=False,
                    max_batches=0):
    _model.train()
    total_loss, n_images, total_l1_loss = 0.0, 0, 0.0
    start = perf_counter()
    pbar = tqdm(enumerate(_loader), total=len(_loader), desc=f"Train Epoch {epoch}")

    for batch_idx, (output) in pbar:
        conditioning = output['conditioning'].float().to(_device)
        gt = output['ground_truth'].float().to(_device)
        input = output['cfa_sparse'].float().to(_device)
        if rggb:
            input = output['cfa_rggb'].float().to(_device)
        conditioning = make_conditioning(conditioning, _device)

        _optimizer.zero_grad(set_to_none=True)
        pred = _model(input, conditioning) 

        loss = _loss_func(pred, gt)
        _optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(_model.parameters(), _clipping)
        _optimizer.step()

        total_loss +=  float(loss.detach().cpu())
        n_images += gt.shape[0]

        # Testing final image quality
        final_image_loss = float(nn.functional.l1_loss(pred, gt).detach().cpu())
        total_l1_loss += final_image_loss
        del loss, pred, final_image_loss
        torch.mps.empty_cache() 

        if (batch_idx + 1) % log_interval == 0:
                pbar.set_postfix({"loss": f"{total_loss/n_images:.4f}"})

        if (max_batches > 0) and (batch_idx+1 > max_batches): break
        time.sleep(sleep)

    train_time = perf_counter()-start
    print(f"[Epoch {epoch}] "
                f"Train loss: {total_loss/n_images:.6f} "
                f"L1 loss: {total_l1_loss/n_images:.6f} "
                f"Time: {train_time:.1f}s "
                f"Images: {n_images}")
    mlflow.log_metric("train_loss", total_loss/n_images, step=epoch)
    mlflow.log_metric("l1_loss", total_l1_loss/n_images, step=epoch)
    mlflow.log_metric("epoch_duration_s", train_time, step=epoch)
    mlflow.log_metric("learning_rate", _optimizer.param_groups[0]['lr'], step=epoch)

    return total_loss / max(1, n_images), perf_counter()-start

In [None]:
with mlflow.start_run(run_name=run_config['run_name']) as run:
    mlflow.log_params(params)
    for epoch in range(num_epochs):
        train_one_epoch(epoch, model, optimizer, train_loader, device, loss_fn, clipping, 
                        log_interval = 10, sleep=0.0, rggb=rggb, max_batches=0)
        if cosine_annealing:
            sched.step()
        
    mlflow.pytorch.log_model(
        pytorch_model=model,
        name=run_config['run_path'],
    )

In [None]:
run.info.run_id