In [2]:
import torchOptics.optics as tt
import warnings
import torch.nn as nn
import torchOptics.metrics as tm
import torch.nn.functional as F
import torch.optim
import torch
from torch.utils.data import Dataset, DataLoader
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torchvision
import datetime
import tqdm
import time
import pandas as pd
# import HHS.model


warnings.filterwarnings('ignore')


class SignFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        t = torch.Tensor([0.5]).to(input.device)  # threshold
        output = (input > t).float() * 1
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output * torch.ones_like(input)  # Replace with your custom gradient computation
        return grad_input

class RealSign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output * torch.ones_like(input)  # Replace with your custom gradient computation
        return grad_input


class Dataset512(Dataset):
    def __init__(self, target_dir, meta, transform=None, isTrain=True, padding=0):
        self.target_dir = target_dir
        self.transform = transform
        self.meta = meta
        self.isTrain = isTrain
        self.target_list = sorted(glob.glob(target_dir+'*.png'))
        self.center_crop = torchvision.transforms.CenterCrop(1024)
        self.random_crop = torchvision.transforms.RandomCrop((1024, 1024))
        self.padding = padding

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        target = tt.imread(self.target_list[idx], meta=meta, gray=True).unsqueeze(0)
        if target.shape[-1] < 1024 or target.shape[-2] < 1024:
            target = torchvision.transforms.Resize(1024)(target)
        if self.isTrain:
            target = self.random_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        else:
            target = self.center_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        return target


def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)

# load data
batch_size = 1
# target_dir = 'dataset/DIV2K/DIV2K_train_HR/'
# valid_dir = 'dataset/DIV2K/DIV2K_valid_HR/'
target_dir = '/nfs/dataset/DIV2K/DIV2K_train_HR/DIV2K_train_HR/'
valid_dir = '/nfs/dataset/DIV2K/DIV2K_valid_HR/DIV2K_valid_HR/'
# sim_dir = 'binary_dataset/simulated/'
meta = {'wl' : (515e-9), 'dx':(7.56e-6, 7.56e-6)}
padding = 0
train_dataset = Dataset512(target_dir=target_dir, meta=meta, isTrain=True, padding=padding)
valid_dataset = Dataset512(target_dir=valid_dir, meta=meta, isTrain=False, padding=padding)

trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
validloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)


def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)


def train(model, num_hologram=10, custom_name='fcn', epochs=100, lr=1e-4, z=2e-3):
    date = datetime.datetime.now() - datetime.timedelta(hours=15)
    mean_psnr_list = []
    mean_mse_list = []
    max_epoch = 0
    max_psnr = 0
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    sign_function = SignFunction.apply
    # path = 'models/'
    model_name = f'{date}_{custom_name}_{num_hologram}_{z}'
    # make result folder
    result_folder = 'result/'+model_name
    os.mkdir(result_folder)
    model_path = os.path.join(result_folder, model_name)
    psnr_result_path = os.path.join(result_folder, 'meanPSNR.npy')
    mse_result_path = os.path.join(result_folder, 'meanMSE.npy')
    result = {'model_path': model_path, 'psnr_result_path': psnr_result_path,
          'mse_result_path': mse_result_path, 'z': z,
          'num_hologram': num_hologram, 'lr': lr, 'epochs': epochs}
    # save result dict
    with open(os.path.join(result_folder, 'result_dict.pickle'), 'wb') as fw:
        pickle.dump(result, fw)
    for epoch in range(epochs):
        pbar = tqdm.tqdm(trainloader)
        model.train()
        for target in pbar:
            out = model(target)
            binary = sign_function(out)
            sim = tt.simulate(binary, z).abs()**2
            result = torch.mean(sim, dim=1, keepdim=True)
            loss = tt.relativeLoss(result, target, F.mse_loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        model.eval()
        with torch.no_grad():
            psnrList = []
            mseList = []
            for valid in validloader:
                out = model(valid)
                binary = sign_function(out)
                sim = tt.simulate(binary, z).abs()**2
                result = torch.mean(sim, dim=1, keepdim=True)
                psnr = tt.relativeLoss(result, valid, tm.get_PSNR)
                psnrList.append(psnr)
                mse = tt.relativeLoss(result, valid, F.mse_loss).detach().cpu().numpy()
                mseList.append(mse)
            mean_psnr = sum(psnrList)/len(psnrList)
            mean_mse = sum(mseList)/len(mseList)
            mean_psnr_list.append(mean_psnr)
            mean_mse_list.append(mean_mse)
            # save psnr and mse
            np.save(psnr_result_path, mean_psnr_list)
            np.save(mse_result_path, mean_mse_list)
            plt.plot(np.arange(len(mean_psnr_list)), mean_psnr_list,
                     label=f'max psnr: {max(mean_psnr_list)}')
            plt.title(model_name+'_psnr')
            plt.legend()
            plt.savefig(os.path.join(result_folder, 'meanPSNR.png'))
            plt.clf()
            plt.plot(np.arange(len(mean_mse_list)), mean_mse_list,
                     label=f'max psnr: {min(mean_mse_list)}')
            plt.title(model_name+'_mse')
            plt.legend()
            plt.savefig(os.path.join(result_folder, 'meanMSE.png'))
            plt.clf()
            print(f'mean PSNR : {mean_psnr} {epoch}/{epochs}')
            print(f'mean MSE : {mean_mse} {epoch}/{epochs}')

        if mean_psnr > max_psnr:
            max_psnr = mean_psnr
            max_epoch = epoch
            torch.save(model.state_dict(), model_path)
        print(f'max_psnr: {max_psnr}, epoch: {max_epoch}')


def binary_sim(out, z=2e-3):
    binary = SignFunction.apply(out)
    sim = tt.simulate(binary, z).abs()**2
    res = torch.mean(sim, dim=1, keepdim=True)
    return binary, res


def valid(model):
    with torch.no_grad():
        psnr_list = []
        for target in validloader:
            out = model(target)
            binary, res = binary_sim(out)
            psnr = tt.relativeLoss(res, target, tm.get_PSNR)
            psnr_list.append(psnr)
    print(sum(psnr_list)/len(psnr_list))


def check_order(bmodel, classifier, target):
    if len(target.shape) == 3:
        target = target.unsqueeze(0)
    result = []
    with torch.no_grad():
        out = bmodel(target)
        binary, res = binary_sim(out)
        for i in range(10):
            binary_input = binary[0][i].unsqueeze(0).unsqueeze(0)
            prop_order = classifier(binary_input)
            result.append(torch.argmax(prop_order))
    print(result)


def check_order_without_model(binary, classifier):
    with torch.no_grad():
        result = []
        for i in range(10):
            binary_input = binary[0][i].unsqueeze(0).unsqueeze(0)
            prop_order = classifier(binary_input)
            result.append(torch.argmax(prop_order))
    print(result)


class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc = nn.Linear(64 * 64 * 64, 10)  # 최종 분류를 위한 FC 레이어

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = self.pool3(torch.relu(self.conv3(x)))
        x = x.view(-1, 64 * 64 * 64)  # flatten
        x = self.fc(x)
        return x


def rgb_binary_sim(out, z, th):
    pixel_pitch = 7.56e-6
    meta = {'wl' : (638e-9, 515e-9, 450e-9), 'dx':(pixel_pitch, pixel_pitch)}
    rmeta = {'wl': (638e-9), 'dx': (pixel_pitch, pixel_pitch)}
    gmeta = {'wl': (515e-9), 'dx': (pixel_pitch, pixel_pitch)}
    bmeta = {'wl': (450e-9), 'dx': (pixel_pitch, pixel_pitch)}
    sign = SignFunction.apply
    binary = sign(out, th)
    channel = out.shape[1]
    rchannel = int(channel/3)
    gchannel = int(channel*2/3)
    red = binary[:, :rchannel, :, :]
    green = binary[:, rchannel:gchannel, :, :]
    blue = binary[:, gchannel:, :, :]
    red = tt.Tensor(red, meta=rmeta)
    green = tt.Tensor(green, meta=gmeta)
    blue = tt.Tensor(blue, meta=bmeta)
    rsim = tt.simulate(red, z).abs()**2
    gsim = tt.simulate(green, z).abs()**2
    bsim = tt.simulate(blue, z).abs()**2
    rmean = torch.mean(rsim, dim=1, keepdim=True)
    gmean = torch.mean(gsim, dim=1, keepdim=True)
    bmean = torch.mean(bsim, dim=1, keepdim=True)
    rgb = torch.cat([rmean, gmean, bmean], dim=1)
    rgb = tt.Tensor(rgb, meta=meta)
    binary = tt.Tensor(binary, meta=meta)
    return binary, rgb

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
  warn(f"Failed to load image Python extension: {e}")


In [3]:
class BinaryNet(nn.Module):
    def __init__(self, num_hologram, final='Sigmoid', in_planes=3,
                 channels=[32, 64, 128, 256, 512, 1024, 2048, 4096],
                 convReLU=True, convBN=True, poolReLU=True, poolBN=True,
                 deconvReLU=True, deconvBN=True):
        super(BinaryNet, self).__init__()

        def CRB2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias=bias)]
            if relu:
                layers += [nn.Tanh()]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]

            cbr = nn.Sequential(*layers) # *으로 list unpacking 

            return cbr

        def TRB2d(in_channels, out_channels, kernel_size=2, stride=2, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                          kernel_size=2, stride=2, padding=0,
                                          bias=True)]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]
            if relu:
                layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers) # *으로 list unpacking 

            return cbr

        self.enc1_1 = CRB2d(in_planes, channels[0], relu=convReLU, bn=convBN)
        self.enc1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)
        self.pool1 = CRB2d(channels[0], channels[0], stride=2, relu=poolReLU, bn=poolBN)

        self.enc2_1 = CRB2d(channels[0], channels[1], relu=convReLU, bn=convBN)
        self.enc2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)
        self.pool2 = CRB2d(channels[1], channels[1], stride=2, relu=poolReLU, bn=poolBN)

        self.enc3_1 = CRB2d(channels[1], channels[2], relu=convReLU, bn=convBN)
        self.enc3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)
        self.pool3 = CRB2d(channels[2], channels[2], stride=2, relu=poolReLU, bn=poolBN)

        self.enc4_1 = CRB2d(channels[2], channels[3], relu=convReLU, bn=convBN)
        self.enc4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)
        self.pool4 = CRB2d(channels[3], channels[3], stride=2, relu=poolReLU, bn=poolBN)

        self.enc5_1 = CRB2d(channels[3], channels[4], relu=convReLU, bn=convBN)
        self.enc5_2 = CRB2d(channels[4], channels[4], relu=convReLU, bn=convBN)

        self.deconv4 = TRB2d(channels[4], channels[3], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec4_1 = CRB2d(channels[4], channels[3], relu=convReLU, bn=convBN)
        self.dec4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)

        self.deconv3 = TRB2d(channels[3], channels[2], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec3_1 = CRB2d(channels[3], channels[2], relu=convReLU, bn=convBN)
        self.dec3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)

        self.deconv2 = TRB2d(channels[2], channels[1], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec2_1 = CRB2d(channels[2], channels[1], relu=convReLU, bn=convBN)
        self.dec2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)

        self.deconv1 = TRB2d(channels[1], channels[0], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec1_1 = CRB2d(channels[1], channels[0], relu=convReLU, bn=convBN)
        self.dec1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)

        self.classifier = CRB2d(channels[0], num_hologram, relu=False, bn=False)

    def forward(self, x):
        # Encoder
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)
        enc5_2 = self.enc5_2(enc5_1)

        deconv4 = self.deconv4(enc5_2)
        concat4 = torch.cat((deconv4, enc4_2), dim=1)
        dec4_1 = self.dec4_1(concat4)
        dec4_2 = self.dec4_2(dec4_1)

        deconv3 = self.deconv3(dec4_2)
        concat3 = torch.cat((deconv3, enc3_2), dim=1)
        dec3_1 = self.dec3_1(concat3)
        dec3_2 = self.dec3_2(dec3_1)

        deconv2 = self.deconv2(dec3_2)
        concat2 = torch.cat((deconv2, enc2_2), dim=1)
        dec2_1 = self.dec2_1(concat2)
        dec2_2 = self.dec2_2(dec2_1)

        deconv1 = self.deconv1(dec2_2)
        concat1 = torch.cat((deconv1, enc1_2), dim=1)
        dec1_1 = self.dec1_1(concat1)
        dec1_2 = self.dec1_2(dec1_1)

        # Final classifier
        out = self.classifier(dec1_2)
        out = nn.Sigmoid()(out)
        return out


model = BinaryNet(num_hologram=8, in_planes=1, convReLU=False,
                  convBN=False, poolReLU=False, poolBN=False,
                  deconvReLU=False, deconvBN=False).cuda()
test = torch.randn(1, 1, 1024, 1024).cuda()
out = model(test)
print(out.shape)

torch.Size([1, 8, 1024, 1024])


In [4]:
def rgb_binary_train_padding(model, num_hologram=24, custom_name='0703Unet', epochs=500, lr=1e-4, z=2e-3):
    date = datetime.datetime.now() - datetime.timedelta(hours=15)
    model.apply(initialize_weights)
    mean_psnr_list = []
    mean_mse_list = []
    max_epoch = 0
    max_psnr = 0
    crop = torchvision.transforms.CenterCrop(896)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    sign = SignFunction.apply
    # path = 'models/'
    model_name = f'{date}_{custom_name}_{num_hologram}_{z}'
    # make result folder
    result_folder = 'result/'+model_name
    os.mkdir(result_folder)
    num_parameters = count_parameters(model)
    model_path = os.path.join(result_folder, model_name)
    psnr_result_path = os.path.join(result_folder, 'meanPSNR.npy')
    mse_result_path = os.path.join(result_folder, 'meanMSE.npy')
    result = {'model_path': model_path, 'psnr_result_path': psnr_result_path,
          'mse_result_path': mse_result_path, 'z': z,
          'num_hologram': num_hologram, 'lr': lr, 'epochs': epochs,
            'num_parameters': num_parameters}
    # save result dict
    with open(os.path.join(result_folder, 'result_dict.pickle'), 'wb') as fw:
        pickle.dump(result, fw)
    for epoch in range(epochs):
        pbar = tqdm.tqdm(trainloader)
        model.train()
        for target in pbar:
            target = target
            out = model(target)
            binary = sign(out)
            sim = tt.simulate(binary, 2e-3).abs()**2
            mean_sim = torch.mean(sim, dim=1, keepdim=True)
            rgb = crop(mean_sim)
            croped_target = crop(target)
            loss = tt.relativeLoss(rgb, croped_target, F.mse_loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        model.eval()
        with torch.no_grad():
            psnrList = []
            mseList = []
            for idx, valid in enumerate(validloader):
                valid = valid
                out = model(valid)
                binary = sign(out)
                sim = tt.simulate(binary, 2e-3).abs()**2
                mean_sim = torch.mean(sim, dim=1, keepdim=True)
                rgb = crop(mean_sim)
                croped_target = crop(valid)
                psnr = tt.relativeLoss(rgb, croped_target, tm.get_PSNR)
                psnrList.append(psnr)
                mse = tt.relativeLoss(rgb, croped_target, F.mse_loss).detach().cpu().numpy()
                mseList.append(mse)
                if idx == 0:
                    pl = tt.show_with_insets(rgb, croped_target, correct_colorwise=True)
                    pl.save(os.path.join(result_folder, 'penguin.png'))
            mean_psnr = sum(psnrList)/len(psnrList)
            mean_mse = sum(mseList)/len(mseList)
            mean_psnr_list.append(mean_psnr)
            mean_mse_list.append(mean_mse)
            # save psnr and mse
            if epoch % 10 == 0:
                np.save(psnr_result_path, mean_psnr_list)
                np.save(mse_result_path, mean_mse_list)
                plt.plot(np.arange(len(mean_psnr_list)), mean_psnr_list,
                         label=f'max psnr: {max(mean_psnr_list)}')
                plt.title(model_name+'_psnr')
                plt.legend()
                plt.savefig(os.path.join(result_folder, 'meanPSNR.png'))
                plt.clf()
                plt.plot(np.arange(len(mean_mse_list)), mean_mse_list,
                         label=f'max psnr: {min(mean_mse_list)}')
                plt.title(model_name+'_mse')
                plt.legend()
                plt.savefig(os.path.join(result_folder, 'meanMSE.png'))
                plt.clf()
            print(f'mean PSNR : {mean_psnr} {epoch}/{epochs}')
            print(f'mean MSE : {mean_mse} {epoch}/{epochs}')

        if mean_psnr > max_psnr:
            max_psnr = mean_psnr
            max_epoch = epoch
            torch.save(model.state_dict(), model_path)
        print(f'max_psnr: {max_psnr}, epoch: {max_epoch}')

In [None]:
rgb_binary_train_padding(model, num_hologram=8, custom_name='pre_reinforce', epochs=6000, lr=1e-4, z=2e-3)

100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 22.03728144645691 0/6000
mean MSE : 0.006638044661376625 0/6000
max_psnr: 22.03728144645691, epoch: 0


100%|██████████| 800/800 [00:52<00:00, 15.34it/s]


mean PSNR : 21.97874744415283 1/6000
mean MSE : 0.006457836385816335 1/6000
max_psnr: 22.03728144645691, epoch: 0


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 23.906861572265626 2/6000
mean MSE : 0.00443095731199719 2/6000
max_psnr: 23.906861572265626, epoch: 2


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 23.96236581802368 3/6000
mean MSE : 0.0043372797966003416 3/6000
max_psnr: 23.96236581802368, epoch: 3


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 23.916285934448243 4/6000
mean MSE : 0.004401192873483523 4/6000
max_psnr: 23.96236581802368, epoch: 3


100%|██████████| 800/800 [00:52<00:00, 15.28it/s]


mean PSNR : 22.39017587661743 5/6000
mean MSE : 0.005989363581174984 5/6000
max_psnr: 23.96236581802368, epoch: 3


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 19.640231800079345 6/6000
mean MSE : 0.011135856520850211 6/6000
max_psnr: 23.96236581802368, epoch: 3


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 19.398634214401245 7/6000
mean MSE : 0.011880428255535662 7/6000
max_psnr: 23.96236581802368, epoch: 3


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 23.35358179092407 8/6000
mean MSE : 0.0049772866093553605 8/6000
max_psnr: 23.96236581802368, epoch: 3


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 24.514981822967528 9/6000
mean MSE : 0.003912150895921514 9/6000
max_psnr: 24.514981822967528, epoch: 9


100%|██████████| 800/800 [00:52<00:00, 15.17it/s]


mean PSNR : 24.935651111602784 10/6000
mean MSE : 0.003609868086059578 10/6000
max_psnr: 24.935651111602784, epoch: 10


100%|██████████| 800/800 [00:52<00:00, 15.23it/s]


mean PSNR : 24.78624349594116 11/6000
mean MSE : 0.0036504956730641423 11/6000
max_psnr: 24.935651111602784, epoch: 10


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 24.497220344543457 12/6000
mean MSE : 0.0037993930414086208 12/6000
max_psnr: 24.935651111602784, epoch: 10


100%|██████████| 800/800 [00:52<00:00, 15.24it/s]


mean PSNR : 20.702145195007326 13/6000
mean MSE : 0.008565525196027012 13/6000
max_psnr: 24.935651111602784, epoch: 10


100%|██████████| 800/800 [00:52<00:00, 15.20it/s]


mean PSNR : 24.46669891357422 14/6000
mean MSE : 0.0038628290535416454 14/6000
max_psnr: 24.935651111602784, epoch: 10


100%|██████████| 800/800 [00:52<00:00, 15.19it/s]


mean PSNR : 24.650519256591796 15/6000
mean MSE : 0.003706236408906989 15/6000
max_psnr: 24.935651111602784, epoch: 10


100%|██████████| 800/800 [00:52<00:00, 15.17it/s]


mean PSNR : 24.69864677429199 16/6000
mean MSE : 0.003689855542033911 16/6000
max_psnr: 24.935651111602784, epoch: 10


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 24.807001495361327 17/6000
mean MSE : 0.003644610038609244 17/6000
max_psnr: 24.935651111602784, epoch: 10


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 24.969462985992433 18/6000
mean MSE : 0.0034973694558721037 18/6000
max_psnr: 24.969462985992433, epoch: 18


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 24.234669132232668 19/6000
mean MSE : 0.003973785851849243 19/6000
max_psnr: 24.969462985992433, epoch: 18


100%|██████████| 800/800 [00:52<00:00, 15.28it/s]


mean PSNR : 24.85430486679077 20/6000
mean MSE : 0.0035363407351542264 20/6000
max_psnr: 24.969462985992433, epoch: 18


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 25.282709217071535 21/6000
mean MSE : 0.0032670425507239998 21/6000
max_psnr: 25.282709217071535, epoch: 21


100%|██████████| 800/800 [00:52<00:00, 15.28it/s]


mean PSNR : 24.670210227966308 22/6000
mean MSE : 0.0036876039835624396 22/6000
max_psnr: 25.282709217071535, epoch: 21


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 25.108134117126465 23/6000
mean MSE : 0.0033846725738840176 23/6000
max_psnr: 25.282709217071535, epoch: 21


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 24.505263347625732 24/6000
mean MSE : 0.0037538823077920822 24/6000
max_psnr: 25.282709217071535, epoch: 21


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 24.36070966720581 25/6000
mean MSE : 0.003915156937437132 25/6000
max_psnr: 25.282709217071535, epoch: 21


100%|██████████| 800/800 [00:52<00:00, 15.24it/s]


mean PSNR : 24.604191761016846 26/6000
mean MSE : 0.0036830578302033245 26/6000
max_psnr: 25.282709217071535, epoch: 21


100%|██████████| 800/800 [00:52<00:00, 15.24it/s]


mean PSNR : 24.78536962509155 27/6000
mean MSE : 0.0036475595220690593 27/6000
max_psnr: 25.282709217071535, epoch: 21


100%|██████████| 800/800 [00:52<00:00, 15.29it/s]


mean PSNR : 25.179871101379394 28/6000
mean MSE : 0.0032991423201747237 28/6000
max_psnr: 25.282709217071535, epoch: 21


100%|██████████| 800/800 [00:52<00:00, 15.29it/s]


mean PSNR : 25.381358489990234 29/6000
mean MSE : 0.0031714298261795192 29/6000
max_psnr: 25.381358489990234, epoch: 29


100%|██████████| 800/800 [00:52<00:00, 15.21it/s]


mean PSNR : 24.883095226287843 30/6000
mean MSE : 0.0035169779043644667 30/6000
max_psnr: 25.381358489990234, epoch: 29


100%|██████████| 800/800 [00:52<00:00, 15.20it/s]


mean PSNR : 24.660496330261232 31/6000
mean MSE : 0.003683248453307897 31/6000
max_psnr: 25.381358489990234, epoch: 29


100%|██████████| 800/800 [00:52<00:00, 15.28it/s]


mean PSNR : 25.211856060028076 32/6000
mean MSE : 0.003286557836108841 32/6000
max_psnr: 25.381358489990234, epoch: 29


100%|██████████| 800/800 [00:52<00:00, 15.24it/s]


mean PSNR : 25.54079948425293 33/6000
mean MSE : 0.003100142360199243 33/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.23it/s]


mean PSNR : 25.040035190582277 34/6000
mean MSE : 0.003414222025312483 34/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 24.93385679244995 35/6000
mean MSE : 0.003484380804002285 35/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.17it/s]


mean PSNR : 25.13222543716431 36/6000
mean MSE : 0.0033625487150857223 36/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.24it/s]


mean PSNR : 25.438427257537843 37/6000
mean MSE : 0.0031598626438062638 37/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 25.50435110092163 38/6000
mean MSE : 0.003114918847568333 38/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 25.274472579956054 39/6000
mean MSE : 0.003222104535670951 39/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.32it/s]


mean PSNR : 25.231894073486327 40/6000
mean MSE : 0.003244192658457905 40/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.30it/s]


mean PSNR : 25.211969509124756 41/6000
mean MSE : 0.0032877331058261917 41/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.24it/s]


mean PSNR : 25.228032932281494 42/6000
mean MSE : 0.003238713286118582 42/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.29it/s]


mean PSNR : 25.140839099884033 43/6000
mean MSE : 0.0033081707352539524 43/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.29it/s]


mean PSNR : 24.93156764984131 44/6000
mean MSE : 0.0034753472916781904 44/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.24it/s]


mean PSNR : 25.349231853485108 45/6000
mean MSE : 0.003199945039814338 45/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 23.5404243850708 46/6000
mean MSE : 0.004605259746313095 46/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 25.08658338546753 47/6000
mean MSE : 0.0033533480134792628 47/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.23it/s]


mean PSNR : 25.287408962249756 48/6000
mean MSE : 0.0032362993125570937 48/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.23it/s]


mean PSNR : 25.43170639038086 49/6000
mean MSE : 0.0031370695255463943 49/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 25.21855094909668 50/6000
mean MSE : 0.0032587647438049316 50/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 24.37605073928833 51/6000
mean MSE : 0.003855058659100905 51/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.29it/s]


mean PSNR : 25.510064029693602 52/6000
mean MSE : 0.003093028133152984 52/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.31it/s]


mean PSNR : 25.479468002319337 53/6000
mean MSE : 0.0030726159113692118 53/6000
max_psnr: 25.54079948425293, epoch: 33


100%|██████████| 800/800 [00:52<00:00, 15.29it/s]


mean PSNR : 25.60849588394165 54/6000
mean MSE : 0.002977599332807586 54/6000
max_psnr: 25.60849588394165, epoch: 54


100%|██████████| 800/800 [00:52<00:00, 15.28it/s]


mean PSNR : 25.088536548614503 55/6000
mean MSE : 0.0033606809156481177 55/6000
max_psnr: 25.60849588394165, epoch: 54


100%|██████████| 800/800 [00:52<00:00, 15.30it/s]


mean PSNR : 25.031818923950194 56/6000
mean MSE : 0.003409540003631264 56/6000
max_psnr: 25.60849588394165, epoch: 54


100%|██████████| 800/800 [00:52<00:00, 15.33it/s]


mean PSNR : 25.589228019714355 57/6000
mean MSE : 0.0030405252240598203 57/6000
max_psnr: 25.60849588394165, epoch: 54


100%|██████████| 800/800 [00:52<00:00, 15.30it/s]


mean PSNR : 25.345158920288085 58/6000
mean MSE : 0.003166968692094088 58/6000
max_psnr: 25.60849588394165, epoch: 54


100%|██████████| 800/800 [00:52<00:00, 15.23it/s]


mean PSNR : 25.458261737823488 59/6000
mean MSE : 0.0030992957146372645 59/6000
max_psnr: 25.60849588394165, epoch: 54


100%|██████████| 800/800 [00:52<00:00, 15.28it/s]


mean PSNR : 25.373403911590575 60/6000
mean MSE : 0.0031205899245105684 60/6000
max_psnr: 25.60849588394165, epoch: 54


100%|██████████| 800/800 [00:52<00:00, 15.24it/s]


mean PSNR : 25.793882331848145 61/6000
mean MSE : 0.0029171850020065903 61/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.28it/s]


mean PSNR : 25.59766721725464 62/6000
mean MSE : 0.003030964984791353 62/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.32it/s]


mean PSNR : 25.256607322692872 63/6000
mean MSE : 0.00321991455857642 63/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.15it/s]


mean PSNR : 22.21592056274414 64/6000
mean MSE : 0.006119374674744904 64/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.31it/s]


mean PSNR : 24.23127897262573 65/6000
mean MSE : 0.003980916952132247 65/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.31it/s]


mean PSNR : 24.67693634033203 66/6000
mean MSE : 0.0036657416506204756 66/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.30it/s]


mean PSNR : 24.84567808151245 67/6000
mean MSE : 0.0035258828790392725 67/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.29it/s]


mean PSNR : 24.406199779510498 68/6000
mean MSE : 0.0037695187015924603 68/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 25.147864818572998 69/6000
mean MSE : 0.0032869275915436446 69/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 24.66177091598511 70/6000
mean MSE : 0.00356784753035754 70/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 25.45509853363037 71/6000
mean MSE : 0.0030919467378407715 71/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.31it/s]


mean PSNR : 25.75502555847168 72/6000
mean MSE : 0.0029297980800038204 72/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.31it/s]


mean PSNR : 25.182478542327882 73/6000
mean MSE : 0.003313625759910792 73/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 25.38324628829956 74/6000
mean MSE : 0.0031219265022082252 74/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.28it/s]


mean PSNR : 25.242467555999756 75/6000
mean MSE : 0.0032105441275052725 75/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.23it/s]


mean PSNR : 24.631947174072266 76/6000
mean MSE : 0.003627435304224491 76/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 25.407069854736328 77/6000
mean MSE : 0.0031139559636358173 77/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.20it/s]


mean PSNR : 25.75929262161255 78/6000
mean MSE : 0.0029270545934559776 78/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.21it/s]


mean PSNR : 25.023108501434326 79/6000
mean MSE : 0.0033480793063063174 79/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.30it/s]


mean PSNR : 24.237736892700195 80/6000
mean MSE : 0.003934654698241502 80/6000
max_psnr: 25.793882331848145, epoch: 61


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 25.804392910003664 81/6000
mean MSE : 0.0029334665730129925 81/6000
max_psnr: 25.804392910003664, epoch: 81


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 25.37807050704956 82/6000
mean MSE : 0.003189203884685412 82/6000
max_psnr: 25.804392910003664, epoch: 81


100%|██████████| 800/800 [00:52<00:00, 15.31it/s]


mean PSNR : 24.382769527435304 83/6000
mean MSE : 0.0038691810227464885 83/6000
max_psnr: 25.804392910003664, epoch: 81


100%|██████████| 800/800 [00:52<00:00, 15.30it/s]


mean PSNR : 24.086415195465086 84/6000
mean MSE : 0.004062313056783751 84/6000
max_psnr: 25.804392910003664, epoch: 81


100%|██████████| 800/800 [00:52<00:00, 15.30it/s]


mean PSNR : 25.259319496154784 85/6000
mean MSE : 0.003209452524315566 85/6000
max_psnr: 25.804392910003664, epoch: 81


100%|██████████| 800/800 [00:52<00:00, 15.33it/s]


mean PSNR : 25.499962005615235 86/6000
mean MSE : 0.0030725023709237574 86/6000
max_psnr: 25.804392910003664, epoch: 81


100%|██████████| 800/800 [00:52<00:00, 15.29it/s]


mean PSNR : 24.715947227478026 87/6000
mean MSE : 0.003528460366651416 87/6000
max_psnr: 25.804392910003664, epoch: 81


100%|██████████| 800/800 [00:52<00:00, 15.32it/s]


mean PSNR : 25.116290187835695 88/6000
mean MSE : 0.003281987182563171 88/6000
max_psnr: 25.804392910003664, epoch: 81


100%|██████████| 800/800 [00:52<00:00, 15.30it/s]


mean PSNR : 25.482149333953856 89/6000
mean MSE : 0.003042305804556236 89/6000
max_psnr: 25.804392910003664, epoch: 81


100%|██████████| 800/800 [00:52<00:00, 15.30it/s]


mean PSNR : 25.470037746429444 90/6000
mean MSE : 0.0030545256414916365 90/6000
max_psnr: 25.804392910003664, epoch: 81


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 25.254979248046876 91/6000
mean MSE : 0.003272330572362989 91/6000
max_psnr: 25.804392910003664, epoch: 81


100%|██████████| 800/800 [00:52<00:00, 15.29it/s]


mean PSNR : 25.600383358001707 92/6000
mean MSE : 0.003008781843818724 92/6000
max_psnr: 25.804392910003664, epoch: 81


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 25.636582832336426 93/6000
mean MSE : 0.002949726687511429 93/6000
max_psnr: 25.804392910003664, epoch: 81


100%|██████████| 800/800 [00:52<00:00, 15.33it/s]


mean PSNR : 25.741397495269776 94/6000
mean MSE : 0.0028985102829756216 94/6000
max_psnr: 25.804392910003664, epoch: 81


100%|██████████| 800/800 [00:52<00:00, 15.28it/s]


mean PSNR : 25.841898040771483 95/6000
mean MSE : 0.0028294504445511848 95/6000
max_psnr: 25.841898040771483, epoch: 95


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 25.945326251983644 96/6000
mean MSE : 0.0027851725439541042 96/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.31it/s]


mean PSNR : 25.835227928161622 97/6000
mean MSE : 0.0028644623898435384 97/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 25.345450973510744 98/6000
mean MSE : 0.0031819085904862734 98/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 25.386157722473143 99/6000
mean MSE : 0.00318255785969086 99/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 25.67453773498535 100/6000
mean MSE : 0.002960509607801214 100/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.31it/s]


mean PSNR : 25.749931411743162 101/6000
mean MSE : 0.0029224537074333057 101/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 25.680740756988527 102/6000
mean MSE : 0.0029735552310012283 102/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.21it/s]


mean PSNR : 25.473303546905516 103/6000
mean MSE : 0.0030793055694084614 103/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 25.217316646575927 104/6000
mean MSE : 0.0032130593294277788 104/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 25.498050575256347 105/6000
mean MSE : 0.0030635558045469225 105/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.23it/s]


mean PSNR : 24.909021663665772 106/6000
mean MSE : 0.003470298664178699 106/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 25.497368202209472 107/6000
mean MSE : 0.0030662860802840443 107/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.31it/s]


mean PSNR : 24.423972873687745 108/6000
mean MSE : 0.0038317884446587413 108/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.31it/s]


mean PSNR : 25.205170154571533 109/6000
mean MSE : 0.003235921349842101 109/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.29it/s]


mean PSNR : 25.7706729888916 110/6000
mean MSE : 0.002872916788328439 110/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 25.548708610534668 111/6000
mean MSE : 0.0029803725064266474 111/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 25.553807163238524 112/6000
mean MSE : 0.0029904957383405416 112/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 25.62148693084717 113/6000
mean MSE : 0.0029728971334407107 113/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.32it/s]


mean PSNR : 25.188183879852296 114/6000
mean MSE : 0.0033135193376801907 114/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.31it/s]


mean PSNR : 25.193369693756104 115/6000
mean MSE : 0.0032998665201012045 115/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 25.196955013275147 116/6000
mean MSE : 0.0032491762144491077 116/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 23.608417510986328 117/6000
mean MSE : 0.004516332041239366 117/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.26it/s]


mean PSNR : 25.38012643814087 118/6000
mean MSE : 0.0031324188242433594 118/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.21it/s]


mean PSNR : 25.16573377609253 119/6000
mean MSE : 0.003232721290551126 119/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.27it/s]


mean PSNR : 25.516492824554444 120/6000
mean MSE : 0.0030335185455624015 120/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 25.544137210845946 121/6000
mean MSE : 0.003019687962369062 121/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.23it/s]


mean PSNR : 25.499785442352294 122/6000
mean MSE : 0.0030665900389431045 122/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.23it/s]


mean PSNR : 25.904875965118407 123/6000
mean MSE : 0.0028165074466960506 123/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 25.486857204437257 124/6000
mean MSE : 0.0031247280375100673 124/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.28it/s]


mean PSNR : 25.458634662628175 125/6000
mean MSE : 0.0030894057650584727 125/6000
max_psnr: 25.945326251983644, epoch: 96


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 25.64068161010742 126/6000
mean MSE : 0.0029431192873744295 126/6000
max_psnr: 25.945326251983644, epoch: 96


 78%|███████▊  | 622/800 [00:40<00:12, 14.80it/s]

In [None]:
valid_dataset[0].shape

In [None]:
sign = SignFunction.apply
out = model(valid_dataset[0].unsqueeze(0).unsqueeze(0))
binary = sign(out, 0.5)
sim = tt.simulate(binary, 2e-3)
mean_sim = torch.mean(sim, dim=1, keepdim=True)

In [None]:
for target in trainloader:
    print(target.shape)
    break

In [None]:
target

In [None]:
mean_sim