# Konfiguracja
Przed uruchomieniem należy sprawdzic sciezki i odpowiednio ustawic stałe

In [1]:
MAIN_FOLDER = '/content/drive/MyDrive'

data_folder = MAIN_FOLDER + '/data/prepared' # zip z plikami ref i input

INPUT_95_DATASET_FOLDER = '/content/input_95_dataset'
REFERENCE_DATASET_FOLDER = '/content/reference_dataset'
INPUT_97_DATASET_FOLDER = '/content/input_97_dataset'

RESULTS_FOLDER = "/content/drive/MyDrive/SIGK_P2/results" # folder w którym będą obrazki wyjściowe z treningu oraz z testu (w podfolderze /TEST)
CHECKPOINTS_FOLDER = "/content/drive/MyDrive/SIGK_P2/checkpoints"

COMPARE_FOLDER = '/content/drive/MyDrive/SIGK_P2/compare'

In [2]:
unzip_files = False # czy wypakowywac pliki (z data_folder do DATASET_FOLDER)

In [3]:
train_model = True # czy trenować model (konfiguracja w sekcji Trening)

In [4]:
test_model = False # czy testować model (zapisac pliki wyjsciowe do folderu)
#PRZED TESTOWANIEM KONIECZNIE WCZYTAJ WAGI MODELU (CHECKPOINT) - nie korzysta z tego samego modelu co w treningu

In [5]:
calc_metrics = False # czy wyliczac metryki

# Moduły

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
!pip install brisque

Collecting brisque
  Downloading brisque-0.0.17-py3-none-any.whl.metadata (2.4 kB)
Collecting libsvm-official (from brisque)
  Downloading libsvm-official-3.36.0.tar.gz (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.2/40.2 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading brisque-0.0.17-py3-none-any.whl (140 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.3/140.3 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: libsvm-official
  Building wheel for libsvm-official (setup.py) ... [?25l[?25hdone
  Created wheel for libsvm-official: filename=libsvm_official-3.36.0-cp312-cp312-linux_x86_64.whl size=124634 sha256=00c830703946fc897bdaac641e9e9553e03401def31e719592489c2b98e7972d
  Stored in directory: /root/.cache/pip/wheels/df/65/4b/c3cdece6e5fa7eebef116be2d5a309f7ac50c90183cbe12c92
Successfully built libsvm-official
Installing colle

In [8]:
import zipfile
import os
import glob

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
import torch.nn.functional as F
import numpy as np

In [10]:
import cv2
from numpy import ndarray
from brisque import BRISQUE
from skimage.metrics import structural_similarity as ssim

In [11]:
from datetime import datetime

In [12]:
# from google.colab.patches import cv2_imshow

# Funkcje pomocnicze - metryki oraz przygotowanie plików

## Rozpakowanie archiwum

In [13]:
def unzip(source_path, target_path, with_main_folder=True):
  if os.path.exists(source_path):
    os.makedirs(target_path, exist_ok=True)

    with zipfile.ZipFile(source_path, 'r') as zip_ref:
      if with_main_folder:
        zip_ref.extractall(target_path)
      else:
        for content in zip_ref.namelist():
          name_archive = content.split('/', 1)[0]
          content = content.split('/', 1)[-1] if '/' in content else content

          if content == '' or content == '/':
            continue

          with zip_ref.open(name_archive + '/' + content) as source, open(os.path.join(target_path, content), "wb") as target:
            target.write(source.read())

    print(f'File unziped: source={source_path}, target={target_path}')
  else:
    print(f'Cannot unzip (file not found): source={source_path}, target={target_path}')

## Operacje na plikach EXR i operatory

In [14]:
 # enable using OpenEXR with OpenCV
 os.environ['OPENCV_IO_ENABLE_OPENEXR'] = "1"

In [15]:
def read_exr(im_path: str)-> ndarray:
  return cv2.imread(
  filename=im_path,
  flags=cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH
  )

In [16]:
def tone_map_reinhard(image: ndarray)-> ndarray:
  tonemap_operator = cv2.createTonemapReinhard(
  gamma=2.2,
  intensity=0.0,
  light_adapt=0.0,
  color_adapt=0.0
  )
  result = tonemap_operator.process(src=image)
  return result

In [17]:
def tone_map_mantiuk(image: ndarray)-> ndarray:
  tonemap_operator = cv2.createTonemapMantiuk(
  gamma=2.2,
  scale=0.85,
  saturation=1.2
  )
  result = tonemap_operator.process(src=image)
  return result

## Metryka BRISQUE

In [18]:
def evaluate_image(image: ndarray)-> float:
 metric = BRISQUE(url=False)
 return metric.score(img=image)

# Wczytanie danych

In [19]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [20]:
data_folder_contents = os.listdir(data_folder)

print(f"Zawartość folderu {data_folder}: {data_folder_contents}")

Zawartość folderu /content/drive/MyDrive/data/prepared: ['reference.zip', 'input_97.zip', 'input_95.zip']


In [22]:
#if unzip_files:
reference_path = data_folder + '/reference.zip'
unzip(reference_path, REFERENCE_DATASET_FOLDER, False)

input_95_path = data_folder + '/input_97.zip'
unzip(input_95_path, INPUT_97_DATASET_FOLDER, False)

input_95_path = data_folder + '/input_95.zip'
unzip(input_95_path, INPUT_95_DATASET_FOLDER, False)

ref_files = sorted(os.listdir(REFERENCE_DATASET_FOLDER))
print(f"Unzipped reference dataset, total files: {len(ref_files)}")

File unziped: source=/content/drive/MyDrive/data/prepared/reference.zip, target=/content/reference_dataset
File unziped: source=/content/drive/MyDrive/data/prepared/input_97.zip, target=/content/input_97_dataset
File unziped: source=/content/drive/MyDrive/data/prepared/input_95.zip, target=/content/input_95_dataset
Unzipped reference dataset, total files: 181


# Model

## Helpery, dataset, zapisywanie rezultatu/utils

In [23]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [24]:
# -------------------- helpery dla sieci --------------------
def norm(x):
    x_max = np.max(x)
    x_min = np.min(x)
    scale = x_max - x_min
    x_norm = (x - x_min)/scale
    return x_norm

def norm_mean(img):
    img = 0.5 * img / img.mean()
    return img

def ulaw_np(img, scale = 10.0):
    median_value = np.median(img)
    scale = 8.759 * np.power(median_value, 2.148) + 0.1494 * np.power(median_value, -2.067)
    out = np.log(1 + scale*img) / np.log(1 + scale)
    return out.astype(np.float32), scale

def load_hdr_ldr_norm_ulaw(name_hdr):
    y = read_exr(name_hdr)
    y_rgb = np.maximum(cv2.cvtColor(y, cv2.COLOR_BGR2RGB), 0.0)
    y_rgb = norm_mean(y_rgb)
    y_ulaw, scale = ulaw_np(y_rgb)
    return scale, y_ulaw, y_rgb

In [25]:
# -------------------- DATASET --------------------
class HDRDataset(Dataset):
    def __init__(self, hdr_folder, limit=None, from_file=None):
        self.files = [os.path.join(hdr_folder, f) for f in os.listdir(hdr_folder) if f.endswith('.exr')]
        self.files.sort()
        if from_file:
            self.files = self.files[from_file:]
        if limit:
            self.files = self.files[:limit]


    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        hdr_file = self.files[idx]
        scale, hdr_ulaw, hdr_rgb = load_hdr_ldr_norm_ulaw(hdr_file)

        # do Pytorch: (C,H,W)
        hdr_ulaw = torch.from_numpy(hdr_ulaw).permute(2,0,1)
        hdr_rgb   = torch.from_numpy(hdr_rgb).permute(2,0,1)

        return hdr_ulaw, hdr_rgb, hdr_file

In [26]:
# -------------------- UTIL dla sieci --------------------
''' v1
def mul_exp(img):
    # img: (B,C,H,W)
    B,C,H,W = img.shape
    x_p = 1.21497
    max_val = img.view(B,-1).max(dim=1)[0].view(B,1,1,1)
    med_val = img.view(B,-1).median(dim=1)[0].view(B,1,1,1)
    c_start = torch.log(x_p/max_val)/torch.log(torch.tensor(2.0))
    c_end = torch.log(x_p/med_val)/torch.log(torch.tensor(2.0))
    exp_values = [c_start, (c_start+c_end)/2, c_end]
    output_list = []
    for sc in exp_values:
        img_out = img * (2.0**0.5)**sc
        img_out = torch.clamp(img_out, 0.0, 1.0)
        output_list.append(img_out)
    return output_list
'''

#v2
def mul_exp(img):
    # img: (B,C,H,W)
    B,C,H,W = img.shape

    x_p = 1.21497

    max_val = img.view(B,-1).max(dim=1)[0].view(B,1,1,1)
    med_val = img.view(B,-1).median(dim=1)[0].view(B,1,1,1)

    c_start = torch.log(x_p / max_val) / torch.log(torch.tensor(2.0))
    c_end = torch.log(x_p / med_val) / torch.log(torch.tensor(2.0))

    exp_values = [c_start, (c_start + c_end)/2.0, c_end]

    output_list = []
    for c in exp_values:
        sc = (2.0**0.5) ** c

        img_out = img * sc
        img_out = torch.clamp(img_out, 0.0, 1.0)
        output_list.append(img_out)

    return output_list

def writeLDR(img, path):
    # tensor, konwertuj na numpy i CPU
    if isinstance(img, torch.Tensor):
        img = img.detach().cpu().numpy()

    # (C,H,W) -> (H,W,C)
    img = np.transpose(img, (1,2,0))

    # clamp i 0-255
    img = np.clip(img, 0.0, 1.0)
    img = (img * 255).astype(np.uint8)

    # RGB -> BGR dla OpenCV
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    cv2.imwrite(path, img)

## Architektura sieci

In [27]:
# -------------------- NETWORK --------------------
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3,16,3,padding=1)
        self.conv2 = nn.Conv2d(16,32,3,padding=1)
        self.conv3 = nn.Conv2d(32,64,3,padding=1)
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        return x

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(64,32,3,padding=1)
        self.conv2 = nn.Conv2d(32,16,3,padding=1)
        self.conv3 = nn.Conv2d(16,3,3,padding=1)
    def forward(self,x, img1,img2,img3):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        x = torch.sigmoid(x + img1 + img2 + img3)
        return x

class TMONet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        # fusion convs
        self.fusion_conv1 = nn.Conv2d(64*3, 64, 3, padding=1) # zmniejszone ze względu na zbyt duze uzycie RAM w porowaniu do oryginalnej arhcitektury
        self.fusion_conv2 = nn.Conv2d(64, 64, 1)

    def forward(self, i0, i1, i2):
        # encode each input
        o0 = self.encoder(i0)
        o1 = self.encoder(i1)
        o2 = self.encoder(i2)

        # concat features
        o0 = torch.cat([o0, o1, o2], dim=1) # nadpisywanie dla optymalizacji wykorzystania RAM
        o1 = o2 = None

        # fusion block
        o0 = F.relu(self.fusion_conv1(o0))

        o0 = self.fusion_conv2(o0)

        # decode
        o0 = self.decoder(o0, i0, i1, i2)
        return o0

## Funkcja straty

In [28]:
# ---------------------------
# VGG feature extractor
vgg_model = models.vgg19(pretrained=True).features.to(device).eval()
for p in vgg_model.parameters():
    p.requires_grad = False

def vgg_features(x, layers=['0','5','10']):  # odpowiada VGG11,21,31 w TF
    features = []
    h = x
    for idx, layer in enumerate(vgg_model):
        h = layer(h)
        if str(idx) in layers:
            features.append(h)
    return features  # lista tensorow

# ---------------------------
# Gaussian kernel
_gaussian_cache = {} # optymalizacja aby raz wczytac
def gaussian_kernel(size=13, sigma=2.0, channels=3, device='cpu'):
    key = (size, sigma, channels)
    if key in _gaussian_cache:
        return _gaussian_cache[key].to(device)
    coords = np.arange(size) - size//2
    x, y = np.meshgrid(coords, coords)
    kernel = np.exp(-(x**2 + y**2)/(2*sigma**2))
    kernel = kernel / kernel.sum()
    kernel = torch.tensor(kernel, dtype=torch.float32, device=device)
    kernel = kernel.view(1,1,size,size).repeat(channels,1,1,1)
    _gaussian_cache[key] = kernel
    return kernel

def local_mean_std(x, kernel_size=13, sigma=2.0):
    C = x.shape[1]
    w = gaussian_kernel(kernel_size, sigma, C).to(x.device)
    x_pad = F.pad(x, (kernel_size//2,)*4, mode='reflect')
    mean_local = F.conv2d(x_pad, w, groups=C)
    mean_sq = F.conv2d(x_pad**2, w, groups=C)
    std_local = torch.sqrt(torch.clamp(mean_sq - mean_local**2, min=1e-8))
    return mean_local, std_local

# ---------------------------
# Feature contrast masking
def sign_num_den(x, gamma=0.5, beta=0.5, sigma=2.0, kernel_size=13):
    mean_local, std_local = local_mean_std(x, kernel_size, sigma)
    norm_num = torch.sign(x - mean_local) * torch.abs((x - mean_local)/(mean_local.abs()+1e-8))**gamma
    norm_den = 1.0 + (std_local / (mean_local.abs()+1e-8))**beta
    return norm_num, norm_den

def feature_contrast_masking(x, gamma=0.5, beta=0.5, sigma=2.0, kernel_size=13):
    num, den = sign_num_den(x, gamma, beta, sigma, kernel_size)
    return num / den

def masking_loss(pred, target, gamma=0.5, beta=0.5, sigma=2.0, kernel_size=13):
    f_pred = feature_contrast_masking(pred, gamma=1.0, beta=beta, sigma=sigma, kernel_size=kernel_size)
    f_target = feature_contrast_masking(target, gamma=gamma, beta=beta, sigma=sigma, kernel_size=kernel_size)
    return F.l1_loss(f_pred, f_target)

# ---------------------------
# FCM loss
def fcm_loss(pred, target, gamma=0.5, beta=0.5, sigma=2.0, kernel_size=13):
    feats_pred = vgg_features(pred, layers=['0','5','10'])   # odpowiada VGG11,21,31
    with torch.no_grad():
      feats_target = vgg_features(target, layers=['0','5','10'])
    loss_total = 0.0
    for f_pred, f_target in zip(feats_pred, feats_target):
        loss_total += masking_loss(f_pred, f_target, gamma=gamma, beta=beta, sigma=sigma, kernel_size=kernel_size)
    loss_total /= len(feats_pred)
    return loss_total



Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


100%|██████████| 548M/548M [00:03<00:00, 158MB/s]


# TRENING

## Konfiguracja treningu

parametry zaawansowane loss w argumentach domyslnych funkcji loss

In [30]:
epochs = 49

train_load = True # True jesli chcemy wczytac model - wybierz checkpoint nizej
checkpoint_path_train = CHECKPOINTS_FOLDER + "/model_11_16_4_25_47_49.pth"

## Inicjalizacja datasetu
odpowiednio ustawiony w argumentach podział danych

In [None]:
dataset = HDRDataset(REFERENCE_DATASET_FOLDER, limit=127)

loader = DataLoader(dataset, batch_size=1, shuffle=True)
val_dataset = HDRDataset(REFERENCE_DATASET_FOLDER, from_file=127, limit=18)

#val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [None]:
model = TMONet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)

In [None]:
if train_load:
  model.load_state_dict(torch.load(checkpoint_path_train, map_location=device))

In [None]:
os.makedirs(RESULTS_FOLDER, exist_ok=True)
os.makedirs(CHECKPOINTS_FOLDER, exist_ok=True)

## Trenowanie

In [None]:
if train_model:
  for epoch in range(epochs):
      print("")
      print(f"Epoch {epoch}")
      now = datetime.now()
      epoch_loss = 0.0
      step = 0

      # -------------------------
      # Trening
      for hdr_ulaw, hdr_rgb, fname in loader:
          step += 1

          hdr_ulaw = hdr_ulaw.to(device)
          hdr_rgb = hdr_rgb.to(device)

          hdr_rgb = F.interpolate(hdr_rgb, scale_factor=0.25, mode='bilinear', align_corners=False)
          hdr_ulaw = F.interpolate(hdr_ulaw, scale_factor=0.25, mode='bilinear', align_corners=False)

          imgs_exp = mul_exp(hdr_rgb)
          ldr_pred = model(imgs_exp[0], imgs_exp[1], imgs_exp[2])


          # dla loss (TYLKO dla loss) interpolujemy do mniejszej resolution dla optymalizacji zuzycia RAM
          #pred_small = F.interpolate(ldr_pred, scale_factor=0.25, mode='bilinear', align_corners=False)
          #hdr_ulaw = F.interpolate(hdr_ulaw, scale_factor=0.25, mode='bilinear', align_corners=False)
          pred_small = ldr_pred

          loss = fcm_loss(pred_small, hdr_ulaw)

          optimizer.zero_grad()

          loss.backward()

          optimizer.step()

          epoch_loss += loss.item()
          img_out = ldr_pred[0].detach()  # (C,H,W)

          #zapisanie obrazka
          img_out = torch.clamp(img_out, 0.0, 1.0)
          exten = f"_{epoch}.png"
          writeLDR(img_out, os.path.join(
              RESULTS_FOLDER,
              os.path.basename(fname[0]).replace(".exr", exten) ##UWAGA! Dodano zapisywanie po epoch, nie nadpisuja sie
          ))


          #if step % 10 == 0:
          print(step, fname[0])
          print(f"Step {step}, Loss: {loss.item():.6f}")

          if isinstance(img_out, torch.Tensor):
              img_out = img_out.detach().cpu().numpy()
          # (C,H,W) -> (H,W,C)
          img_out = np.transpose(img_out, (1,2,0))

          # clamp i 0-255
          img_out = np.clip(img_out, 0.0, 1.0)
          img_out = (img_out * 255).astype(np.uint8)

          # RGB -> BGR dla OpenCV
          img_out = cv2.cvtColor(img_out, cv2.COLOR_RGB2BGR)
          brisque_sdr = evaluate_image(img_out)
          print(brisque_sdr)


      #zapisanie modelu
      now = datetime.now()
      model_name = f'model_{now.month}_{now.day}_{now.hour}_{now.minute}_{now.second}_{epoch}.pth'
      torch.save(model.state_dict(), CHECKPOINTS_FOLDER + "/" + model_name)

      avg_loss = epoch_loss / step
      print("")
      print(f"Epoch {epoch} finished, Average Loss: {avg_loss:.6f}")

      # -------------------------
      # Walidacja - kontroluj! czy nie ma overfittingu
      if 'val_loader' in globals():  # tylko jeśli loader walidacyjny istnieje
          model.eval()
          val_loss = 0.0
          val_steps = 0
          # nie uczymy
          with torch.no_grad():
              for val_ulaw, val_rgb, val_fname in val_loader:
                  val_steps += 1
                  val_ulaw = val_ulaw.to(device)
                  val_rgb = val_rgb.to(device)

                  imgs_exp_val = mul_exp(val_rgb)
                  ldr_pred_val = model(imgs_exp_val[0], imgs_exp_val[1], imgs_exp_val[2])
                  # analogiczna optymalizacja dla RAM co w treningu
                  pred_small_val = F.interpolate(ldr_pred_val, scale_factor=0.25, mode='bilinear', align_corners=False)
                  target_small_val = F.interpolate(val_ulaw, scale_factor=0.25, mode='bilinear', align_corners=False)
                  loss_val = fcm_loss(pred_small_val, target_small_val)
                  val_loss += loss_val.item()

          avg_val_loss = val_loss / val_steps
          print(f"Validation Loss after Epoch {epoch}: {avg_val_loss:.6f}")
          model.train()

# Testowanie (zapisanie obrazów z sieci)
NIE KORZYSTA Z TEGO SAMEGO MODELU CO W TRENINGU! -> wczytaj wagi modelu w konfiguracji

## Konfiguracja testowania
USTAW MODEL (CHECKPOINT)!

In [29]:
checkpoint_path = CHECKPOINTS_FOLDER + "/model_11_16_4_25_47_49.pth"

## Inicjalizacja

In [31]:
test_dataset = HDRDataset(REFERENCE_DATASET_FOLDER, from_file=145, limit=35) # podział danych (zaczynamy od pliku indeks (nie nazwa!) 147 (wlacznie) max: 36 plików)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [32]:
model = TMONet().to(device) # wczytanie checkpointa odbywa się w sekcji predykcji

In [None]:
os.makedirs(RESULTS_FOLDER + '/TEST', exist_ok=True) # tutaj beda zapisane obrazki testowe

## Predykcja
zapisywane są results z wytrenowanego modelu do RESULTS_FOLDER + '/TEST/'

In [None]:
if test_model and 'test_loader' in globals():  # tylko jeśli loader testowy istnieje
    model.load_state_dict(torch.load(checkpoint_path, map_location=device)) #wczytanie wag modelu
    model.eval()
    test_loss = 0.0
    test_steps = 0
    with torch.no_grad():
        for test_ulaw, test_rgb, test_fname in test_loader:
            print(f"Predykcja: {test_fname[0]}")
            test_steps += 1
            test_ulaw = test_ulaw.to(device)
            test_rgb = test_rgb.to(device)

            imgs_exp_test = mul_exp(test_ulaw)
            ldr_pred_test = model(imgs_exp_test[0], imgs_exp_test[1], imgs_exp_test[2])

            img_out = ldr_pred_test[0].detach()  # (C,H,W)

            img_out = torch.clamp(img_out, 0.0, 1.0)
            writeLDR(img_out, os.path.join(
                RESULTS_FOLDER + '/TEST',
                os.path.basename(test_fname[0]).replace(".exr", ".png")
            ))
            print(f"Zapisano: {test_fname[0]}")
    model.train()

# Wyniki
Wymagane są rozpakowane dane z input_97, reference (hdr), oraz results (subfolder TEST)

## Inicjalizacja

In [33]:
# ZBIERAMY WSZYSTKIE SDR Z NASZEGO MODELU
sdr_paths = sorted(glob.glob(RESULTS_FOLDER + "/TEST/*.png"))
print(f"Found {len(sdr_paths)} PNG files.")

count = 0

Found 36 PNG files.


## Obliczenie metryk

In [34]:
calc_metrics = True

if calc_metrics:
  brisque_reinhard_sum = 0.0
  ssim_reinhard_sum = 0.0
  brisque_mantiuk_sum = 0.0
  ssim_mantiuk_sum = 0.0
  brisque_sdr_sum = 0.0
  ssim_sdr_sum = 0.0
  brisque_input97_sum = 0.0
  ssim_input97_sum = 0.0

  count = 0

  for path in sdr_paths:

      base = os.path.splitext(os.path.basename(path))[0]   # np '006'
      hdr_path = os.path.join(REFERENCE_DATASET_FOLDER, base + ".exr")

      print(f"\nProcessing SDR: {path}")
      print(f"Matching HDR:   {hdr_path}")

      hdr = read_exr(hdr_path)
      if hdr is None:
          print("Could not load HDR!")
          continue

      hdr_norm = np.zeros_like(hdr, dtype=np.float32)
      for c in range(hdr.shape[2]):
          ch_max = hdr[:,:,c].max()
          hdr_norm[:,:,c] = np.clip(hdr[:,:,c] / (ch_max + 1e-8), 0.0, 1.0)

      # SDR z input_97
      input97_path = os.path.join(INPUT_97_DATASET_FOLDER, base + ".png")
      if os.path.exists(input97_path):
          sdr_in = cv2.imread(input97_path, cv2.IMREAD_COLOR)
          if sdr_in is not None:
              sdr_in_rgb = cv2.cvtColor(sdr_in, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
              print(f"HDR shape: {hdr_norm.shape}, input_97 SDR shape: {sdr_in_rgb.shape}")
              brisque_in = evaluate_image(sdr_in_rgb)
              ssim_in = ssim(sdr_in_rgb, sdr_in_rgb, channel_axis=2, data_range=1.0)
              ssim_in = np.nan_to_num(ssim_in, nan=1.0)
              brisque_input97_sum += brisque_in
              ssim_input97_sum += ssim_in
              print("BRISQUE (input_97):", brisque_in)
              print(f"SSIM (input_97 vs input_97): {ssim_in}")
      # ------------------------------
      # Reinhard tonemap
      ldr_reinhard = tone_map_reinhard(hdr)
      print(f"Reinhard SDR shape: {ldr_reinhard.shape}")
      brisque_reinhard = evaluate_image(ldr_reinhard)
      ssim_reinhard = ssim(sdr_in_rgb, ldr_reinhard, channel_axis=2, data_range=1.0)
      ssim_reinhard = np.nan_to_num(ssim_reinhard, nan=1.0)
      brisque_reinhard_sum += brisque_reinhard
      ssim_reinhard_sum += ssim_reinhard
      print("BRISQUE (HDR→Reinhard):", brisque_reinhard)
      print("SSIM (input_97 vs Reinhard):", ssim_reinhard)

      # ------------------------------
      # Mantiuk tonemap
      ldr_mantiuk = tone_map_mantiuk(hdr)
      print(f"Mantiuk SDR shape: {ldr_mantiuk.shape}")
      brisque_mantiuk = evaluate_image(ldr_mantiuk)
      ssim_mantiuk = ssim(sdr_in_rgb, ldr_mantiuk, channel_axis=2, data_range=1.0)
      ssim_mantiuk = np.nan_to_num(ssim_mantiuk, nan=1.0)
      brisque_mantiuk_sum += brisque_mantiuk
      ssim_mantiuk_sum += ssim_mantiuk
      print("BRISQUE (HDR→Mantiuk):", brisque_mantiuk)
      print("SSIM (input_97 vs Mantiuk):", ssim_mantiuk)

      # ------------------------------
      # SDR z wytrenowanej sieci
      ldr_net = cv2.imread(path, cv2.IMREAD_COLOR)
      if ldr_net is None:
          print("Could not load PNG!")
          continue
      ldr_rgb = cv2.cvtColor(ldr_net, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
      print(f"HDR shape: {hdr_norm.shape}, SDR net shape: {ldr_rgb.shape}")
      brisque_sdr = evaluate_image(ldr_rgb)
      ssim_sdr = ssim(sdr_in_rgb, ldr_rgb, channel_axis=2, data_range=1.0)
      ssim_sdr = np.nan_to_num(ssim_sdr, nan=1.0)
      brisque_sdr_sum += brisque_sdr
      ssim_sdr_sum += ssim_sdr
      print("BRISQUE (SDR→net(ours)):", brisque_sdr)
      print(f"SSIM (input_97 vs SDR net(ours)): {ssim_sdr}")


      count += 1


Processing SDR: /content/drive/MyDrive/SIGK_P2/results/TEST/156.png
Matching HDR:   /content/reference_dataset/156.exr
HDR shape: (1280, 1888, 3), input_97 SDR shape: (1280, 1888, 3)
BRISQUE (input_97): 12.834332443768147
SSIM (input_97 vs input_97): 1.0
Reinhard SDR shape: (1280, 1888, 3)
BRISQUE (HDR→Reinhard): 105.28229246996906
SSIM (input_97 vs Reinhard): 1.0
Mantiuk SDR shape: (1280, 1888, 3)
BRISQUE (HDR→Mantiuk): 110.31718800074972
SSIM (input_97 vs Mantiuk): 1.0
HDR shape: (1280, 1888, 3), SDR net shape: (1280, 1888, 3)
BRISQUE (SDR→net(ours)): 85.9477888892196
SSIM (input_97 vs SDR net(ours)): 0.8204138875007629

Processing SDR: /content/drive/MyDrive/SIGK_P2/results/TEST/158.png
Matching HDR:   /content/reference_dataset/158.exr
HDR shape: (1280, 1888, 3), input_97 SDR shape: (1280, 1888, 3)
BRISQUE (input_97): 20.144943599592096
SSIM (input_97 vs input_97): 1.0
Reinhard SDR shape: (1280, 1888, 3)
BRISQUE (HDR→Reinhard): 17.440496987475825
SSIM (input_97 vs Reinhard): 0.358

## Wyświetlenie metryk

In [36]:
# -------------------------------------------------------
# ŚREDNIE
if count > 0:
    print("\n=============================================")
    print("                 AVERAGE METRICS")
    print(f"Files processed: {count}")

    print("\n--- Reinhard ---")
    print("Average BRISQUE:", brisque_reinhard_sum / count)
    print("Average SSIM:   ", ssim_reinhard_sum / count)

    print("\n--- Mantiuk ---")
    print("Average BRISQUE:", brisque_mantiuk_sum / count)
    print("Average SSIM:   ", ssim_mantiuk_sum / count)

    print("\n--- SDR (network output (ours)) ---")
    print("Average BRISQUE:", brisque_sdr_sum / count)
    print("Average SSIM:   ", ssim_sdr_sum / count)

    print("\n--- SDR (input_97) ---")
    print("Average BRISQUE:", brisque_input97_sum / count)
    print("Average SSIM:   ", ssim_input97_sum / count)
else:
    print("No files processed!")


                 AVERAGE METRICS
Files processed: 36

--- Reinhard ---
Average BRISQUE: 55.72110349068631
Average SSIM:    0.7163193

--- Mantiuk ---
Average BRISQUE: 103.63161421264304
Average SSIM:    0.8374267

--- SDR (network output (ours)) ---
Average BRISQUE: 57.79757917388013
Average SSIM:    0.63296425

--- SDR (input_97) ---
Average BRISQUE: 19.65719957425216
Average SSIM:    1.0


# Porównanie wizualne
Dla podanego pliku w zmiennej file_base (sekcja Wizualizacja) -> generowane są pliki w COMPARE_FOLDER

pliki z sufiksami dla każdej metody

HDR_norm uzywane dla SSIM jako obraz referencyjny

## Inicjalizacja


In [None]:
os.makedirs(COMPARE_FOLDER, exist_ok=True)

## Wizualizacja


In [None]:
# Podaj nazwę pliku, np '006'
file_base = '157'

hdr_path = os.path.join(REFERENCE_DATASET_FOLDER, file_base + ".exr")
sdr_net_path = os.path.join(RESULTS_FOLDER, "TEST", file_base + ".png")
sdr_input97_path = os.path.join(INPUT_97_DATASET_FOLDER, file_base + ".png")

# 1. Wczytaj HDR
hdr = read_exr(hdr_path)

# normalizacja kanałowa HDR -> HDR_norm (używane przez metryki SSIM jako obraz referencyjny)
hdr_norm = np.zeros_like(hdr, dtype=np.float32)
for c in range(hdr.shape[2]):
    ch_max = hdr[:,:,c].max()
    hdr_norm[:,:,c] = np.clip(hdr[:,:,c] / (ch_max + 1e-8), 0.0, 1.0)

print("hdr min/max:", hdr.min(), hdr.max())
print("hdr_norm   min/max:", hdr_norm.min(), hdr_norm.max())

print("hdr mean/std:", hdr.mean(), hdr.std())
print("hdr_norm   mean/std:", hdr_norm.mean(), hdr_norm.std())

# 2. Tonemap Reinhard
ldr_reinhard = tone_map_reinhard(hdr)

# 3. Tonemap Mantiuk
ldr_mantiuk = tone_map_mantiuk(hdr)

print("Reinhard min/max:", ldr_reinhard.min(), ldr_reinhard.max())
print("Mantiuk   min/max:", ldr_mantiuk.min(), ldr_mantiuk.max())

print("Reinhard mean/std:", ldr_reinhard.mean(), ldr_reinhard.std())
print("Mantiuk   mean/std:", ldr_mantiuk.mean(), ldr_mantiuk.std())

# 4. SDR sieci
ldr_net = cv2.imread(sdr_net_path, cv2.IMREAD_COLOR)
ldr_net = cv2.cvtColor(ldr_net, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

# 5. SDR z input_97
sdr_input97 = None
if os.path.exists(sdr_input97_path):
    sdr_input97 = cv2.imread(sdr_input97_path, cv2.IMREAD_COLOR)
    sdr_input97 = cv2.cvtColor(sdr_input97, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

# Wyświetlanie(problemy w colab)/zapisywanie do folderu
def show_img(window_name, img):
    #img_show = np.clip(img, 0.0, 1.0)
    #img_show = (img_show * 255).astype(np.uint8)
    #print(window_name)
    #cv2_imshow(img)
    filename = f"{file_base}_{window_name}.png"
    if img.ndim == 3 and img.shape[2] == 3:
        img_to_save = np.transpose(img, (2,0,1))
    else:
        img_to_save = img

    img_to_save = np.clip(img_to_save, 0.0, 1.0).astype(np.float32)
    writeLDR(img_to_save, os.path.join(COMPARE_FOLDER, filename))
    print(f"Saved: {filename}")

# Wyświetl wszystkie obrazy
show_img("HDR", hdr)

show_img("HDR norm", hdr_norm)
show_img("Reinhard", ldr_reinhard)
show_img("Mantiuk", ldr_mantiuk)
show_img("SDR net (ours)", ldr_net)

if sdr_input97 is not None:
    show_img("SDR input_97", sdr_input97)
    print("input97 min/max:", sdr_input97.min(), sdr_input97.max())
    print("sdr net   min/max:", ldr_net.min(), ldr_net.max())

    print("input97 mean/std:", sdr_input97.mean(), sdr_input97.std())
    print("sdr net   mean/std:", ldr_net.mean(), ldr_net.std())

cv2.waitKey(0)
cv2.destroyAllWindows()