In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch
import copy
import mlflow
from pathlib import Path

In [None]:
from src.training.SmallRawDataset import SmallRawDataset

from src.training.losses.ShadowAwareLoss import ShadowAwareLoss
from src.training.VGGFeatureExtractor import VGGFeatureExtractor
from src.training.utils import apply_gamma_torch
from src.training.train_loop import visualize
from src.training.load_config import load_config


In [None]:
run_config = load_config()
dataset_path = Path(run_config['jpeg_output_subdir'])
align_csv = dataset_path / run_config['secondary_align_csv']
device=run_config['device']

batch_size = run_config['batch_size']
lr = run_config['lr_base'] * batch_size
clipping =  run_config['clipping']

num_epochs = run_config['num_epochs_pretraining']
val_split = run_config['val_split']
crop_size = run_config['crop_size']
experiment = run_config['mlflow_experiment']
model_params = run_config['model_params']
rggb = model_params['rggb']
mlflow_path = run_config['mlflow_path']
mlflow.set_tracking_uri(f"file://{mlflow_path}")
mlflow.set_experiment(experiment)

In [None]:

RUN_ID = "b0664f324e9444d3b3a5277d513d3642"  
ARTIFACT_PATH = run_config['run_path']

model_uri = f"runs:/{RUN_ID}/{ARTIFACT_PATH}"

try:
    model = mlflow.pytorch.load_model(model_uri)
    model.eval()
    print(f"Model successfully loaded from MLflow URI: {model_uri}")
    

except Exception as e:
    print(f"Error loading model from MLflow: {e}")

In [None]:
dataset = SmallRawDataset(dataset_path, align_csv, crop_size=crop_size)

# Split dataset into train and val
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size
torch.manual_seed(42)  # For reproducibility
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# Set the validation dataset to use the same crops
val_dataset = copy.deepcopy(val_dataset)
val_dataset.dataset.validation = True

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

vfe = VGGFeatureExtractor(config=((1, 64), (1, 128), (1, 256), (1, 512), (1, 512),), 
                          feature_layers=[14], 
                          activation=nn.ReLU
                          )
vfe = vfe.to(device)

loss_fn = ShadowAwareLoss(
    alpha=run_config['alpha'],
    beta=run_config['beta'],
    l1_weight=run_config['l1_weight'],
    ssim_weight=run_config['ssim_weight'],
    tv_weight=run_config['tv_weight'],
    vgg_loss_weight=run_config['vgg_loss_weight'],
    apply_gamma_fn=apply_gamma_torch,
    vgg_feature_extractor=vfe,
    device=device,
)

In [None]:
# with mlflow.start_run(run_id=existing_run_id) as run:
#     print(f"Re-opened run: {run.info.run_name}")
    
#     # Log your new validation metric
#     mlflow.log_metric("final_validation_accuracy", new_validation_metric)

In [None]:
import numpy as np
subset_indices = np.array(val_dataset.indices)  # indices in the original dataset
mask = val_dataset.dataset.df.iso.values[subset_indices] == 65535
matching_indices_in_subset = np.nonzero(mask)[0]
matching_indices_in_subset

In [None]:
visualize(matching_indices_in_subset,  model, val_dataset, device, loss_fn, rggb=rggb)

In [None]:
Train loss: 0.056005 Final image val loss: 0.009073 Time: 45.7s Images: 25


In [None]:
import os
import torch
import mlflow
import matplotlib.pyplot as plt
import numpy as np
from time import perf_counter


def make_conditioning(conditioning, device):
    B = conditioning.shape[0]
    conditioning_extended = torch.zeros(B, 1).to(device)
    conditioning_extended[:, 0] = conditioning[:, 0]
    return conditioning_extended


def validate_model(
    _model, 
    val_dataset, 
    _device, 
    _loss_func, epoch,
    iso_values=(100, 400, 1600, 65535), 
    n_examples=3, 
    rggb=False,
    artifact_dir="val_examples"
):
    """
    Run validation over different ISO values and log metrics + sample images to MLflow.

    Args:
        _model: torch model
        val_dataset: dataset with `df.iso` values accessible
        _device: torch device
        _loss_func: loss function
        iso_values: list/tuple of ISO values to evaluate
        n_examples: number of images to log per ISO
        rggb: whether to use rggb input
        artifact_dir: subdirectory to store example images for MLflow
    """

    _model.eval()
    os.makedirs(artifact_dir, exist_ok=True)

    all_metrics = {}

    for iso in iso_values:
        # Select subset indices with this ISO
        subset_indices = np.array(val_dataset.indices)
        mask = val_dataset.dataset.df.iso.values[subset_indices] == iso
        matching_indices_in_subset = np.nonzero(mask)[0]

        if len(matching_indices_in_subset) == 0:
            print(f"No validation samples for ISO {iso}")
            continue


        idxs = matching_indices_in_subset[:n_examples]

        total_loss, final_img_loss, duration = visualize(
            idxs, _model, val_dataset, _device, _loss_func, rggb=rggb
        )


        # Save and log a few example denoised images
        for i, idx in enumerate(idxs):
            row = val_dataset[idx]
            noisy = row['noisy'].unsqueeze(0).float().to(_device)
            conditioning = make_conditioning(row['conditioning'].float().unsqueeze(0).to(_device), _device)
            gt = row['aligned'].unsqueeze(0).float().to(_device)
            input = row['rggb' if rggb else 'sparse'].unsqueeze(0).float().to(_device)

            with torch.no_grad():
                with torch.autocast(device_type="mps", dtype=torch.bfloat16):
                    pred = _model(input, conditioning, noisy)

            pred_img = apply_gamma_torch(pred[0].cpu().permute(1, 2, 0))
            noisy_img = apply_gamma_torch(noisy[0].cpu().permute(1, 2, 0))
            gt_img = apply_gamma_torch(gt[0].cpu().permute(1, 2, 0))

            fig, axs = plt.subplots(1, 3, figsize=(15, 5))
            axs[0].imshow(noisy_img)
            axs[0].set_title(f"Noisy (ISO {iso})")
            axs[1].imshow(pred_img)
            axs[1].set_title("Denoised")
            axs[2].imshow(gt_img)
            axs[2].set_title("Ground Truth")
            for ax in axs: ax.axis('off')

            img_path = os.path.join(artifact_dir, f"iso{iso}_example{i}.png")
            plt.savefig(img_path, bbox_inches="tight")
            plt.close(fig)

            mlflow.log_artifact(img_path, artifact_path=f"{artifact_dir}/ISO_{iso}")

        # Log metrics per ISO
        mlflow.log_metrics({
            f"val_loss_ISO_{iso}": total_loss,
            f"final_img_loss_ISO_{iso}": final_img_loss,
            f"val_duration_ISO_{iso}": duration
        }, step=epoch)

    # Also log a summary table
    print("\nValidation Summary:")
    for iso, metrics in all_metrics.items():
        print(f"ISO {iso}: val_loss={metrics['val_loss']:.6f}, "
              f"final_img_loss={metrics['final_image_loss']:.6f}, "
              f"time={metrics['duration']:.1f}s")

    return all_metrics


In [None]:
with mlflow.start_run(run_name=run_config['run_name']) as run:
    metrics = validate_model(model, val_dataset, device, loss_fn, rggb=True)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SwiGLU(nn.Module):

    def __init__(self, input_dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.w1 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1)
        self.w2 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1)
        self.w3 = nn.Conv2d(hidden_dim, input_dim, 1, 1, 0, 1)
        
    def forward(self, x):
        gate = F.silu(self.w1(x)) 
        value = self.w2(x)
        x = gate * value 
        
        x = self.w3(x)
        return x

In [None]:
class ConditionedChannelAttention(nn.Module):
    def __init__(self, dims, cat_dims):
        super().__init__()
        in_dim = dims + cat_dims
        self.mlp = nn.Sequential(nn.Linear(in_dim, dims))
        self.pool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x, conditioning):
        pool = self.pool(x)
        conditioning = conditioning.unsqueeze(-1).unsqueeze(-1)
        cat_channels = torch.cat([pool, conditioning], dim=1)
        cat_channels = cat_channels.permute(0, 2, 3, 1)
        ca = self.mlp(cat_channels).permute(0, 3, 1, 2)

        return ca

In [None]:
import torch.nn.functional as F

class SwiGLU(nn.Module):

    def __init__(self, input_dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.w1 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1)
        self.w2 = nn.Conv2d(input_dim, hidden_dim, 1, 1, 0, 1)
        self.w3 = nn.Conv2d(hidden_dim, input_dim, 1, 1, 0, 1)
        
    def forward(self, x):
        gate = F.silu(self.w1(x)) 
        value = self.w2(x)
        x = gate * value 
        
        x = self.w3(x)
        return x
    
class AttnBlock(nn.Module):
    def __init__(self, c, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0):
        super().__init__()
        
        self.dw = nn.Conv2d(
            in_channels=c,
            out_channels=c,
            kernel_size=3,
            padding=1,
            stride=1,
            groups=c,
            bias=True,
        )

        self.sca = ConditionedChannelAttention(c, cond_chans)

        self.norm = nn.GroupNorm(1, c)
        
        self.swiglu = SwiGLU(c, int(c *  FFN_Expand))
        self.alpha = nn.Parameter(torch.zeros(1, c, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, c, 1, 1))


    def forward(self, input):
        inp = input[0]
        cond = input[1]

        x = self.dw(inp)
        x = self.sca(x, cond) * x
        y = self.norm(inp + self.alpha * x )


        x = self.swiglu(y)
        x = y + self.beta * x
        return (x, cond)

In [None]:
block = AttnBlock(64, 2, cond_chans=1)
block((torch.rand(1, 64, 32, 32), torch.rand(1, 1)))[0].shape

In [None]:
class NAFBlock0(nn.Module):
    def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.0, cond_chans=0):
        super().__init__()
        dw_channel = c * DW_Expand
        self.conv1 = nn.Conv2d(
            in_channels=c,
            out_channels=dw_channel,
            kernel_size=1,
            padding=0,
            stride=1,
            groups=1,
            bias=True,
        )
        self.conv2 = nn.Conv2d(
            in_channels=dw_channel,
            out_channels=dw_channel,
            kernel_size=3,
            padding=1,
            stride=1,
            groups=dw_channel,
            bias=True,
        )
        self.conv3 = nn.Conv2d(
            in_channels=dw_channel // 2,
            out_channels=c,
            kernel_size=1,
            padding=0,
            stride=1,
            groups=1,
            bias=True,
        )

        # Simplified Channel Attention
        self.sca = ConditionedChannelAttention(dw_channel // 2, cond_chans)

        # SimpleGate
        self.sg = SimpleGate()

        ffn_channel = FFN_Expand * c
        self.conv4 = nn.Conv2d(
            in_channels=c,
            out_channels=ffn_channel,
            kernel_size=1,
            padding=0,
            stride=1,
            groups=1,
            bias=True,
        )
        self.conv5 = nn.Conv2d(
            in_channels=ffn_channel // 2,
            out_channels=c,
            kernel_size=1,
            padding=0,
            stride=1,
            groups=1,
            bias=True,
        )

        # self.grn = GRN(ffn_channel // 2)

        self.norm1 = LayerNorm2d(c)
        self.norm2 = LayerNorm2d(c)

        self.dropout1 = (
            nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity()
        )
        self.dropout2 = (
            nn.Dropout(drop_out_rate) if drop_out_rate > 0.0 else nn.Identity()
        )

        self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

    def forward(self, input):
        inp = input[0]
        cond = input[1]

        x = inp

        x = self.norm1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.sg(x)
        x = x * self.sca(x, cond)
        x = self.conv3(x)

        x = self.dropout1(x)

        y = inp + x * self.beta

        # Channel Mixing
        x = self.conv4(self.norm2(y))
        x = self.sg(x)
        # x = self.grn(x)
        x = self.conv5(x)

        x = self.dropout2(x)

        return (y + x * self.gamma, cond)