In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch
import copy
import mlflow
import mlflow.pytorch
from pathlib import Path

In [None]:
from src.training.DemosaicingDataset import DemosaicingDataset
from src.training.losses.ShadowAwareLoss import ShadowAwareLoss
from src.training.VGGFeatureExtractor import VGGFeatureExtractor
from src.training.train_loop import train_one_epoch, visualize
from src.training.utils import apply_gamma_torch
from src.training.load_config import load_config
from src.Restorer.Cond_NAF import  make_full_model_RGGB_Demosaicing


In [None]:
run_config = load_config('config_demosaicing.yaml')
dataset_path = Path(run_config['cropped_raw_subdir'])
align_csv = dataset_path / run_config['secondary_align_csv']

In [None]:
device=run_config['device']

batch_size = run_config['batch_size']
lr = run_config['lr_base'] * batch_size
clipping =  run_config['clipping']

num_epochs = run_config['num_epochs_pretraining']
cosine_annealing = run_config['cosine_annealing']

val_split = run_config['val_split']
crop_size = run_config['crop_size']
experiment = run_config['mlflow_experiment']
mlflow_path = run_config['mlflow_path']
colorspace = run_config['colorspace']
iso_range = run_config['iso_range']

rggb = True
mlflow.set_tracking_uri(f"file://{mlflow_path}")
mlflow.set_experiment(experiment)

params = {**run_config}

In [None]:

RUN_ID = "1e253a47ff7d43e1a432cce4ed083c7b"  
ARTIFACT_PATH = run_config['run_path']

model_uri = f"runs:/{RUN_ID}/{ARTIFACT_PATH}"

try:
    model = mlflow.pytorch.load_model(model_uri)
    model.eval()
    print(f"Model successfully loaded from MLflow URI: {model_uri}")
    

except Exception as e:
    print(f"Error loading model from MLflow: {e}")

In [None]:
dataset = DemosaicingDataset(dataset_path, align_csv, colorspace, output_crop_size=crop_size, downsample_factor=4)
dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('crw')]
dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('dng_bayer')]
dataset.df = dataset.df[(dataset.df.iso >= iso_range[0]) & (dataset.df.iso <= iso_range[1])]
print(len(dataset.df ))
# Split dataset into train and val
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
torch.manual_seed(42)  # For reproducibility
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# Set the validation dataset to use the same crops
val_dataset = copy.deepcopy(val_dataset)
val_dataset.dataset.validation = True

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
if cosine_annealing:
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs,eta_min=lr*1e-6)
else:
    sched = None
    
vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), 
                          feature_layers=[14], 
                          activation=nn.ReLU
                          )
vfe = vfe.to(device)

loss_fn = ShadowAwareLoss(
    alpha=run_config['alpha'],
    beta=run_config['beta'],
    l1_weight=run_config['l1_weight'],
    ssim_weight=run_config['ssim_weight'],
    tv_weight=run_config['tv_weight'],
    vgg_loss_weight=run_config['vgg_loss_weight'],
    percept_loss_weight=run_config['percept_loss_weight'],
    apply_gamma_fn=apply_gamma_torch,
    vgg_feature_extractor=vfe,
    device=device,
)

In [None]:
from time import perf_counter
import time
from tqdm import tqdm
import torch
import torch.nn as nn
from src.training.utils import apply_gamma_torch
import mlflow

def make_conditioning(conditioning, device):
    B = conditioning.shape[0]
    conditioning_extended = torch.zeros(B, 1).to(device)
    conditioning_extended[:, 0] = conditioning[:, 0]
    return conditioning_extended


from colour_demosaicing import (
    mosaicing_CFA_Bayer,
    demosaicing_CFA_Bayer_Menon2007
)
def visualize(idxs, _model, dataset, _device, _loss_func, rggb=False):
    import matplotlib.pyplot as plt
    _model.train()
    total_loss, n_images, total_final_image_loss = 0.0, 0, 0.0
    start = perf_counter()

    for idx in idxs:
        row = dataset[idx]
        conditioning = row['conditioning'].unsqueeze(0).float().to(_device)
        gt = row['ground_truth'].unsqueeze(0).float().to(_device)
        sparse = row['cfa_sparse'].unsqueeze(0).float().to(_device)
        input = sparse
        if rggb:
            input = row['cfa_rggb'].unsqueeze(0).float().to(_device)

        conditioning = make_conditioning(conditioning, _device)
        
        with torch.no_grad():
            with torch.autocast(device_type="mps", dtype=torch.bfloat16):
                pred = _model(input, conditioning) 
        loss = _loss_func(pred, gt)

        total_loss +=  float(loss.detach().cpu())
        n_images += gt.shape[0]

        # Testing final image quality
        final_image_loss = nn.functional.l1_loss(pred, gt)
        total_final_image_loss += final_image_loss.item()

        plt.subplots(2, 3, figsize=(30, 15))

        plt.subplot(2, 3, 1)
        plt.title('pred')
        pred = apply_gamma_torch(pred[0].cpu().permute(1, 2, 0))
        plt.imshow(pred)

        plt.subplot(2, 3, 2)
        plt.title('Menon')
        bayer = sparse[0].sum(axis=0)
        bayer = apply_gamma_torch(bayer).cpu().numpy()
        trad_demosaiced = demosaicing_CFA_Bayer_Menon2007(bayer)
        plt.imshow(trad_demosaiced)

        plt.subplot(2, 3, 3)
        plt.title('gt')

        gt = apply_gamma_torch(gt[0].cpu().permute(1, 2, 0))
        plt.imshow(gt)

        plt.subplot(2, 3, 4)
        plt.imshow(pred - gt + 0.5)


        plt.subplot(2, 3, 5)
        plt.imshow(trad_demosaiced - pred.cpu().numpy() + 0.5)

        plt.subplot(2, 3, 6)
        plt.imshow(trad_demosaiced - gt.cpu().numpy() + 0.5)
        plt.show()
        plt.clf()

    n_images = len(idxs)
    print(
                f"Train loss: {total_loss/n_images:.6f} "
                f"Final image val loss: {total_final_image_loss/n_images:.6f} "
                f"Time: {perf_counter()-start:.1f}s "
                f"Images: {n_images}")

    return total_loss / max(1, n_images), total_final_image_loss / max(1, n_images), perf_counter()-start

In [None]:
import numpy as np
subset_indices = np.array(val_dataset.indices)  # indices in the original dataset
mask = val_dataset.dataset.df.iso.values[subset_indices] == 200
matching_indices_in_subset = np.nonzero(mask)[0]
matching_indices_in_subset

In [None]:
visualize(matching_indices_in_subset,  model, val_dataset, device, loss_fn, rggb=rggb)