In [1]:
import os, sys 
import itertools
import collections
import random
import time 
import logging

from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pylab as plt
from matplotlib.pyplot import imread

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torch.optim import Adam, Adadelta, SGD
from torch.profiler import profile, record_function, ProfilerActivity

#########################################################################################################
torch.set_float32_matmul_precision('high')
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ROOT_PROJECT = "/home/jovyan/shared/Thuc/hoodsgatedrive/projects/ImageRestoration-Development-Unrolling/"
ROOT_DATASET = "/home/jovyan/shared/Thuc/hoodsgatedrive/projects/"

#########################################################################################################

sys.path.append(os.path.join(ROOT_PROJECT, 'exploration/model_multiscale_mixture_GLR/lib'))
from dataloader import ImageSuperResolution




In [2]:

LOG_DIR = os.path.join(ROOT_PROJECT, "exploration/model_multiscale_mixture_GLR/result/model_test11/logs/")
LOGGER = logging.getLogger("main")
logging.basicConfig(
    format='%(asctime)s: %(message)s', 
    datefmt='%m/%d/%Y %I:%M:%S %p',
    filename=os.path.join(LOG_DIR, 'training00.log'), 
    level=logging.INFO
)

CHECKPOINT_DIR = os.path.join(ROOT_PROJECT, "exploration/model_multiscale_mixture_GLR/result/model_test11/checkpoints/")
VERBOSE_RATE = 1000

(H_train, W_train) = (128, 128)
(H_val, W_val) = (128, 128)
(H_test, W_test) = (496, 496)

train_dataset = ImageSuperResolution(
    csv_path=os.path.join(ROOT_DATASET, "dataset/DFWB_training_data_info.csv"),
    dist_mode="addictive_noise_scale",
    lambda_noise=25.0,
    patch_size=H_train,
    patch_overlap_size=H_train//2,
    max_num_patchs=1000000,
    root_folder=ROOT_DATASET,
    logger=LOGGER,
    device=torch.device("cpu"),
)

validation_dataset = ImageSuperResolution(
    csv_path=os.path.join(ROOT_DATASET, "dataset/CBSD68_testing_data_info.csv"),
    dist_mode="addictive_noise_scale",
    lambda_noise=25.0,
    patch_size=H_val,
    patch_overlap_size=H_train//2,
    max_num_patchs=1000000,
    root_folder=ROOT_DATASET,
    logger=LOGGER,
    device=torch.device("cpu"),
)

test_dataset = ImageSuperResolution(
    csv_path=os.path.join(ROOT_DATASET, "dataset/McMaster_testing_data_info.csv"),
    dist_mode="addictive_noise_scale",
    lambda_noise=25.0,
    patch_size=H_test,
    patch_overlap_size=H_train//2,
    max_num_patchs=1000000,
    root_folder=ROOT_DATASET,
    logger=LOGGER,
    device=torch.device("cpu"),
)


data_train_batched = torch.utils.data.DataLoader(
    train_dataset, batch_size=4, num_workers=4
)

data_valid_batched = torch.utils.data.DataLoader(
    validation_dataset, batch_size=16, num_workers=4
)

data_test_batched = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, num_workers=4
)

In [3]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias)

class MeanShift(nn.Conv2d):
    def __init__(
        self, rgb_range,
        rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):

        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
        for p in self.parameters():
            p.requires_grad = False

class BasicBlock(nn.Sequential):
    def __init__(
        self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
        bn=True, act=nn.ReLU(True)):

        m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
        if bn:
            m.append(nn.BatchNorm2d(out_channels))
        if act is not None:
            m.append(act)

        super(BasicBlock, self).__init__(*m)

class ResBlock(nn.Module):
    def __init__(
        self, conv, n_feats, kernel_size,
        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if i == 0:
                m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        res += x

        return res

class Upsampler(nn.Sequential):
    def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):

        m = []
        if (scale & (scale - 1)) == 0:    # Is scale = 2^n?
            for _ in range(int(math.log(scale, 2))):
                m.append(conv(n_feats, 4 * n_feats, 3, bias))
                m.append(nn.PixelShuffle(2))
                if bn:
                    m.append(nn.BatchNorm2d(n_feats))
                if act == 'relu':
                    m.append(nn.ReLU(True))
                elif act == 'prelu':
                    m.append(nn.PReLU(n_feats))

        elif scale == 3:
            m.append(conv(n_feats, 9 * n_feats, 3, bias))
            m.append(nn.PixelShuffle(3))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if act == 'relu':
                m.append(nn.ReLU(True))
            elif act == 'prelu':
                m.append(nn.PReLU(n_feats))
        else:
            raise NotImplementedError

        super(Upsampler, self).__init__(*m)



class MDSR(nn.Module):
    def __init__(self, args, conv=default_conv):
        super(MDSR, self).__init__()
        n_resblocks = args["n_resblocks"]
        n_feats = args["n_feats"]
        kernel_size = 3
        act = nn.ReLU(True)
        self.scale_idx = 0
        self.sub_mean = MeanShift(args["rgb_range"])
        self.add_mean = MeanShift(args["rgb_range"], sign=1)

        m_head = [conv(args["n_colors"], n_feats, kernel_size)]

        self.pre_process = nn.ModuleList([
            nn.Sequential(
                ResBlock(conv, n_feats, 5, act=act),
                ResBlock(conv, n_feats, 5, act=act)
            ) for _ in args["scale"]
        ])

        m_body = [
            ResBlock(
                conv, n_feats, kernel_size, act=act
            ) for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        self.upsample = nn.ModuleList([
            Upsampler(conv, int(s), n_feats, act=False) for s in args["scale"]
        ])

        m_tail = [conv(n_feats, args["n_colors"], kernel_size)]

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)

    def forward(self, x):
        x = self.sub_mean(x)
        x = self.head(x)
        x = self.pre_process[self.scale_idx](x)

        res = self.body(x)
        res += x

        x = self.upsample[self.scale_idx](res)
        x = self.tail(x)
        x = self.add_mean(x)

        return x

    def set_scale(self, scale_idx):
        self.scale_idx = scale_idx

In [17]:
model = MDSR({
    "n_resblocks":16,
    "n_feats": 64,
    "rgb_range": 255,
    "n_colors": 3,
    "scale": "1"
}).to(DEVICE)
model

MDSR(
  (sub_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (add_mean): MeanShift(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (pre_process): ModuleList(
    (0): Sequential(
      (0): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        )
      )
      (1): ResBlock(
        (body): Sequential(
          (0): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
          (1): ReLU(inplace=True)
          (2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        )
      )
    )
  )
  (upsample): ModuleList(
    (0): Upsampler()
  )
  (head): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (body): Sequential(
    (0): ResBlock(
      (body): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1

In [18]:
for img_test_noisy, img_test_clean in data_test_batched:

    img_test_noisy = img_test_noisy.to(DEVICE)
    img_test_clean = img_test_clean.to(DEVICE) 
    break


In [19]:
with torch.no_grad():
    out = model(img_test_noisy.permute(0, 3, 1, 2))

In [20]:
out.shape

torch.Size([1, 3, 496, 496])

In [21]:
img_test_noisy.shape

torch.Size([1, 496, 496, 3])

In [22]:
1984 / 496

4.0

In [None]:
import os, sys 
import itertools
import collections
import random
import time 
import logging

from PIL import Image
import numpy as np
import pandas as pd
import matplotlib.pylab as plt
from matplotlib.pyplot import imread

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torch.optim import Adam, AdamW

#########################################################################################################
torch.set_float32_matmul_precision('high')
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ROOT_PROJECT = "/home/jovyan/shared/Thuc/hoodsgatedrive/projects/ImageRestoration-Development-Unrolling/"
ROOT_DATASET = "/home/jovyan/shared/Thuc/hoodsgatedrive/projects/"

#########################################################################################################

sys.path.append(os.path.join(ROOT_PROJECT, 'exploration/model_multiscale_mixture_GLR/lib'))
from dataloader import ImageSuperResolution
import baselineRestormer as model_structure


LOG_DIR = os.path.join(ROOT_PROJECT, "exploration/model_multiscale_mixture_GLR/result/model_test_unet/logs/")
LOGGER = logging.getLogger("main")
logging.basicConfig(
    format='%(asctime)s: %(message)s', 
    datefmt='%m/%d/%Y %I:%M:%S %p',
    filename=os.path.join(LOG_DIR, 'training00.log'), 
    level=logging.INFO
)

CHECKPOINT_DIR = os.path.join(ROOT_PROJECT, "exploration/model_multiscale_mixture_GLR/result/model_test_unet/checkpoints/")
VERBOSE_RATE = 1000

(H_train, W_train) = (128, 128)
(H_val, W_val) = (128, 128)
(H_test, W_test) = (496, 496)

train_dataset = ImageSuperResolution(
    csv_path=os.path.join(ROOT_DATASET, "dataset/DFWB_training_data_info.csv"),
    dist_mode="addictive_noise_scale",
    lambda_noise=25.0,
    patch_size=H_train,
    patch_overlap_size=H_train//2,
    max_num_patchs=1000000,
    root_folder=ROOT_DATASET,
    logger=LOGGER,
    device=torch.device("cpu"),
)

validation_dataset = ImageSuperResolution(
    csv_path=os.path.join(ROOT_DATASET, "dataset/CBSD68_testing_data_info.csv"),
    dist_mode="addictive_noise",
    lambda_noise=25.0,
    patch_size=H_val,
    patch_overlap_size=H_train//2,
    max_num_patchs=1000000,
    root_folder=ROOT_DATASET,
    logger=LOGGER,
    device=torch.device("cpu"),
)

test_dataset = ImageSuperResolution(
    csv_path=os.path.join(ROOT_DATASET, "dataset/McMaster_testing_data_info.csv"),
    dist_mode="addictive_noise",
    lambda_noise=25.0,
    patch_size=H_test,
    patch_overlap_size=0,
    max_num_patchs=1000000,
    root_folder=ROOT_DATASET,
    logger=LOGGER,
    device=torch.device("cpu"),
)


data_train_batched = torch.utils.data.DataLoader(
    train_dataset, batch_size=4, num_workers=4
)

data_valid_batched = torch.utils.data.DataLoader(
    validation_dataset, batch_size=16, num_workers=4
)

data_test_batched = torch.utils.data.DataLoader(
    test_dataset, batch_size=1, num_workers=4
)

NUM_EPOCHS = 45

CONNECTION_FLAGS = np.array([
    1,1,1,
    1,0,1,
    1,1,1,
]).reshape((3,3))


model = model_structure.Restormer(**{
    "inp_channels":3, 
    "out_channels":3, 
    "dim": 48,
    "num_blocks": [4,6,6,8], 
    "num_refinement_blocks": 4,
    "heads": [1,2,4,8],
    "ffn_expansion_factor": 2.66,
    "bias": False,
    "LayerNorm_type": 'BiasFree',   ## Other option 'BiasFree'
    "dual_pixel_task": False        ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
}).to(DEVICE)



In [None]:

s = 0
for p in model.parameters():
    s += np.prod(np.array(p.shape))
    # print(p.dtype, np.array(p.shape), s)

LOGGER.info(f"Init model with total parameters: {s}")

criterian = nn.L1Loss()
optimizer = AdamW(
    model.parameters(),
    lr=0.0003,
    eps=1e-08
)

### TRAINING
LOGGER.info("######################################################################################")
LOGGER.info("BEGIN TRAINING PROCESS")
# training_state_path = os.path.join(CHECKPOINT_DIR, 'checkpoints_epoch00_iter0024k.pt')
# training_state = torch.load(training_state_path)
# model.load_state_dict(training_state["model"])
# optimizer.load_state_dict(training_state["optimizer"])
# i_checkpoint=training_state["i"]


for epoch in range(NUM_EPOCHS):

    model.train()

    i = 0
    ### TRAINING
    for patchs_noisy, patchs_true in data_train_batched:
        s = time.time()
        optimizer.zero_grad()
        patchs_noisy = patchs_noisy.to(DEVICE)
        patchs_true = patchs_true.to(DEVICE) 
        reconstruct_patchs = model(patchs_noisy.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
        loss_value = criterian(reconstruct_patchs, patchs_true)
        loss_value.backward()
        optimizer.step()

        img_true = np.clip(patchs_true.detach().cpu().numpy(), a_min=0.0, a_max=1.0).astype(np.float64)
        img_recon = np.clip(reconstruct_patchs.detach().cpu().numpy(), a_min=0.0, a_max=1.0).astype(np.float64)
        train_mse_value = np.square(img_true- img_recon).mean()
        train_psnr = 10 * np.log10(1/train_mse_value)
        LOGGER.info(f"iter={i} time={time.time()-s} psnr={train_psnr} mse={train_mse_value}")

        if (i%VERBOSE_RATE == 0):
            checkpoint = { 
                'i': i,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(checkpoint, os.path.join(CHECKPOINT_DIR, f'checkpoints_epoch{str(epoch).zfill(2)}_iter{str(i//VERBOSE_RATE).zfill(4)}k.pt'))


        if (i%(VERBOSE_RATE/5) == 0):
            # LOGGER.info(f"Start VALIDATION EPOCH {epoch} - iter={i}")
            # model.graph_frame_recalibrate(H_val, W_val)

            # ### VALIDAING
            model.eval()
            list_val_mse = []
            val_i = 0
            for val_patchs_noisy, val_patchs_true in data_valid_batched:
                s = time.time()
                with torch.no_grad():
                    val_patchs_noisy = val_patchs_noisy.to(DEVICE)
                    val_patchs_true = val_patchs_true.to(DEVICE) 

                    reconstruct_patchs = model(val_patchs_noisy.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
                    img_true = np.clip(val_patchs_true.cpu().numpy(), a_min=0.0, a_max=1.0).astype(np.float64)
                    img_recon = np.clip(reconstruct_patchs.cpu().numpy(), a_min=0.0, a_max=1.0).astype(np.float64)
                    val_mse_value = np.square(img_true- img_recon).mean()
                    list_val_mse.append(val_mse_value)
                    # LOGGER.info(f"test_i={test_i} time={time.time()-s} test_i_psnr_value={10 * np.log10(1/test_mse_value)}")
                val_i+=1

            psnr_validation = 10 * np.log10(1/np.array(list_val_mse))
            LOGGER.info(f"FINISH VALIDATION EPOCH {epoch} - iter={i} -  psnr_validation={np.mean(psnr_validation)}")
            # model.graph_frame_recalibrate(H_train, W_train)
            model.train()

        if (i%VERBOSE_RATE == 0):

            # LOGGER.info(f"Start VALIDATION EPOCH {epoch} - iter={i}")
            # model.graph_frame_recalibrate(H_test, W_test)

            # ### VALIDAING
            model.eval()
            list_test_mse = []
            test_i = 0
            for test_patchs_noisy, test_patchs_true in data_test_batched:
                s = time.time()
                with torch.no_grad():
                    test_patchs_noisy = test_patchs_noisy.to(DEVICE)
                    test_patchs_true = test_patchs_true.to(DEVICE) 
                    reconstruct_patchs = model(test_patchs_noisy.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
                    img_true = np.clip(test_patchs_true[0].cpu().numpy(), a_min=0.0, a_max=1.0).astype(np.float64)
                    img_recon = np.clip(reconstruct_patchs[0].cpu().numpy(), a_min=0.0, a_max=1.0).astype(np.float64)
                    test_mse_value = np.square(img_true- img_recon).mean()
                    list_test_mse.append(test_mse_value)
                    # LOGGER.info(f"test_i={test_i} time={time.time()-s} test_i_psnr_value={10 * np.log10(1/test_mse_value)}")
                test_i+=1

            psnr_testing = 10 * np.log10(1/np.array(list_test_mse))
            LOGGER.info(f"FINISH TESING EPOCH {epoch} - iter={i} -  psnr_testing={np.mean(psnr_testing)}")
            # model.graph_frame_recalibrate(H_train, W_train)
            model.train()

        i+=1     