In [11]:
!pip install torch
!pip install torchvision
!pip install tensorboard
!pip install numpy
!pip install matplotlib



In [12]:
!pip install imagecodecs
!pip install einops

Collecting imagecodecs
  Downloading imagecodecs-2024.9.22-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Downloading imagecodecs-2024.9.22-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (43.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.3/43.3 MB[0m [31m38.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: imagecodecs
Successfully installed imagecodecs-2024.9.22
Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


In [110]:
!pip install opencv-python
!pip install scikit-image



# HDRNet

##Utils

In [149]:
import json
import matplotlib.pyplot as plt
import numpy as np
import os
import re
import torch
import torch.nn.functional as F
from matplotlib.ticker import MaxNLocator
import cv2
from skimage.metrics import structural_similarity as ssim



def psnr(pred, target):
    return 10 * torch.log10(1 / F.mse_loss(pred, target))

def calculate_ssim(image1, image2):
    # Convert to grayscale if the images are colored
    if len(image1.shape) == 3:
        image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
    if len(image2.shape) == 3:
        image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY)

    ssim_value, _ = ssim(image1, image2, full=True)
    return ssim_value


def print_params(params):
    print('Training parameters: ')
    print('\n'.join('  {} = {}'.format(k, str(v)) for k, v in params.items()))
    print()

def get_files(path):
    files = os.listdir(path)
    files = [os.path.join(path, x) for x in files]
    files.sort()
    return files

def load_train_ckpt(model, ckpt_dir):
    # Get latest
    files = os.listdir(ckpt_dir)
    if not files:
        return
    files = [os.path.join(ckpt_dir, x) for x in files]
    files.sort(key=lambda f: int(re.sub('\D', '', f)))
    ckpt_path = files[-1]
    prev_epochs = -1
    prev_epochs = int(ckpt_path.split('_')[1])
    print("epochs ", prev_epochs)
    # Load ckpt
    print('Loading:', ckpt_path)
    state_dict = torch.load(ckpt_path)
    state_dict.pop('params')
    model.load_state_dict(state_dict)
    return prev_epochs

def load_test_ckpt(ckpt_path):
    state_dict = torch.load(ckpt_path)
    params = state_dict['params']
    state_dict.pop('params')
    return state_dict, params

def save_model_stats(model, params, ckpt_fname, stats):
    ckpt_path = os.path.join(params['ckpt_dir'], ckpt_fname)
    state_dict = model.state_dict()
    state_dict['params'] = params
    torch.save(state_dict, ckpt_path)
    # Save stats
    stats_path = os.path.join(params['stats_dir'], 'stats.json')
    with open(stats_path, 'w') as fp:
        json.dump(stats, fp, indent=2)


class AvgMeter(object):
    """Acumulate and compute average."""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0.
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## Layers

In [112]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image

def conv_layer(in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, activation=nn.ReLU, batch_norm=False):
    layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=bias)]
    if batch_norm:
        layers.append(nn.BatchNorm2d(out_channels))
    if activation:
        layers.append(activation())
    return nn.Sequential(*layers)

def fc_layer(in_channels, out_channels, bias=True, activation=nn.ReLU, batch_norm=False):
    layers = [nn.Linear(int(in_channels), int(out_channels), bias=bias)]
    if batch_norm:
        layers.append(nn.BatchNorm1d(out_channels))
    if activation:
        layers.append(activation())
    return nn.Sequential(*layers)

def slicing(grid, guide):#grid N, C=12, D=8, H=16, W=16  # guide N, C=1, H, W
    N, C, H, W = guide.shape
    device = grid.get_device()
    if device >= 0:
        hh, ww = torch.meshgrid(torch.arange(H, device=device), torch.arange(W, device=device)) # H, W
    else:
        hh, ww = torch.meshgrid(torch.arange(H), torch.arange(W)) # H, W
    # To [-1, 1] range for grid_sample
    hh = hh / (H - 1) * 2 - 1
    ww = ww / (W - 1) * 2 - 1
    guide = guide * 2 - 1
    hh = hh[None, :, :, None].repeat(N, 1, 1, 1) # N, H, W, C=1
    ww = ww[None, :, :, None].repeat(N, 1, 1, 1)  # N, H, W, C=1
    guide = guide.permute(0, 2, 3, 1) # N, H, W, C=1

    guide_coords = torch.cat([ww, hh, guide], dim=3) # N, H, W, 3    guide-> D channel
    # unsqueeze because extra D dimension
    guide_coords = guide_coords.unsqueeze(1) # N, Dout=1, H, W, 3 # H W->final size
    sliced = F.grid_sample(grid, guide_coords, align_corners=False, padding_mode="border") # N, C=12, Dout=1, H, W
    sliced = sliced.squeeze(2) # N, C=12, H, W

    return sliced

def apply(sliced, fullres):
    # r' = w1*r + w2*g + w3*b + w4
    rr = fullres * sliced[:, 0:3, :, :] # N, C=3, H, W
    gg = fullres * sliced[:, 4:7, :, :] # N, C=3, H, W
    bb = fullres * sliced[:, 8:11, :, :] # N, C=3, H, W
    rr = torch.sum(rr, dim=1) + sliced[:, 3, :, :] # N, H, W
    gg = torch.sum(gg, dim=1) + sliced[:, 7, :, :] # N, H, W
    bb = torch.sum(bb, dim=1) + sliced[:, 11, :, :] # N, H, W
    output = torch.stack([rr, gg, bb], dim=1) # N, C=3, H, W
    return output

## Modules

In [113]:
import torch.nn.functional as F
from PIL import Image
from torchvision.transforms.functional import resize
from einops import rearrange

class SPSA_Attention(nn.Module):
    def __init__(self, dim, num_heads,is_material_mask, is_spec, bias):
        super(SPSA_Attention, self).__init__()
        self.num_heads = num_heads
        self.is_material_mask = is_material_mask
        self.is_spec = is_spec
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
        if self.is_spec:
            self.project_out1_x = nn.Conv2d(dim//2, dim//2,  kernel_size=3, stride=1, padding=1, bias=bias)
            self.project_out1_spec = nn.Conv2d(dim//2, dim//2,  kernel_size=3, stride=1, padding=1, bias=bias)
            self.project_out2_x = nn.Conv2d(dim, dim//2,  kernel_size=3, stride=1, padding=1, bias=bias)
            self.project_out2_spec = nn.Conv2d(dim, dim//2,  kernel_size=3, stride=1, padding=1, bias=bias)
        else:
            self.project_out1 = nn.Conv2d(dim, dim,  kernel_size=3, stride=1, padding=1, bias=bias)
            self.project_out2 = nn.Conv2d(dim, dim,  kernel_size=3, stride=1, padding=1, bias=bias)

        # ===========================mask condition===========================
        self.ResBlock_SFTk = ResBlock_SFT(input_channel = dim,input_mask_dim=1)
        self.ResBlock_SFTq = ResBlock_SFT(input_channel = dim,input_mask_dim=1)
        self.out_sft1 = ResBlock_SFT(input_channel = dim,input_mask_dim=1)
        self.out_sft2 = ResBlock_SFT(input_channel = dim,input_mask_dim=1)
        self.out_sft3 = ResBlock_SFT(input_channel = dim//2,input_mask_dim=1)
        self.out_sft4 = ResBlock_SFT(input_channel = dim//2,input_mask_dim=1)

    def forward(self, x_in,spec, material_mask):
        x = torch.cat([spec,x_in],dim=1)
        b,c,h,w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q,k,v = qkv.chunk(3, dim=1)

        #material semantic prior
        if self.is_material_mask:
            q = self.ResBlock_SFTk(q,material_mask)
            k = self.ResBlock_SFTq(k,material_mask)

        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

        x_out = self.project_out2_x(out) + self.project_out1_x(x_in)
        spec_out = self.project_out2_spec(out) + self.project_out1_spec(spec)
        return x_out,spec_out

class SFTLayer(nn.Module):
    def __init__(self,dim,input_mask_dim):
        super(SFTLayer, self).__init__()
        self.SFT_scale_conv0 = nn.Conv2d(input_mask_dim, dim//2, kernel_size=1, stride=1, padding=0, bias=True) #nn.Conv2d(32, 32, 1)
        self.SFT_scale_conv1 = nn.Conv2d(dim//2, dim, kernel_size=1, stride=1, padding=0, bias=True)
        self.SFT_shift_conv0 = nn.Conv2d(input_mask_dim, dim//2, kernel_size=1, stride=1, padding=0, bias=True)
        self.SFT_shift_conv1 = nn.Conv2d(dim//2, dim, kernel_size=1, stride=1, padding=0, bias=True)

    def forward(self, x,seg):
        bt, c, h, w = x.shape
        seg = resize(seg, (h, w), Image.BILINEAR)
        scale = self.SFT_scale_conv1(F.leaky_relu(self.SFT_scale_conv0(seg), 0.1, inplace=True))
        shift = self.SFT_shift_conv1(F.leaky_relu(self.SFT_shift_conv0(seg), 0.1, inplace=True))
        return x * (scale + 1) + shift
class ResBlock_SFT(nn.Module):
    def __init__(self,input_channel,input_mask_dim):
        super(ResBlock_SFT, self).__init__()
        self.sft0 = SFTLayer(dim = input_channel,input_mask_dim = input_mask_dim)
        self.conv0 = nn.Conv2d(input_channel, input_channel, kernel_size=3, stride=1, padding=1, bias=True)
        self.sft1 = SFTLayer(dim = input_channel,input_mask_dim = input_mask_dim)
        self.conv1 = nn.Conv2d(input_channel, input_channel, kernel_size=3, stride=1, padding=1, bias=True)

    def forward(self, x,seg):
        # x[0]: fea; x[1]: cond
        fea = self.sft0(x,seg)
        fea = F.relu(self.conv0(fea), inplace=True)
        fea = self.sft1(fea, seg)
        fea = self.conv1(fea)
        return x + fea


class SegExtract(nn.Module):
    def __init__(self, params, c_in=1):
        super(SegExtract, self).__init__()
        self.params = params
        self.relu = nn.ReLU()

        self.splat1 = nn.Conv2d(c_in, 8, kernel_size=3, stride=1, padding=1, bias=True)#conv_layer(c_in, 8, kernel_size=3, stride=1, padding=1, batch_norm=False)
        self.maxpool1 = nn.MaxPool2d(kernel_size=(3,3),stride=(2,2),padding=(1,1))#
        self.splat2 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1, bias=True)#conv_layer(8, 16, kernel_size=3, stride=1, padding=1, batch_norm=params['batch_norm'])
        self.maxpool2 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))  #
        #
        self.splat1_up = nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1, bias=True)#conv_layer(16, 8, kernel_size=3, stride=1, padding=1, batch_norm=params['batch_norm'])
        self.splat2_up = nn.Conv2d(8, 1, kernel_size=3, stride=1, padding=1, bias=True)#conv_layer(8, 1, kernel_size=3, stride=1, padding=1, batch_norm=params['batch_norm'])
        self.sigmoid = nn.Sigmoid()
    def forward(self, x_in):
        x_in = resize(x_in, (16, 16), Image.BILINEAR)

        x1 = self.splat1(x_in)
        x1 = self.maxpool1(x1)
        x1 = self.splat2(x1)
        x_low1 = self.maxpool2(x1)

        x1 = self.splat1_up(x_low1)
        x1 = F.interpolate(x1, size=(8,8), mode='bilinear')
        x1 = self.splat2_up(x1)
        a = F.interpolate(x1, size=(16,16), mode='bilinear')#self.upsamp2(x1)

        out = 1.0+x_in*(1+a)
        return out

class BrightnessAdaptation(nn.Module):
    def __init__(self, params, c_in=1):
        super(BrightnessAdaptation, self).__init__()
        self.params = params
        self.relu = nn.ReLU()

        self.splat1 = nn.Conv2d(c_in, 8, kernel_size=3, stride=1, padding=1, bias=True)
        self.splat1_2 = nn.Conv2d(c_in, 8, kernel_size=3, stride=1, padding=1, bias=True)
        self.maxpool1 = nn.MaxPool2d(kernel_size=(3,3),stride=(2,2),padding=(1,1))#
        self.splat2 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1, bias=True)
        self.splat2_2 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1, bias=True)
        self.maxpool2 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))  #

        self.splat1_up = nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1, bias=True)
        self.splat1_up2 = nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1, bias=True)
        self.upsamp1 = torch.nn.Upsample(size=(params['output_res'][0]//2,params['output_res'][1]//2), mode='bilinear')
        self.splat2_up = nn.Conv2d(8, 1, kernel_size=3, stride=1, padding=1, bias=True)
        self.splat2_up2 = nn.Conv2d(8, 1, kernel_size=3, stride=1, padding=1, bias=True)
        self.upsamp2 = torch.nn.Upsample(size=(params['output_res'][0],params['output_res'][1]), mode='bilinear')
        self.sigmoid = nn.Sigmoid()
    def forward(self, x_in,fullres):
        x1 = self.splat1(x_in)
        x1 = self.maxpool1(x1)
        x1 = self.splat2(x1)
        x_low1 = self.maxpool2(x1)

        x1 = self.splat1_up(x_low1)
        x1 = F.interpolate(x1, size=(fullres.size()[-2:][0]//2,fullres.size()[-2:][1]//2), mode='bilinear')
        x1 = self.splat2_up(x1)
        a = F.interpolate(x1, size=(fullres.size()[-2:][0],fullres.size()[-2:][1]), mode='bilinear')

        x2 = self.splat1_2(x_in)
        x2 = self.maxpool1(x2)
        x2 = self.splat2_2(x2)
        x_low2 = self.maxpool2(x2)

        x2 = self.splat1_up2(x_low2)
        x2 = F.interpolate(x2, size=(fullres.size()[-2:][0]//2,fullres.size()[-2:][1]//2), mode='bilinear')
        x2 = self.splat2_up2(x2)
        b = F.interpolate(x2, size=(fullres.size()[-2:][0],fullres.size()[-2:][1]), mode='bilinear')
        out = x_in*(1+a)+b
        out = torch.clamp(out, 0.01, 1)
        return out

## Models

In [114]:
import numpy as np

class FeatureExtract(nn.Module):
    def __init__(self, params, c_in = 3):
        super(FeatureExtract, self).__init__()
        self.params = params
        self.relu = nn.ReLU()
        # ===========================attention===========================
        if self.params['spec']:
            self.attn1 = SPSA_Attention(dim=8*2, num_heads=1, is_material_mask = self.params['material_mask'], is_spec = self.params['spec'], bias=False)
            self.attn2 = SPSA_Attention(dim=16*2, num_heads=1,is_material_mask = self.params['material_mask'], is_spec = self.params['spec'], bias=False)
            self.attn3 = SPSA_Attention(dim=32*2, num_heads=1,is_material_mask = self.params['material_mask'], is_spec = self.params['spec'], bias=False)
            self.attn4 = SPSA_Attention(dim=64*2, num_heads=1,is_material_mask = self.params['material_mask'], is_spec = self.params['spec'], bias=False)
        else:
            self.attn1 = SPSA_Attention(dim=8, num_heads=1, is_material_mask = self.params['material_mask'], is_spec = self.params['spec'], bias=False)
            self.attn2 = SPSA_Attention(dim=16, num_heads=1,is_material_mask = self.params['material_mask'], is_spec = self.params['spec'], bias=False)
            self.attn3 = SPSA_Attention(dim=32, num_heads=1,is_material_mask = self.params['material_mask'], is_spec = self.params['spec'], bias=False)
            self.attn4 = SPSA_Attention(dim=64, num_heads=1,is_material_mask = self.params['material_mask'], is_spec = self.params['spec'], bias=False)
        # ===========================Fusion===========================
        self.fusion1 = conv_layer(16, 8,  kernel_size=3, stride=1, padding=1, batch_norm=params['batch_norm'])
        self.fusion2 = conv_layer(32, 16,  kernel_size=3, stride=1, padding=1, batch_norm=params['batch_norm'])
        self.fusion3 = conv_layer(64, 32,  kernel_size=3, stride=1, padding=1, batch_norm=params['batch_norm'])
        self.fusion4 = conv_layer(128, 64,  kernel_size=3, stride=1, padding=1, batch_norm=params['batch_norm'])
        # ===========================Splat===========================
        self.splat1 = conv_layer(c_in, 8,  kernel_size=3, stride=2, padding=1, batch_norm=False)
        self.splat2 = conv_layer(8,    16, kernel_size=3, stride=2, padding=1, batch_norm=params['batch_norm'])
        self.splat3 = conv_layer(16,   32, kernel_size=3, stride=2, padding=1, batch_norm=params['batch_norm'])
        self.splat4 = conv_layer(32,   64, kernel_size=3, stride=2, padding=1, batch_norm=params['batch_norm'])

        self.splat1_spec = conv_layer(10,    8, kernel_size=3, stride=2, padding=1, batch_norm=False) #12.18
        self.splat2_spec = conv_layer(8,    16, kernel_size=3, stride=2, padding=1, batch_norm=params['batch_norm'])
        self.splat3_spec = conv_layer(16,   32, kernel_size=3, stride=2, padding=1, batch_norm=params['batch_norm'])
        self.splat4_spec = conv_layer(32,   64, kernel_size=3, stride=2, padding=1, batch_norm=params['batch_norm'])
        # ===========================Global mine===========================
        # Conv until 4x4
        self.global1 = conv_layer(64, 128, kernel_size=3, stride=2, padding=1, batch_norm=params['batch_norm'])
        self.global2 = conv_layer(128, 256, kernel_size=3, stride=2, padding=1, batch_norm=params['batch_norm'])
        self.global3 = conv_layer(256, 128, kernel_size=3, stride=2, padding=1, batch_norm=params['batch_norm'])
        self.global4 = conv_layer(128, 64, kernel_size=3, stride=2, padding=1, batch_norm=params['batch_norm'])
        self.global5 = conv_layer(64, 64, kernel_size=3, stride=2, padding=1, batch_norm=params['batch_norm'])

        # ===========================Local===========================
        self.local1 = conv_layer(64, 64, kernel_size=3, padding=1, batch_norm=params['batch_norm'])
        self.local2 = conv_layer(64, 64, kernel_size=3, padding=1, bias=False, activation=None)

        # ===========================predicton===========================
        self.pred = conv_layer(64, 96, kernel_size=1, activation=None) # 64 -> 96

    def forward(self, x, spec, material_mask):
        N = x.shape[0]
        # ===========================Splat===========================
        x = self.splat1(x) # N, C=8,  H=128, W=128
        if self.params['spec']:
            spec = self.splat1_spec(spec) # N, C=8,  H=128, W=128
            x,spec = self.attn1(x, spec, material_mask)

        x = self.splat2(x) # N, C=16, H=64,  W=64
        if self.params['spec']:
            spec = self.splat2_spec(spec) # N, C=8,  H=128, W=128
            x,spec = self.attn2(x, spec, material_mask)

        x = self.splat3(x) # N, C=32, H=32,  W=32
        if self.params['spec']:
            spec = self.splat3_spec(spec) # N, C=8,  H=128, W=128
            x,spec = self.attn3(x, spec, material_mask)

        x = self.splat4(x) # N, C=64, H=16,  W=16
        if self.params['spec']:
            spec = self.splat4_spec(spec) # N, C=8,  H=128, W=128
            x,spec = self.attn4(x, spec, material_mask)

        splat_out = x # N, C=64, H=16,  W=16
        # ===========================Global mine===========================
        # convs
        x = self.global1(x)
        x = self.global2(x)
        # flatten
        x = self.global3(x)
        x = self.global4(x)
        x = self.global5(x)
        global_out = x.squeeze(2).squeeze(2)
        # ===========================Local===========================
        x = splat_out
        x = self.local1(x)
        x = self.local2(x)
        local_out = x
        # ===========================Fusion===========================
        global_out = global_out[:, :, None, None] # N, 64， 1， 1
        fusion = self.relu(local_out + global_out) # N, C=64, H=16, W=16
        # ===========================Prediction===========================
        x = self.pred(fusion) # N, C=96, H=16, W=16
        x = x.view(N, 12, 8, 16, 16)#16, 16) # N, C=12, D=8, H=16, W=16
        return x

class Coefficients(nn.Module):
    def __init__(self, params, c_in=3):
        super(Coefficients, self).__init__()
        self.params = params
        self.relu = nn.ReLU()
        # ===========================FeatureExtract===========================
        self.FeatureExtract0 = FeatureExtract(params,c_in=3)
        self.FeatureExtract1 = FeatureExtract(params,c_in=3)
        self.FeatureExtract2 = FeatureExtract(params,c_in=3)
        self.FeatureExtract3 = FeatureExtract(params,c_in=3)
        self.FeatureExtract4 = FeatureExtract(params,c_in=3)
        self.FeatureExtract5 = FeatureExtract(params,c_in=3)
        self.SegExtract0 = SegExtract(params)
        self.SegExtract1 = SegExtract(params)
        self.SegExtract2 = SegExtract(params)
        self.SegExtract3 = SegExtract(params)
        self.SegExtract4 = SegExtract(params)
        self.SegExtract5 = SegExtract(params)

    def forward(self, x, spec, material_mask):
        #FeatureExtract
        if self.params['material_mask']:
            x0 = self.FeatureExtract0(x,spec, material_mask[:,0,:,:].unsqueeze(1))
            x1 = self.FeatureExtract1(x,spec, material_mask[:,1,:,:].unsqueeze(1))
            x2 = self.FeatureExtract2(x,spec, material_mask[:,2,:,:].unsqueeze(1))
            x3 = self.FeatureExtract3(x,spec, material_mask[:,3,:,:].unsqueeze(1))
            x4 = self.FeatureExtract4(x,spec, material_mask[:,4,:,:].unsqueeze(1))
            x5 = self.FeatureExtract5(x,spec, material_mask[:,5,:,:].unsqueeze(1))
            x0_seg = self.SegExtract0(material_mask[:,0,:,:].unsqueeze(1)).unsqueeze(1)
            x1_seg = self.SegExtract1(material_mask[:,1,:,:].unsqueeze(1)).unsqueeze(1)
            x2_seg = self.SegExtract2(material_mask[:,2,:,:].unsqueeze(1)).unsqueeze(1)
            x3_seg = self.SegExtract3(material_mask[:,3,:,:].unsqueeze(1)).unsqueeze(1)
            x4_seg = self.SegExtract4(material_mask[:,4,:,:].unsqueeze(1)).unsqueeze(1)
            x5_seg = self.SegExtract5(material_mask[:,5,:,:].unsqueeze(1)).unsqueeze(1)
            x = x0*x0_seg + x1*x1_seg + x2*x2_seg + x3*x3_seg + x4*x4_seg + x5*x5_seg
        else:
            x = self.FeatureExtract0(x,spec, material_mask[:,0,:,:].unsqueeze(1))
        return x


class Guide(nn.Module):
    def __init__(self, params, c_in=3):
        super(Guide, self).__init__()
        self.params = params
        # Number of relus/control points for the curve
        self.nrelus = 16
        self.c_in = c_in
        self.M = nn.Parameter(torch.eye(c_in, dtype=torch.float32) + torch.randn(1, dtype=torch.float32) * 1e-4) # (c_in, c_in)
        self.M_bias = nn.Parameter(torch.zeros(c_in, dtype=torch.float32)) # (c_in,)
        # The shifts/thresholds in x of relus
        thresholds = np.linspace(0, 1, self.nrelus, endpoint=False, dtype=np.float32) # (nrelus,)
        thresholds = torch.tensor(thresholds) # (nrelus,)
        thresholds = thresholds[None, None, None, :] # (1, 1, 1, nrelus)
        thresholds = thresholds.repeat(1, 1, c_in, 1) # (1, 1, c_in, nrelus)
        self.thresholds = nn.Parameter(thresholds) # (1, 1, c_in, nrelus)
        # The slopes of relus
        slopes = torch.zeros(1, 1, 1, c_in, self.nrelus, dtype=torch.float32) # (1, 1, 1, c_in, nrelus)
        slopes[:, :, :, :, 0] = 1.0
        self.slopes = nn.Parameter(slopes)

        self.relu = nn.ReLU()
        self.bias = nn.Parameter(torch.tensor(0, dtype=torch.float32))

    def forward(self, x,material_mask,nir):
        x = x.permute(0, 2, 3, 1) # N, H, W, C=3
        old_shape = x.shape # (N, H, W, C=3)

        x = torch.matmul(x.reshape(-1, self.c_in), self.M) # N*H*W, C=3
        x = x + self.M_bias
        x = x.reshape(old_shape) # N, H, W, C=3
        x = x.unsqueeze(4) # N, H, W, C=3, 1
        x = torch.sum(self.slopes * self.relu(x - self.thresholds), dim=4) # N, H, W, C=3

        x = x.permute(0, 3, 1, 2) # N, C=3, H, W

        x = torch.sum(x, dim=1, keepdim=True) / self.c_in # N, C=1, H, W
        x = x + self.bias # N, C=1, H, W
        x = torch.clamp(x, 0, 1) # N, C=1, H, W
        return x


class JDMHDRnetModel(nn.Module):
    def __init__(self, params):
        super(JDMHDRnetModel, self).__init__()
        self.coefficients = Coefficients(params)
        self.BrightnessAdaptation1 = BrightnessAdaptation(params)
        self.BrightnessAdaptation2 = BrightnessAdaptation(params)
        self.BrightnessAdaptation3 = BrightnessAdaptation(params)
        self.guide = Guide(params)

    def forward(self, lowres, fullres,spec,material_mask,nir):
        #step1 Brightness Adaptation
        hue = self.BrightnessAdaptation1(nir,fullres)
        hue_out = self.BrightnessAdaptation2(nir,fullres)
        hue_spec = self.BrightnessAdaptation3(nir, fullres)
        hue_lowres = F.interpolate(hue, size=lowres.size()[-2:],mode='bilinear')
        hue_spec = F.interpolate(hue_spec, size=lowres.size()[-2:],mode='bilinear')
        fullres = fullres/hue
        lowres = lowres/hue_lowres
        spec = spec/hue_spec
        # step2 grid coefficient predict
        grid = self.coefficients(lowres,spec,material_mask)# N, C=12, D=8, H=16, W=16
        # step3 guide map
        guide = self.guide(fullres,material_mask,nir) # N, C=1, H, W
        #step4 slicing
        sliced = slicing(grid, guide)
        #step5 generate output
        output = apply(sliced, fullres)
        output = output * hue_out

        return output

## Datasets

In [115]:
import numpy as np
import os
import torch
from PIL import Image
from skimage import io
from torchvision import transforms
from torchvision.transforms.functional import resize
from torch.utils.data import Dataset

class BaseDataset(Dataset):
    def get_tif(self, path, is_jdm_predict):
        memory_tif = []
        memory_tif_input = {}
        memory_tif_output = {}
        memory_tif_nir = {}
        memory_spec = {}
        for file_name in os.listdir(path+'/target'):
            fname = file_name
            input = io.imread(os.path.join(path, 'source', fname.split('.')[0] + '.tif'))
            output = io.imread(os.path.join(path, 'target', fname.split('.')[0] + '.tif'))
            if is_jdm_predict:
                nir = io.imread(os.path.join(path, 'nir_jdm', fname.split('.')[0] + '.png'))
                nir = nir[:, :, 0]
                nir = (nir // 32 + 1)#1-8
                nir = nir / 8.0#0.125-1.0
            else:
                nir = io.imread(os.path.join(path, 'nir', fname.split('.')[0] + '.tif'))
            spec = np.load(os.path.join(path, 'spec_npy10band', fname.split('.')[0] + '.npy'), mmap_mode=None,
                           allow_pickle=False, fix_imports=True, encoding='ASCII')

            memory_tif_input.update({fname.split('.')[0]:input})
            memory_tif_output.update({fname.split('.')[0]:output})
            memory_tif_nir.update({fname.split('.')[0]:nir})
            memory_spec.update({fname.split('.')[0]:spec})
        memory_tif.append(memory_tif_input)
        memory_tif.append(memory_tif_output)
        memory_tif.append(memory_tif_nir)
        memory_tif.append(memory_spec)
        return memory_tif

    def load_img_hdr(self, fname,read_memory = False):
        if read_memory:
            input = self.memory_tif[0][fname.split('.')[0]]
            output = self.memory_tif[1][fname.split('.')[0]]
            nir_ori = self.memory_tif[2][fname.split('.')[0]]
            spec = self.memory_tif[3][fname.split('.')[0]]
        else:
            input = io.imread(os.path.join(self.data_path, 'source', fname.split('.')[0] + '.tif'))
            output = io.imread(os.path.join(self.data_path, 'target', fname.split('.')[0] + '.tif'))
            if self.params['jdm_predict']:
                nir_ori = io.imread(os.path.join(self.data_path, 'nir_jdm', fname.split('.')[0] + '.png'))
                nir_ori = nir_ori[:, :, 0]
                nir_ori = (nir_ori // 32 + 1)#1-8
                nir_ori = nir_ori / 8.0#0.125-1.0
            else:
                nir_ori = io.imread(os.path.join(self.data_path, 'nir', fname.split('.')[0] + '.tif'))
            spec = np.load(os.path.join(self.data_path, 'spec_npy10band', fname.split('.')[0] + '.npy'), mmap_mode=None,
                           allow_pickle=False, fix_imports=True, encoding='ASCII')
        if self.params['jdm_predict']:
            seg = io.imread(os.path.join(self.data_path, 'seg_jdm', fname.split('.')[0] + '.png'))
        else:
            seg = io.imread(os.path.join(self.data_path, 'seg', fname.split('.')[0] + '.png'))

        spec = spec.transpose((1, 0, 2))[:,::-1,:]
        spec = np.ascontiguousarray(spec)

        #############segmentation#################
        #CLASSES = ('sky','tree','building','trunk','road')
        #PALETTE =[[19,19,194], [43,139,3], [248,232,109], [78,50,12],[102,102,100]]
        sky_mask = np.where((seg[:,:,0]==19)&(seg[:,:,1]==19)&(seg[:,:,2]==194),1,0)
        tree_mask = np.where((seg[:,:,0]==43)&(seg[:,:,1]==139)&(seg[:,:,2]==3),1,0)
        building_mask = np.where((seg[:,:,0]==248)&(seg[:,:,1]==232)&(seg[:,:,2]==109),1,0)
        trunk_mask = np.where((seg[:,:,0]==78)&(seg[:,:,1]==50)&(seg[:,:,2]==12),1,0)
        road_mask = np.where((seg[:,:,0]==102)&(seg[:,:,1]==102)&(seg[:,:,2]==100),1,0)
        others_mask = np.where((seg[:,:,0]==8)&(seg[:,:,1]==8)&(seg[:,:,2]==8),1,0)

        sky_mask = np.expand_dims(sky_mask,axis=2)
        tree_mask = np.expand_dims(tree_mask,axis=2)
        building_mask = np.expand_dims(building_mask,axis=2)
        trunk_mask = np.expand_dims(trunk_mask,axis=2)
        road_mask = np.expand_dims(road_mask,axis=2)
        others_mask = np.expand_dims(others_mask,axis=2)

        material_mask = np.concatenate((sky_mask,tree_mask,building_mask,trunk_mask,\
                                         road_mask,others_mask ),axis=2)

        spec = np.asarray(spec, dtype=np.float32)
        input = np.asarray(input, dtype=np.float32)
        output = np.asarray(output, dtype=np.float32)
        material_mask = np.asarray(material_mask, dtype=np.float32)
        nir_ori = np.asarray(nir_ori, dtype=np.float32)

        #constraint the nir range
        nir_ori = nir_ori[:,:,np.newaxis]
        input_max = np.max(input,axis = 2)
        nir_max = input_max / np.max(input)
        nir_max = nir_max[:, :, np.newaxis]
        nir = np.maximum(nir_max,nir_ori)

        input = torch.from_numpy(input.transpose((2, 0, 1)))
        nir = torch.from_numpy(nir.transpose((2, 0, 1)))
        output = torch.from_numpy(output.transpose((2, 0, 1)))
        spec = torch.from_numpy(spec.transpose((2, 0, 1)))
        material_mask = torch.from_numpy(material_mask.transpose((2, 0, 1)))

        return input, output, spec, material_mask,nir

    def __len__(self):
        return len(self.input_paths)


class Train_Dataset(BaseDataset):
    """Class for training images."""

    def __init__(self, params=None):
        self.data_path = params['train_data_dir']
        self.input_paths = get_files(os.path.join(self.data_path, 'source'))
        self.memory_tif = self.get_tif(self.data_path,params['jdm_predict'])
        self.input_res = params['input_res']
        self.output_res = params['output_res']

        self.augment = transforms.Compose([
            transforms.RandomCrop(self.output_res),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
        ])
        self.params = params

    def __getitem__(self, idx):

        fname = self.input_paths[idx].split('/')[-1]
        if self.params['hdr']:
            input, output, spec, material_mask,nir = self.load_img_hdr(fname,read_memory=True)
        # Check dimensions before crop
        assert input.shape == output.shape
        assert self.output_res[0] <= input.shape[2]
        assert self.output_res[1] <= input.shape[1]
        # Crop
        inout = torch.cat([input,output,spec,material_mask,nir],dim=0)
        inout = self.augment(inout)

        full = inout[:3,:,:]
        low = resize(full, (self.input_res, self.input_res), Image.BILINEAR)
        output = inout[3:6,:,:]

        spec = inout[6:6+spec.shape[0],:,:]
        material_mask = inout[6+spec.shape[0]:6+spec.shape[0]+6,:,:]
        nir = inout[6+spec.shape[0]+6:6+spec.shape[0]+7,:,:]
        spec_tmp = resize(spec, (self.params['spec_size'], self.params['spec_size']), Image.BILINEAR)
        spec_low = resize(spec_tmp, (self.input_res, self.input_res), Image.BILINEAR)

        return low, full, output, spec_low,material_mask,nir

class Eval_Dataset(BaseDataset):
    """Class for validation images."""

    def __init__(self, params=None):
        self.data_path = params['eval_data_dir']
        self.input_paths = get_files(os.path.join(self.data_path,  'source'))#'input'))
        # self.memory_tif = self.get_tif(self.data_path)
        self.input_res = params['input_res']
        self.params = params

    def __getitem__(self, idx):
        fname = self.input_paths[idx].split('/')[-1]
        if self.params['hdr']:
            full, output, spec, material_mask,nir = self.load_img_hdr(fname,read_memory=False)
        low = resize(full, (self.input_res, self.input_res), Image.BILINEAR)
        spec_tmp = resize(spec, (self.params['spec_size'], self.params['spec_size']), Image.BILINEAR)
        spec_low = resize(spec_tmp, (self.input_res, self.input_res), Image.BILINEAR)
        return low, full, output, spec_low, material_mask, nir

## Train

In [116]:
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from argparse import ArgumentParser
from torch.optim import Adam, lr_scheduler
from torchvision.transforms.functional import vflip, hflip
from torch.utils.data import DataLoader
from torchvision.utils import save_image



os.environ["CUDA_VISIBLE_DEVICES"]='0'


def train(params, train_loader, valid_loader, model, ep, device):
    
#     if torch.cuda.device_count() > 1:
#         print(f"Using {torch.cuda.device_count()} GPUs!")
#         # Wrap the entire model in DataParallel
#         model = nn.DataParallel(model)
    
    model = model.to(device)
    
    # Optimization
    optimizer = Adam(model.parameters(), params['learning_rate'], weight_decay=1e-8)
    if not params['material_mask']:
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
        patience=params['epochs']*3, factor=0.5, verbose=True)
    # Loss function
    criterion = nn.MSELoss()
    # Training
    train_loss_meter = AvgMeter()
    train_psnr_meter = AvgMeter()
    early_stopping_counter = 100
    stats = {'train_loss': [],
             'train_psnr': [],
             'valid_psnr': []}
    iteration = (ep+1)*160 / params['batch_size']
    # print(iteration)
    old_time = time.time()

    for epoch in range(ep+1,params['epochs']):
        for batch_idx, (low, full, target, spec, material_mask,nir) in enumerate(train_loader):
            iteration += 1
            model.train()

            low = low.to(device)
            full = full.to(device)
            target = target.to(device)
            spec = spec.to(device)
            material_mask = material_mask.to(device)
            nir = nir.to(device)

            if params['debugsave']:
                full_image = full[0, :, :, :]
                full_save =  (full_image -torch.min(full_image))/ (torch.max(full_image)- torch.min(full_image)) #* 255
                save_image(full_save, os.path.join(params['eval_out'], str(batch_idx)+'_inputfull.tif'))
                save_image(nir[0,:,:,:], os.path.join(params['eval_out'], str(batch_idx)+'_nir.tif'))
                ori_imgae = nir[0, :, :, :]*full[0,:,:,:]
                ori_save = (ori_imgae - torch.min(ori_imgae)) / (torch.max(ori_imgae) - torch.min(ori_imgae)) #* 255
                save_image(ori_save, os.path.join(params['eval_out'], str(batch_idx) + '_ori.tif'))
                save_image(target[0, :, :, :] / nir[0,:,:,:] /255, os.path.join(params['eval_out'], str(batch_idx) + '_targetR.png'))

                save_image(target[0,:,:,:]/255, os.path.join(params['eval_out'], str(batch_idx)+'_target.png'))
                save_image(spec[0,0,:,:]/65535, os.path.join(params['eval_out'], str(batch_idx)+'_spec.png'))

                # save_image(full[0,:,:,:]/65535*255, os.path.join(params['eval_out'], str(batch_idx)+'_full.tif'))
                # save_image(target[0,:,:,:]/255, os.path.join(params['eval_out'], str(batch_idx)+'_target.png'))
                # save_image(spec[0,0,:,:]/65535, os.path.join(params['eval_out'], str(batch_idx)+'_spec.png'))
                save_image(material_mask[0,0,:,:], os.path.join(params['eval_out'], str(batch_idx)+'_material_sky_mask.png'))
                save_image(material_mask[0,1,:,:], os.path.join(params['eval_out'], str(batch_idx)+'_material_tree_mask.png'))
                save_image(material_mask[0,2,:,:], os.path.join(params['eval_out'], str(batch_idx)+'_material_building_mask.png'))
                save_image(material_mask[0,3,:,:], os.path.join(params['eval_out'], str(batch_idx)+'_material_trunk_mask.png'))
                save_image(material_mask[0,4,:,:], os.path.join(params['eval_out'], str(batch_idx)+'_material_road_mask.png'))
                save_image(material_mask[0,5,:,:], os.path.join(params['eval_out'], str(batch_idx)+'_material_others_mask.png'))
            # Normalize to [0, 1] on GPU
            if params['hdr']:
                low = torch.div(low, 65535.0)
                full = torch.div(full, 65535.0)
                spec = torch.div(spec, 65535.0)
            else:
                low = torch.div(low, 255.0)
                full = torch.div(full, 255.0)
            target = torch.div(target, 255.0)

            output = model(low, full, spec, material_mask,nir)

            loss = criterion(output, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if not params['material_mask']:
                scheduler.step(loss)

            if iteration % params['summary_interval'] == 0:
                train_loss_meter.update(loss.item())
                train_psnr = psnr(output, target).item()
                train_psnr_meter.update(train_psnr)
                new_time = time.time()
                print('[%d/%d] Iteration: %d | Loss: %.4f | PSNR: %.4f | lr: %.8f | Time: %.2fs' %
                        (epoch+1, params['epochs'], iteration, loss, train_psnr, optimizer.param_groups[0]['lr'], new_time-old_time))
                old_time = new_time

            if iteration % params['ckpt_interval'] == 0:
                stats['train_loss'].append(train_loss_meter.avg)
                train_loss_meter.reset()
                stats['train_psnr'].append(train_psnr_meter.avg)
                train_psnr_meter.reset()
                valid_psnr = eval(params, valid_loader, model, device, epoch)
                stats['valid_psnr'].append(valid_psnr)
                ckpt_fname = "epoch_" + str(epoch)+'_iter_' + str(iteration) + ".pt"
                save_model_stats(model, params, ckpt_fname, stats)
                if(valid_psnr >= 30):
                    return

def eval(params, valid_loader, model, device,epoch):
    model.eval()
    psnr_meter = AvgMeter()
    with torch.no_grad():
        for batch_idx, (low, full, target, spec, material_mask,nir) in enumerate(valid_loader):
            low = low.to(device)
            full = full.to(device)
            target = target.to(device)
            spec = spec.to(device)
            nir = nir.to(device)
            material_mask = material_mask.to(device)

            # Normalize to [0, 1] on GPU
            if params['hdr']:
                low =  torch.div(low, 65535.0)
                full = torch.div(full, 65535.0)
                spec = torch.div(spec, 65535.0)
            else:
                low = torch.div(low, 255.0)
                full = torch.div(full, 255.0)
            target = torch.div(target, 255.0)


            output= model(low, full, spec, material_mask,nir)

            # output = output * nir

            # save_image(output, os.path.join(params['eval_out'], 'epoch'+str(epoch)+'_'+str(batch_idx)+'.png'))


            eval_psnr = psnr(output, target).item()
            print(str(batch_idx)+'.png',eval_psnr)
            psnr_meter.update(eval_psnr)

    print ("Validation PSNR: ", psnr_meter.avg)

    return psnr_meter.avg


def train_main(params, first_time=True):
    # Random seeds
    seed = 0
    torch.backends.cudnn.deterministic = True # False
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    print_params(params)

    os.makedirs(params['ckpt_dir'], exist_ok=True)
    os.makedirs(params['stats_dir'], exist_ok=True)
    os.makedirs(params['eval_out'], exist_ok=True)

    train_dataset = Train_Dataset(params)
    train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True)

    valid_dataset = Eval_Dataset(params)
    valid_loader = DataLoader(valid_dataset, batch_size=1)

    model = JDMHDRnetModel(params)
    prev_epochs = load_train_ckpt(model, params['ckpt_dir'])
    if first_time:
        prev_epochs = -1
    print("prev_epochs ", prev_epochs)
    if params['cuda']:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    model.to(device)

    train(params, train_loader, valid_loader, model, prev_epochs, device)

## Test

In [117]:
import numpy as np
import os
import torch
from argparse import ArgumentParser
from torch.utils.data import DataLoader
from torchvision.utils import save_image

os.environ["CUDA_VISIBLE_DEVICES"]='0'

def eval(params, valid_loader, model, device,epoch):
    model.eval()
    psnr_meter = AvgMeter()
    with torch.no_grad():
        for batch_idx, (low, full, target, spec, material_mask,nir) in enumerate(valid_loader):
            low = low.to(device)
            full = full.to(device)
            target = target.to(device)
            spec = spec.to(device)
            nir = nir.to(device)
            material_mask = material_mask.to(device)

            # Normalize to [0, 1] on GPU
            if params['hdr']:
                low =  torch.div(low, 65535.0)
                full = torch.div(full, 65535.0)
                spec = torch.div(spec, 65535.0)
            else:
                low = torch.div(low, 255.0)
                full = torch.div(full, 255.0)
            target = torch.div(target, 255.0)

            output= model(low, full, spec, material_mask,nir)
            mul = 1;
            if batch_idx >=10 and batch_idx<100:
                mul=2
            if batch_idx >=100 and batch_idx<1000:
                mul=3
            save_image(output, os.path.join(params['eval_out'], '0'*(4-mul) + str(batch_idx)+'.png'))

            eval_psnr = psnr(output, target).item()
            print(str(batch_idx)+'.png',eval_psnr)
            psnr_meter.update(eval_psnr)

    print ("Validation PSNR: ", psnr_meter.avg)

    return psnr_meter.avg



def test(params):
    # Random seeds
    seed = 0
    torch.backends.cudnn.deterministic = True # False
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    print_params(params)

    os.makedirs(params['ckpt_dir'], exist_ok=True)
    os.makedirs(params['stats_dir'], exist_ok=True)
    os.makedirs(params['eval_out'], exist_ok=True)

    valid_dataset = Eval_Dataset(params)
    valid_loader = DataLoader(valid_dataset, batch_size=1)

    model = JDMHDRnetModel(params)
    load_train_ckpt(model, params['ckpt_dir'])
    if params['cuda']:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    model.to(device)

    valid_psnr = eval(params, valid_loader, model, device, epoch=params['epochs'])

In [118]:
def parse_params(jdm_predict=False):
  params = {
      "cuda": True,
      "ckpt_interval": 600,
      "ckpt_dir": "./ckptsJdm",  # Make this different for Ideal and predicted case
      "stats_dir": "./stats_jdm", # Make this different for Ideal and predicted case
      "epochs": 6000,
      "learning_rate": 1e-4,
      "summary_interval": 10,
      "batch_size": 16,
      "train_data_dir": '/kaggle/input/mobile-spec/Mobile-Spec/train',  # Change this to your path
      "eval_data_dir": '/kaggle/input/mobile-spec/Mobile-Spec/eval',    # Change this to your path
      "eval_out": "./outputs_jdm",
      "hdr": True,
      "jdm_predict": jdm_predict,
      "batch_norm": False,
      "input_res": 256,
      "output_res": (512, 512),
      "spec_size": 16,
      "spec": True,
      "material_mask": True,
      "debugsave": False,
  }
  return params


In [119]:
# def delete_all_files(directory_path):
#     # Loop through all files in the directory
#     for filename in os.listdir(directory_path):
#         file_path = os.path.join(directory_path, filename)
        
#         # Check if it's a file before deleting
#         if os.path.isfile(file_path):
#             os.remove(file_path)
#             print(f"Deleted: {file_path}")
#         else:
#             print(f"Skipped (not a file): {file_path}")

# # Example usage
# directory_path = '/kaggle/working/outputs_jdm'  # Replace with your directory path
# delete_all_files(directory_path)

### Predicted priors training

In [120]:
params = parse_params(True)

In [121]:
train_main(params, True) # True if training for first time, False if there's already saved ckpt.

In [123]:
test(params)

Training parameters: 
  cuda = True
  ckpt_interval = 600
  ckpt_dir = ./ckptsJdm
  stats_dir = ./stats_jdm
  epochs = 6000
  learning_rate = 0.0001
  summary_interval = 10
  batch_size = 16
  train_data_dir = /kaggle/input/mobile-spec/Mobile-Spec/train
  eval_data_dir = /kaggle/input/mobile-spec/Mobile-Spec/eval
  eval_out = ./outputs_jdm
  hdr = True
  jdm_predict = True
  batch_norm = False
  input_res = 256
  output_res = (512, 512)
  spec_size = 16
  spec = True
  material_mask = True
  debugsave = False

epochs  2879
Loading: ./ckptsJdm/epoch_2879_iter_28800.0_jdm.pt


  state_dict = torch.load(ckpt_path)


0.png 27.73189353942871
1.png 31.836633682250977
2.png 30.904296875
3.png 30.982990264892578
4.png 32.085533142089844
5.png 28.85420036315918
6.png 28.37204360961914
7.png 34.19792556762695
8.png 23.507020950317383
9.png 29.33184051513672
10.png 30.91350746154785
11.png 29.12506103515625
12.png 30.225200653076172
13.png 25.5830078125
14.png 24.38079833984375
15.png 29.393585205078125
16.png 32.894622802734375
17.png 25.306655883789062
18.png 30.861827850341797
19.png 26.626699447631836
20.png 25.449974060058594
21.png 27.393892288208008
22.png 24.431962966918945
23.png 28.57193946838379
24.png 28.185894012451172
25.png 27.369670867919922
26.png 28.384002685546875
27.png 34.350948333740234
28.png 26.63776397705078
29.png 26.087993621826172
30.png 33.71211242675781
31.png 33.79047393798828
32.png 27.36471176147461
33.png 30.91245460510254
34.png 29.494464874267578
35.png 26.7781925201416
36.png 31.977962493896484
37.png 30.213830947875977
38.png 26.34627914428711
39.png 29.91012573242187

### Ideal Priors training

In [124]:
params = parse_params(False)
params['ckpt_dir'] = './ckptsIdeal'
params['eval_out'] = './outputs_ideal'
params['stats_dir'] = './stats_ideal'

In [125]:
train_main(params, True)

In [126]:
test(params)

Training parameters: 
  cuda = True
  ckpt_interval = 600
  ckpt_dir = ./ckptsIdeal
  stats_dir = ./stats_ideal
  epochs = 6000
  learning_rate = 0.0001
  summary_interval = 10
  batch_size = 16
  train_data_dir = /kaggle/input/mobile-spec/Mobile-Spec/train
  eval_data_dir = /kaggle/input/mobile-spec/Mobile-Spec/eval
  eval_out = ./outputs_ideal
  hdr = True
  jdm_predict = False
  batch_norm = False
  input_res = 256
  output_res = (512, 512)
  spec_size = 16
  spec = True
  material_mask = True
  debugsave = False

epochs  2879
Loading: ./ckptsIdeal/epoch_2879_iter_28800.0.pt


  state_dict = torch.load(ckpt_path)


0.png 26.598548889160156
1.png 30.439321517944336
2.png 30.95160484313965
3.png 30.279176712036133
4.png 33.38441467285156
5.png 33.31548309326172
6.png 25.933650970458984
7.png 31.057437896728516
8.png 23.274181365966797
9.png 29.386669158935547
10.png 29.240585327148438
11.png 29.924623489379883
12.png 30.9648494720459
13.png 27.624462127685547
14.png 24.656953811645508
15.png 26.830293655395508
16.png 30.957992553710938
17.png 26.00284767150879
18.png 30.350610733032227
19.png 30.518394470214844
20.png 26.626794815063477
21.png 26.125925064086914
22.png 23.82185935974121
23.png 28.421512603759766
24.png 25.644432067871094
25.png 26.469913482666016
26.png 33.29463195800781
27.png 31.12220573425293
28.png 25.500797271728516
29.png 28.88542938232422
30.png 31.094627380371094
31.png 34.26521682739258
32.png 26.461318969726562
33.png 33.399662017822266
34.png 30.7695369720459
35.png 29.26552963256836
36.png 31.84986114501953
37.png 29.237306594848633
38.png 29.804649353027344
39.png 25.7

## Calculating ssim

In [132]:
from skimage import io, img_as_float
from skimage.metrics import structural_similarity as ssim
import cv2

In [150]:

# Function to load images from two directories and calculate SSIM for each pair
def load_and_calculate_ssim(dir1, dir2):
    ssim_values = []
#     E_values = []

    # Get sorted list of image filenames from both directories
    files1 = sorted([f for f in os.listdir(dir1) if f.endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tif'))])
    files2 = sorted([f for f in os.listdir(dir2) if f.endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tif'))])

    # Ensure both directories have the same number of images
    if len(files1) != len(files2):
        print(len(files1))
        print(len(files2))
        print("The number of images in both directories must be the same.")
        return

    # Loop through the files and calculate SSIM for each pair
    for file1, file2 in zip(files1, files2):
        # Load images from both directories
        image1_path = os.path.join(dir1, file1)
        image2_path = os.path.join(dir2, file2)
        image1 = cv2.imread(image1_path)
        image2 = cv2.imread(image2_path)

        if image1 is not None and image2 is not None:
            # Calculate SSIM for the current pair of images
            ssim_value = calculate_ssim(image1, image2)
#             E_value = calculate_delta_e_tiff(image1_path, image2_path)
#             print(ssim_value)
            ssim_values.append(ssim_value)
#             E_values.append(E_value)
        else:
            print(f"Error loading images {file1} or {file2}")
    
    return ssim_values

# Example usage
dir1 = '/kaggle/working/outputs_jdm'  # Replace with your directory path
dir2 = '/kaggle/input/mobile-spec/Mobile-Spec/eval/target'  # Replace with your directory path

ssim_results = load_and_calculate_ssim(dir1, dir2)
# print(ssim_results)
# Print SSIM results
ssim = np.mean(ssim_results)
print(f'SSIM : {ssim}')




SSIM : 0.9629230426738514
