In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch
import copy
import mlflow
import mlflow.pytorch
from pathlib import Path

In [None]:
from src.training.RawDatasetDNG import RawDatasetDNG
from src.training.losses.ShadowAwareLoss import ShadowAwareLoss
from src.training.VGGFeatureExtractor import VGGFeatureExtractor
from src.training.train_loop import train_one_epoch, visualize
from src.training.utils import apply_gamma_torch
from src.training.load_config import load_config
from src.Restorer.Cond_NAF import  make_full_model_RGGB


In [None]:
run_config = load_config()
dataset_path = Path(run_config['cropped_raw_subdir'])
align_csv = dataset_path / run_config['secondary_align_csv']

In [None]:
dataset_path

In [None]:
device=run_config['device']

batch_size = run_config['batch_size']
lr = run_config['lr_base'] * batch_size
clipping =  run_config['clipping']

num_epochs = run_config['num_epochs_finetuning']
cosine_annealing = run_config['cosine_annealing']

val_split = run_config['val_split']
crop_size = run_config['crop_size']
experiment = run_config['mlflow_experiment']
mlflow_path = run_config['mlflow_path']
colorspace = run_config['colorspace']
iso_range = run_config['iso_range']

rggb = True
mlflow.set_tracking_uri(f"file://{mlflow_path}")
mlflow.set_experiment(experiment)

params = {**run_config}

In [None]:

RUN_ID = "7d9ffb05e2c747fe93647e06ef43e51b"  
ARTIFACT_PATH = run_config['run_path']
params['base_RUN_ID'] = RUN_ID
params['base_ARTIFACT_PATH'] = ARTIFACT_PATH

model_uri = f"runs:/{RUN_ID}/{ARTIFACT_PATH}"

try:
    model = mlflow.pytorch.load_model(model_uri)
    model.eval()
    print(f"Model successfully loaded from MLflow URI: {model_uri}")
    

except Exception as e:
    print(f"Error loading model from MLflow: {e}")

In [None]:
dataset = RawDatasetDNG(dataset_path, align_csv, colorspace, crop_size=crop_size)
dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('crw')]
dataset.df = dataset.df[~dataset.df.bayer_path.str.contains('dng_bayer')]
dataset.df = dataset.df[(dataset.df.iso >= iso_range[0]) & (dataset.df.iso <= iso_range[1])]

# Split dataset into train and val
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
torch.manual_seed(42)  # For reproducibility
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# Set the validation dataset to use the same crops
val_dataset = copy.deepcopy(val_dataset)
val_dataset.dataset.validation = True

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
if cosine_annealing:
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs,eta_min=lr*1e-6)
else:
    sched = None
    
vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), 
                          feature_layers=[14], 
                          activation=nn.ReLU
                          )
vfe = vfe.to(device)

loss_fn = ShadowAwareLoss(
    alpha=run_config['alpha'],
    beta=run_config['beta'],
    l1_weight=run_config['l1_weight'],
    ssim_weight=run_config['ssim_weight'],
    tv_weight=run_config['tv_weight'],
    vgg_loss_weight=run_config['vgg_loss_weight'],
    percept_loss_weight=run_config['percept_loss_weight'],
    apply_gamma_fn=apply_gamma_torch,
    vgg_feature_extractor=vfe,
    device=device,
)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
if cosine_annealing:
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs,eta_min=lr*1e-6)
else:
    sched = None
    
vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), 
                          feature_layers=[14], 
                          activation=nn.ReLU
                          )
vfe = vfe.to(device)

loss_fn = ShadowAwareLoss(
    alpha=run_config['alpha'],
    beta=run_config['beta'],
    l1_weight=run_config['l1_weight'],
    ssim_weight=run_config['ssim_weight'],
    tv_weight=run_config['tv_weight'],
    vgg_loss_weight=run_config['vgg_loss_weight'],
    percept_loss_weight=run_config['percept_loss_weight'],
    apply_gamma_fn=apply_gamma_torch,
    vgg_feature_extractor=vfe,
    device=device,
)

In [None]:
with mlflow.start_run(run_name=run_config['run_name']) as run:
    mlflow.log_params(params)
    for epoch in range(num_epochs):
        train_one_epoch(epoch, model, optimizer, train_loader, device, loss_fn, clipping, 
                        log_interval = 10, sleep=0.0, rggb=rggb, max_batches=0)
        if cosine_annealing:
            sched.step()
        
    mlflow.pytorch.log_model(
        pytorch_model=model,
        name=run_config['run_path'],
    )

In [None]:
run.info.run_id