In [14]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in 

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# Any results you write to the current directory are saved as output.

/kaggle/input/league-of-legends-diamond-ranked-games-10-min/high_diamond_ranked_10min.csv
/kaggle/input/dl-project/0001_NOISY_SRGB_010.PNG
/kaggle/input/dl-project/0001_GT_SRGB_011.PNG
/kaggle/input/dl-project/0001_GT_SRGB_010.PNG
/kaggle/input/dl-project/0001_NOISY_SRGB_011.PNG


In [15]:
import os
import time
import argparse
import torch.optim as optim
import torchvision.utils as utils
from torch.utils.data import DataLoader

In [16]:
ground_truth = list()
noise_images = list()
for dirname, _, filenames in os.walk('/kaggle/input/dl-project'):
    for filename in filenames:
        if filename[5:].startswith('GT'):
            ground_truth.append(filename)
            noise_images.append(filename.replace('GT', 'NOISY'))
print(ground_truth)
print(noise_images)

['0001_GT_SRGB_011.PNG', '0001_GT_SRGB_010.PNG']
['0001_NOISY_SRGB_011.PNG', '0001_NOISY_SRGB_010.PNG']


In [17]:
val_split = int(0.9*len(ground_truth))
gt_train, gt_val = ground_truth[:val_split], ground_truth[val_split:]
noise_train, noise_val = noise_images[:val_split], noise_images[val_split:]
print(gt_train, gt_val)
print(noise_train, noise_val)

['0001_GT_SRGB_011.PNG'] ['0001_GT_SRGB_010.PNG']
['0001_NOISY_SRGB_011.PNG'] ['0001_NOISY_SRGB_010.PNG']


In [18]:
import os
import os.path
import numpy as np
import random
import h5py
import torch
#import cv2
from PIL import Image
import glob
import torch.utils.data as udata
#from utils import data_augmentation

import torchvision.transforms as transforms

def normalize(data):
    return data/255.

def Im2Patch(img, win, stride=1):
    k = 0
    endc = img.shape[0]
    endw = img.shape[1]
    endh = img.shape[2]
    patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
    TotalPatNum = patch.shape[1] * patch.shape[2]
    Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
    for i in range(win):
        for j in range(win):
            patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
            Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
            k = k + 1
    return Y.reshape([endc, win, win, TotalPatNum])

def prepare_data(root, noise_list , gt_list,  train, patch_size = 256, stride = 200, aug_times=1):
    if train == True:
        # train
        print('processing training data')
        scales = [1]

        root = '../input/dl-project/'
        h5f = h5py.File('train.h5', 'w')
    else:
        # validation
        print('processing validation data')
        scales = [1]

        root = '../input/dl-project/'
        h5f = h5py.File('validation.h5', 'w')

    train_num = 0
    for i in range(len(noise_list)):

        n_img = Image.open(os.path.join(root, noise_list[i]))
        gt_img = Image.open(os.path.join(root, gt_list[i]))

        # Convert to numpy
        n_img = np.array(n_img, dtype=np.float16)
        gt_img = np.array(gt_img, dtype=np.float16)

        h, w, c = n_img.shape
        for k in range(len(scales)):
            n_img = n_img.transpose(2, 0, 1)
            n_img = np.float16(normalize(n_img))
            n_patches = Im2Patch(n_img, win=patch_size, stride=stride)

            gt_img = gt_img.transpose(2, 0, 1)
            gt_img = np.float16(normalize(gt_img))
            gt_patches = Im2Patch(gt_img, win=patch_size, stride=stride)

            #print("file: %s scale %.1f # samples: %d" % (noise_list[i], scales[k], n_patches.shape[3]*aug_times))
            for n in range(n_patches.shape[3]):
                patches = np.concatenate((n_patches[:,:,:,n], gt_patches[:,:,:,n]), axis = 0)
                data = patches.copy()
                for m in range(aug_times):
                    data_aug = data
                    h5f.create_dataset(str(train_num)+"_aug_%d" % (m), data=data_aug)
                    train_num += 1
                    #print(train_num)
    h5f.close()
    if train == True:
        print('training set, # samples %d\n' % train_num)
    else:
        print('validation set, # samples %d\n' % train_num)

In [19]:
#prepare_data(root = "a" , noise_list = noise_list , gt_list  = gt_list )

In [20]:
class Dataset(udata.Dataset):
    def __init__(self, train=True):
        super(Dataset, self).__init__()
        self.train = train
        self.output_size = (48, 48)
        if self.train:
            h5f = h5py.File('/kaggle/working/train.h5', 'r')
        else:
            h5f = h5py.File('/kaggle/working/validation.h5', 'r')
        self.keys = list(h5f.keys())
        random.shuffle(self.keys)
        h5f.close()
    def __len__(self):
        return len(self.keys)
        
    def preprocess(self, x):
        dh = np.random.randint(1, 200, size=1)
        dw = np.random.randint(1, 200, size=1)
        x = x[:, dh[0]:dh[0]+48, dw[0]:dw[0]+48]    #48, 200
        #mode = np.random.randint(0, 2, size=1)
        #x = data_augmentation(x, mode)
        y = x
        return y
        
    def __getitem__(self, index):
        if self.train:
            h5f = h5py.File('/kaggle/working/train.h5', 'r')
        else:
            h5f = h5py.File('/kaggle/working/validation.h5', 'r')
        key = self.keys[index]
        data = np.array(h5f[key])
        data = self.preprocess(data)
        h5f.close()
        return torch.Tensor(data)


In [21]:
import torch 
import torch.nn as nn
import torch.nn.init as init
  
class RDB_Conv(nn.Module):
    def __init__(self, in_C, growRate, kSize=3):
        super(RDB_Conv, self).__init__()
        self.conv = nn.Sequential(*[
            nn.Conv2d(in_C, growRate, kSize, padding=(kSize-1)//2, stride=1),
            nn.ReLU()
        ])
        # self.conv = nn.Sequential(
        #     nn.BatchNorm2d(in_C),
        #     Modulecell(in_channels=in_C,out_channels=growRate,kernel_size=kSize))
        # self.conv = nn.Sequential(*[
        #     nn.Conv2d(in_C, growRate, kSize, padding=(kSize-1)//2, stride=1),
        #     nn.BatchNorm2d(growRate),
        #     nn.PReLU()
        # ])
    def forward(self, x):
        out = self.conv(x)
        return torch.cat((x, out), 1)
        
class RDB(nn.Module):
    def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
        super(RDB, self).__init__()
        G0 = growRate0
        G = growRate
        C = nConvLayers
        convs = []
        for i in range(C):
            convs.append(RDB_Conv(G0 + i*G, G, kSize))
        self.convs = nn.Sequential(*convs)
        #Local Feature Fusion
        self.lff = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)
        
    def forward(self, x):
        y = self.lff(self.convs(x)) + x
        return y
    
class NonLocalBlock2D(nn.Module):
    def __init__(self, in_channels, inter_channels):
        super(NonLocalBlock2D, self).__init__()
        self.in_channels = in_channels
        self.inter_channels = inter_channels
        self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
        self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0)
        nn.init.constant(self.W.weight, 0)
        nn.init.constant(self.W.bias, 0)
        self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
        self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        batch_size = x.size(0)
        g_x = self.g(x).view(batch_size, self.inter_channels, -1)        
        g_x = g_x.permute(0,2,1)
        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0,2,1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        f = torch.matmul(theta_x, phi_x)
        f_div_C = F.softmax(f, dim=1)
        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0,2,1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x
        return z
    
class NLMaskBranchDownUp(nn.Module):
    def __init__(self,growRate0, growRate, nConvLayers, nRDBs, kSize=3):
        super(NLMaskBranchDownUp, self).__init__()
        G0 = growRate0
        G = growRate
        C = nConvLayers
        m = int(nRDBs/4)
        MB_RB1 = []
        MB_RB1.append(NonLocalBlock2D(G0, G0))
        for i in range(m):
            MB_RB1.append(RDB(G0, G, C, kSi1ze))
        MB_Down = []
        MB_Down.append(nn.Conv2d(G0,G0, 3, stride=2, padding=1))        
        MB_RB2 = []
        for i in range(2*m):
            MB_RB2.append(RDB(G0, G, C, kSize))
        MB_Up = []
        MB_Up.append(nn.ConvTranspose2d(G0,G0, 6, stride=2, padding=2))   
        MB_RB3 = []
        for i in range(m):
            MB_RB3.append(RDB(G0, G, C, kSize))
        MB_1x1conv = []
        MB_1x1conv.append(nn.Conv2d(G0,G0, 1, padding=0, bias=True))
        MB_sigmoid = []
        MB_sigmoid.append(nn.Sigmoid())
        self.MB_RB1 = nn.Sequential(*MB_RB1)
        self.MB_Down = nn.Sequential(*MB_Down)
        self.MB_RB2 = nn.Sequential(*MB_RB2)
        self.MB_Up  = nn.Sequential(*MB_Up)
        self.MB_RB3 = nn.Sequential(*MB_RB3)
        self.MB_1x1conv = nn.Sequential(*MB_1x1conv)
        self.MB_sigmoid = nn.Sequential(*MB_sigmoid)
    
    def forward(self, x):
        x_RB1 = self.MB_RB1(x)
        x_Down = self.MB_Down(x_RB1)
        x_RB2 = self.MB_RB2(x_Down)
        x_Up = self.MB_Up(x_RB2)
        x_preRB3 = x_RB1 + x_Up
        x_RB3 = self.MB_RB3(x_preRB3)
        x_1x1 = self.MB_1x1conv(x_RB3)
        mx = self.MB_sigmoid(x_1x1)
        return mx
       
class RDN(nn.Module):
    def __init__(self, growRate0,growRate, RDBkSize,nConvLayers,nRDBs):
        super(RDN, self).__init__()
        G0 = growRate0
        kSize = RDBkSize
        G = growRate
        C = nConvLayers
        D = nRDBs
        #D, C, G = (20, 6, 32)
        D, C, G = (16, 8, 64)
        self.RDB_num = D
        #Shallow Feature Extraction
        self.sfe1 = nn.Conv2d(3, G0, kSize, padding=(kSize-1)//2, stride=1)
        self.sfe2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
        #Residual Dense Blocks
        self.RDBs = nn.ModuleList()
        for i in range(D):
            self.RDBs.append(RDB(G0, G, C, kSize))
        #self.Mask = []
        self.Mask = (NLMaskBranchDownUp(G0, G, C,nRDBs, kSize=3))
        #Global Feature Fusion
        self.gff = nn.Sequential(*[
            nn.Conv2d(D*G0, G0, 1, padding=0, stride=1),
            nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
        ])
        #self.non_local = Self_Att(G0, 'relu')
        self.out_conv = nn.Conv2d(G0, 3, kSize, padding=(kSize-1)//2, stride=1)
        self.init_params()

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                #init.kaiming_normal_(m.weight, mode='fan_out')
                init.normal(m.weight, std=0.01)
                # init.xavier_normal_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.ConvTranspose2d):
                #init.kaiming_normal_(m.weight, mode='fan_out')
                init.normal_(m.weight, std=0.01)
                # init.xavier_normal_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    def forward(self, x):
        f1 = self.sfe1(x)
        y = self.sfe2(f1)
        f2= self.Mask(y)
        RDB_out = []
        for i in range(self.RDB_num):
            y = self.RDBs[i](y)
            RDB_out.append(y)
        y = self.gff(torch.cat(RDB_out, 1))
        #y = self.non_local(y)
        y = y*f2
        y = self.out_conv(f1 + y)
        
        y += x
        return y

In [22]:
import math
import torch
import torch.nn as nn
import numpy as np
from skimage.measure.simple_metrics import compare_psnr

def weights_init_kaiming(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('BatchNorm') != -1:
        # nn.init.uniform(m.weight.data, 1.0, 0.02)
        m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025)
        nn.init.constant(m.bias.data, 0.0)

def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
    return (PSNR/Img.shape[0])

def data_augmentation(image, mode):
    # #out = np.transpose(image, (1,2,0))
    # out = image
    # if mode == 0:
    #     # original
    #     out = out
    # elif mode == 1:
    #     # flip up and down
    #     out = np.flipud(out)
    # elif mode == 2:
    #     # rotate counterwise 90 degree
    #     out = np.rot90(out)
    # elif mode == 3:
    #     # rotate 90 degree and flip up and down
    #     out = np.rot90(out)
    #     out = np.flipud(out)
    # elif mode == 4:
    #     # rotate 180 degree
    #     out = np.rot90(out, k=2)
    # elif mode == 5:
    #     # rotate 180 degree and flip
    #     out = np.rot90(out, k=2)
    #     out = np.flipud(out)
    # elif mode == 6:
    #     # rotate 270 degree
    #     out = np.rot90(out, k=3)
    # elif mode == 7:
    #     # rotate 270 degree and flip
    #     out = np.rot90(out, k=3)
    #     out = np.flipud(out)
    # #return np.transpose(out, (2,0,1))
    # return out

    if mode == 0:
        # original
        out = image
    elif mode == 1:
        # flip up and down
        out = np.flipud(image)
    elif mode == 2:
        # rotate counterwise 90 degree
        out = np.rot90(image, k=2)

    return out

def SSIM(x, y):
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    mu_x = nn.AvgPool2d(3, 1)(x)
    mu_y = nn.AvgPool2d(3, 1)(y)
    mu_x_mu_y = mu_x * mu_y
    mu_x_sq = mu_x.pow(2)
    mu_y_sq = mu_y.pow(2)

    sigma_x = nn.AvgPool2d(3, 1)(x * x) - mu_x_sq
    sigma_y = nn.AvgPool2d(3, 1)(y * y) - mu_y_sq
    sigma_xy = nn.AvgPool2d(3, 1)(x * y) - mu_x_mu_y

    SSIM_n = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2)
    SSIM_d = (mu_x_sq + mu_y_sq + C1) * (sigma_x + sigma_y + C2)
    SSIM = SSIM_n / SSIM_d

    return torch.clamp((1 - SSIM) / 2, 0, 1)

def update_lr(ori_lr, epoch):
    #current_lr = ori_lr * (1 - epoch/40.0)**0.9
    current_lr = ori_lr * (0.1 ** ((epoch) // 40))
    return current_lr

In [23]:
import os
import time
import argparse
import torch.optim as optim
import torchvision.utils as utils
from torch.utils.data import DataLoader

In [24]:
model_name = "RDN_e40_16"
batch_size = 4
lr = 1e-4
epochs = 3

In [25]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
model_dir = os.path.join("/kaggle/working/", model_name)
print('create checkpoint directory %s...' % model_dir)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

create checkpoint directory /kaggle/working/RDN_e40_16...


In [26]:
prepare_data(root = "adarsh" , noise_list = noise_train , gt_list = gt_train , train = True)
prepare_data(root = "adarsh" , noise_list = noise_val , gt_list = gt_val, train = False)

processing training data
training set, # samples 364

processing validation data
validation set, # samples 364



In [27]:
print('Loading training dataset ...\n')
dataset_train = Dataset(train=True)
loader_train = DataLoader(dataset=dataset_train, num_workers=4 , batch_size = batch_size, shuffle=True)
num = len(dataset_train)
print("# of training samples: %d\n" % int(num))

Loading training dataset ...

# of training samples: 364



In [28]:
print('Loading validation dataset ...\n')
dataset_test = Dataset(train=False)
loader_test = DataLoader(dataset=dataset_test, num_workers=4 , batch_size = batch_size, shuffle=True)
num = len(dataset_test)
print("# of validation samples: %d\n" % int(num))

Loading validation dataset ...

# of validation samples: 364



## Training

In [32]:
# Build model
net = RDN(64,64,3,8, 3)

num_params = 0
for parm in net.parameters():
    num_params += parm.numel()
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (model_name, num_params / 1e6))
# Move to GPU
device_ids = [0]
model = nn.DataParallel(net, device_ids=device_ids).cuda()
#model.load_state_dict(torch.load(os.path.join('logs/', opt.name, '40_net.pth')))  # !!!

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)  #weight_decay=opt.weight_decay



RDN(
  (sfe1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (sfe2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (RDBs): ModuleList(
    (0): RDB(
      (convs): Sequential(
        (0): RDB_Conv(
          (conv): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU()
          )
        )
        (1): RDB_Conv(
          (conv): Sequential(
            (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU()
          )
        )
        (2): RDB_Conv(
          (conv): Sequential(
            (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU()
          )
        )
        (3): RDB_Conv(
          (conv): Sequential(
            (0): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): ReLU()
          )
        )
        (4): RDB_Conv(
          (conv): Sequentia

In [33]:
def validator(model, loader_data = loader_test):
    ave_loss_val = 0
    ave_psnr_val = 0
    ave_ssim_val = 0

    for j, data in enumerate(loader_test, 0):
        # training step
        time1 = time.time()
        noise_img_val = data[:, :3, :, :]
        gt_img_val = data[:, 3:, :, :]

        noise_img_val, gt_img_val = noise_img_val.cuda(), gt_img_val.cuda()
        res_val = noise_img_val - gt_img_val

        # Validation set
        pred_res_val = model(noise_img_val)

        loss1_val = torch.mean(torch.abs(pred_res_val - gt_img_val))
        loss2_val = torch.mean(SSIM(pred_res_val, gt_img_val))
        loss_val = loss1_val #0.75*loss1 + 0.25*loss2
        #Validation set

        #Validation set
        result_val = torch.clamp(pred_res_val, 0., 1.)
        psnr_val = batch_PSNR(result_val, gt_img_val, 1.)

        ave_loss_val = (ave_loss_val*j + loss_val.item()) / (j+1)
        ave_psnr_val = (ave_psnr_val*j + psnr_val) / (j+1)
        ave_ssim_val = (ave_ssim_val*j + 1-loss2_val.item()*2) / (j+1)
        #Validation set

    return ave_loss_val, ave_psnr_val, ave_ssim_val

In [34]:
import torch.nn.functional as F
step = 0
print(f"Training on {len(loader_train)} samples and validating on {len(loader_test)} samples")

for epoch in range(epochs):
    # set learning rate
    current_lr = update_lr(lr, epoch)
    for param_group in optimizer.param_groups:
        param_group["lr"] = current_lr
    print('learning rate %f' % current_lr)

    # train
    model.train()
    start_time = time.time()
    
    ave_loss = 0
    ave_psnr = 0
    ave_ssim = 0

    '''
    for j, data in enumerate(loader_test, 0):
    data = data[:, :3, :, :]
    # training step
    time1 = time.time()
    noise_img_val = data[:, :3, :, :]
    gt_img_val = data[:, 3:, :, :]

    noise_img_val, gt_img_val = noise_img_val.cuda(), gt_img_val.cuda()
    res_val = noise_img_val - gt_img_val
    #res_val = res_val.cuda()
    '''
    
    for i, data in enumerate(loader_train, 0):
        # training step
        time1 = time.time()
        model.zero_grad()
        optimizer.zero_grad()

        noise_img = data[:, :3, :, :]
        gt_img = data[:, 3:, :, :]

        noise_img, gt_img = noise_img.cuda(), gt_img.cuda()
        res = noise_img - gt_img
        pred_res = model(noise_img)

        loss1 = torch.mean(torch.abs(pred_res - gt_img))
        loss2 = torch.mean(SSIM(pred_res, gt_img))
        loss = loss1 #0.75*loss1 + 0.25*loss2
        
        loss.backward()
        optimizer.step()

        # evaluate
        #result = torch.clamp(noise_img-pred_res, 0., 1.)
        result = torch.clamp(pred_res, 0., 1.)
        psnr_train = batch_PSNR(result, gt_img, 1.)

        ave_loss = (ave_loss*i + loss.item()) / (i+1)
        ave_psnr = (ave_psnr*i + psnr_train) / (i+1)
        ave_ssim = (ave_ssim*i + 1-loss2.item()*2) / (i+1)
 

        if i == len(loader_train)-1:
            time2 = time.time()
            print("Training details: [epoch %d][%d/%d] Time taken: %.3f loss: %.4f PSNR_train: %.4f SSIM_train: %.4f" %
                (epoch+1, i+1, len(loader_train), (time2 - start_time), ave_loss, ave_psnr, ave_ssim))
            ave_loss_val, ave_psnr_val, ave_ssim_val = validator(model)
            time3 = time.time()
            print("Testing Details: Time taken: %.3f loss: %.4f PSNR_test: %.4f SSIM_test: %.4f" %
                ( (time3 - time2), ave_loss_val, ave_psnr_val, ave_ssim_val))
            print(f'Total time for epoch: {time3 - start_time}')
        if step % 1000 == 0:
            torch.save(model.state_dict(), os.path.join(model_dir, 'latest_net.pth'))
        step += 1
        #result = torch.clamp(noise_img-pred_res, 0., 1.)

    # save model
    if epoch%2 == 0:
        save_name = '%d_net.pth' % (epoch+1)
        torch.save(model.state_dict(), os.path.join(model_dir, save_name))

Training on 91 samples and validating on 91 samples
learning rate 0.000100




Training details: [epoch 1][91/91] Time taken: 18.378 loss: 0.0181 PSNR_train: 32.7643 SSIM_train: 0.7529
Testing Details: Time taken: 6.350 loss: 0.0137 PSNR_test: 35.1370 SSIM_test: 0.8478
Total time for epoch: 24.72796630859375
learning rate 0.000100
Training details: [epoch 2][91/91] Time taken: 18.197 loss: 0.0122 PSNR_train: 36.2488 SSIM_train: 0.8818
Testing Details: Time taken: 6.243 loss: 0.0116 PSNR_test: 36.6396 SSIM_test: 0.8919
Total time for epoch: 24.440308332443237
learning rate 0.000100
Training details: [epoch 3][91/91] Time taken: 18.095 loss: 0.0115 PSNR_train: 36.7668 SSIM_train: 0.8935
Testing Details: Time taken: 6.218 loss: 0.0112 PSNR_test: 37.0263 SSIM_test: 0.8990
Total time for epoch: 24.31334352493286
