In [1]:
import torchOptics.optics as tt
import warnings
import torch.nn as nn
import torchOptics.metrics as tm
import torch.nn.functional as F
import torch.optim
import torch
from torch.utils.data import Dataset, DataLoader
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torchvision
import datetime
import tqdm
import time
import pandas as pd
# import HHS.model

IPS = 1024
CH = 8

warnings.filterwarnings('ignore')


class SignFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        t = torch.Tensor([0.5]).to(input.device)  # threshold
        output = (input > t).float() * 1
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output * torch.ones_like(input)  # Replace with your custom gradient computation
        return grad_input

class RealSign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output * torch.ones_like(input)  # Replace with your custom gradient computation
        return grad_input


class Dataset512(Dataset):
    def __init__(self, target_dir, meta, transform=None, isTrain=True, padding=0):
        self.target_dir = target_dir
        self.transform = transform
        self.meta = meta
        self.isTrain = isTrain
        self.target_list = sorted(glob.glob(target_dir+'*.png'))
        self.center_crop = torchvision.transforms.CenterCrop(IPS)
        self.random_crop = torchvision.transforms.RandomCrop((IPS, IPS))
        self.padding = padding

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        target = tt.imread(self.target_list[idx], meta=meta, gray=True).unsqueeze(0)
        if target.shape[-1] < IPS or target.shape[-2] < IPS:
            target = torchvision.transforms.Resize(IPS)(target)
        if self.isTrain:
            target = self.random_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        else:
            target = self.center_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        return target


def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)

# load data
batch_size = 1
# target_dir = 'dataset/DIV2K/DIV2K_train_HR/'
# valid_dir = 'dataset/DIV2K/DIV2K_valid_HR/'
target_dir = '/nfs/dataset/DIV2K/DIV2K_train_HR/DIV2K_train_HR/'
valid_dir = '/nfs/dataset/DIV2K/DIV2K_valid_HR/DIV2K_valid_HR/'
# sim_dir = 'binary_dataset/simulated/'
meta = {'wl' : (515e-9), 'dx':(7.56e-6, 7.56e-6)}
padding = 0
train_dataset = Dataset512(target_dir=target_dir, meta=meta, isTrain=True, padding=padding)
valid_dataset = Dataset512(target_dir=valid_dir, meta=meta, isTrain=False, padding=padding)

trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
validloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)


def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)


def train(model, num_hologram=10, custom_name='fcn', epochs=100, lr=1e-4, z=2e-3):
    date = datetime.datetime.now() - datetime.timedelta(hours=15)
    mean_psnr_list = []
    mean_mse_list = []
    max_epoch = 0
    max_psnr = 0
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    sign_function = SignFunction.apply
    # path = 'models/'
    model_name = f'{date}_{custom_name}_{num_hologram}_{z}'
    # make result folder
    result_folder = 'result/'+model_name
    os.mkdir(result_folder)
    model_path = os.path.join(result_folder, model_name)
    psnr_result_path = os.path.join(result_folder, 'meanPSNR.npy')
    mse_result_path = os.path.join(result_folder, 'meanMSE.npy')
    result = {'model_path': model_path, 'psnr_result_path': psnr_result_path,
          'mse_result_path': mse_result_path, 'z': z,
          'num_hologram': num_hologram, 'lr': lr, 'epochs': epochs}
    # save result dict
    with open(os.path.join(result_folder, 'result_dict.pickle'), 'wb') as fw:
        pickle.dump(result, fw)
    for epoch in range(epochs):
        pbar = tqdm.tqdm(trainloader)
        model.train()
        for target in pbar:
            out = model(target)
            binary = sign_function(out)
            sim = tt.simulate(binary, z).abs()**2
            result = torch.mean(sim, dim=1, keepdim=True)
            loss = tt.relativeLoss(result, target, F.mse_loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        model.eval()
        with torch.no_grad():
            psnrList = []
            mseList = []
            for valid in validloader:
                out = model(valid)
                binary = sign_function(out)
                sim = tt.simulate(binary, z).abs()**2
                result = torch.mean(sim, dim=1, keepdim=True)
                psnr = tt.relativeLoss(result, valid, tm.get_PSNR)
                psnrList.append(psnr)
                mse = tt.relativeLoss(result, valid, F.mse_loss).detach().cpu().numpy()
                mseList.append(mse)
            mean_psnr = sum(psnrList)/len(psnrList)
            mean_mse = sum(mseList)/len(mseList)
            mean_psnr_list.append(mean_psnr)
            mean_mse_list.append(mean_mse)
            # save psnr and mse
            np.save(psnr_result_path, mean_psnr_list)
            np.save(mse_result_path, mean_mse_list)
            plt.plot(np.arange(len(mean_psnr_list)), mean_psnr_list,
                     label=f'max psnr: {max(mean_psnr_list)}')
            plt.title(model_name+'_psnr')
            plt.legend()
            plt.savefig(os.path.join(result_folder, 'meanPSNR.png'))
            plt.clf()
            plt.plot(np.arange(len(mean_mse_list)), mean_mse_list,
                     label=f'max psnr: {min(mean_mse_list)}')
            plt.title(model_name+'_mse')
            plt.legend()
            plt.savefig(os.path.join(result_folder, 'meanMSE.png'))
            plt.clf()
            print(f'mean PSNR : {mean_psnr} {epoch}/{epochs}')
            print(f'mean MSE : {mean_mse} {epoch}/{epochs}')

        if mean_psnr > max_psnr:
            max_psnr = mean_psnr
            max_epoch = epoch
            torch.save(model.state_dict(), model_path)
        print(f'max_psnr: {max_psnr}, epoch: {max_epoch}')


def binary_sim(out, z=2e-3):
    binary = SignFunction.apply(out)
    sim = tt.simulate(binary, z).abs()**2
    res = torch.mean(sim, dim=1, keepdim=True)
    return binary, res


def valid(model):
    with torch.no_grad():
        psnr_list = []
        for target in validloader:
            out = model(target)
            binary, res = binary_sim(out)
            psnr = tt.relativeLoss(res, target, tm.get_PSNR)
            psnr_list.append(psnr)
    print(sum(psnr_list)/len(psnr_list))


def check_order(bmodel, classifier, target):
    if len(target.shape) == 3:
        target = target.unsqueeze(0)
    result = []
    with torch.no_grad():
        out = bmodel(target)
        binary, res = binary_sim(out)
        for i in range(10):
            binary_input = binary[0][i].unsqueeze(0).unsqueeze(0)
            prop_order = classifier(binary_input)
            result.append(torch.argmax(prop_order))
    print(result)


def check_order_without_model(binary, classifier):
    with torch.no_grad():
        result = []
        for i in range(10):
            binary_input = binary[0][i].unsqueeze(0).unsqueeze(0)
            prop_order = classifier(binary_input)
            result.append(torch.argmax(prop_order))
    print(result)


class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc = nn.Linear(64 * 64 * 64, 10)  # 최종 분류를 위한 FC 레이어

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = self.pool3(torch.relu(self.conv3(x)))
        x = x.view(-1, 64 * 64 * 64)  # flatten
        x = self.fc(x)
        return x


def rgb_binary_sim(out, z, th):
    pixel_pitch = 7.56e-6
    meta = {'wl' : (638e-9, 515e-9, 450e-9), 'dx':(pixel_pitch, pixel_pitch)}
    rmeta = {'wl': (638e-9), 'dx': (pixel_pitch, pixel_pitch)}
    gmeta = {'wl': (515e-9), 'dx': (pixel_pitch, pixel_pitch)}
    bmeta = {'wl': (450e-9), 'dx': (pixel_pitch, pixel_pitch)}
    sign = SignFunction.apply
    binary = sign(out, th)
    channel = out.shape[1]
    rchannel = int(channel/3)
    gchannel = int(channel*2/3)
    red = binary[:, :rchannel, :, :]
    green = binary[:, rchannel:gchannel, :, :]
    blue = binary[:, gchannel:, :, :]
    red = tt.Tensor(red, meta=rmeta)
    green = tt.Tensor(green, meta=gmeta)
    blue = tt.Tensor(blue, meta=bmeta)
    rsim = tt.simulate(red, z).abs()**2
    gsim = tt.simulate(green, z).abs()**2
    bsim = tt.simulate(blue, z).abs()**2
    rmean = torch.mean(rsim, dim=1, keepdim=True)
    gmean = torch.mean(gsim, dim=1, keepdim=True)
    bmean = torch.mean(bsim, dim=1, keepdim=True)
    rgb = torch.cat([rmean, gmean, bmean], dim=1)
    rgb = tt.Tensor(rgb, meta=meta)
    binary = tt.Tensor(binary, meta=meta)
    return binary, rgb

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


In [2]:
class BinaryNet(nn.Module):
    def __init__(self, num_hologram, final='Sigmoid', in_planes=3,
                 channels=[32, 64, 128, 256, 512, 1024, 2048, 4096],
                 convReLU=True, convBN=True, poolReLU=True, poolBN=True,
                 deconvReLU=True, deconvBN=True):
        super(BinaryNet, self).__init__()

        def CRB2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias=bias)]
            if relu:
                layers += [nn.Tanh()]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]

            cbr = nn.Sequential(*layers) # *으로 list unpacking 

            return cbr

        def TRB2d(in_channels, out_channels, kernel_size=2, stride=2, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                          kernel_size=2, stride=2, padding=0,
                                          bias=True)]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]
            if relu:
                layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers) # *으로 list unpacking 

            return cbr

        self.enc1_1 = CRB2d(in_planes, channels[0], relu=convReLU, bn=convBN)
        self.enc1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)
        self.pool1 = CRB2d(channels[0], channels[0], stride=2, relu=poolReLU, bn=poolBN)

        self.enc2_1 = CRB2d(channels[0], channels[1], relu=convReLU, bn=convBN)
        self.enc2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)
        self.pool2 = CRB2d(channels[1], channels[1], stride=2, relu=poolReLU, bn=poolBN)

        self.enc3_1 = CRB2d(channels[1], channels[2], relu=convReLU, bn=convBN)
        self.enc3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)
        self.pool3 = CRB2d(channels[2], channels[2], stride=2, relu=poolReLU, bn=poolBN)

        self.enc4_1 = CRB2d(channels[2], channels[3], relu=convReLU, bn=convBN)
        self.enc4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)
        self.pool4 = CRB2d(channels[3], channels[3], stride=2, relu=poolReLU, bn=poolBN)

        self.enc5_1 = CRB2d(channels[3], channels[4], relu=convReLU, bn=convBN)
        self.enc5_2 = CRB2d(channels[4], channels[4], relu=convReLU, bn=convBN)

        self.deconv4 = TRB2d(channels[4], channels[3], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec4_1 = CRB2d(channels[4], channels[3], relu=convReLU, bn=convBN)
        self.dec4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)

        self.deconv3 = TRB2d(channels[3], channels[2], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec3_1 = CRB2d(channels[3], channels[2], relu=convReLU, bn=convBN)
        self.dec3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)

        self.deconv2 = TRB2d(channels[2], channels[1], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec2_1 = CRB2d(channels[2], channels[1], relu=convReLU, bn=convBN)
        self.dec2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)

        self.deconv1 = TRB2d(channels[1], channels[0], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec1_1 = CRB2d(channels[1], channels[0], relu=convReLU, bn=convBN)
        self.dec1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)

        self.classifier = CRB2d(channels[0], num_hologram, relu=False, bn=False)

    def forward(self, x):
        # Encoder
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)
        enc5_2 = self.enc5_2(enc5_1)

        deconv4 = self.deconv4(enc5_2)
        concat4 = torch.cat((deconv4, enc4_2), dim=1)
        dec4_1 = self.dec4_1(concat4)
        dec4_2 = self.dec4_2(dec4_1)

        deconv3 = self.deconv3(dec4_2)
        concat3 = torch.cat((deconv3, enc3_2), dim=1)
        dec3_1 = self.dec3_1(concat3)
        dec3_2 = self.dec3_2(dec3_1)

        deconv2 = self.deconv2(dec3_2)
        concat2 = torch.cat((deconv2, enc2_2), dim=1)
        dec2_1 = self.dec2_1(concat2)
        dec2_2 = self.dec2_2(dec2_1)

        deconv1 = self.deconv1(dec2_2)
        concat1 = torch.cat((deconv1, enc1_2), dim=1)
        dec1_1 = self.dec1_1(concat1)
        dec1_2 = self.dec1_2(dec1_1)

        # Final classifier
        out = self.classifier(dec1_2)
        out = nn.Sigmoid()(out)
        return out


model = BinaryNet(num_hologram=CH, in_planes=1, convReLU=False,
                  convBN=False, poolReLU=False, poolBN=False,
                  deconvReLU=False, deconvBN=False).cuda()
test = torch.randn(1, 1, IPS, IPS).cuda()
out = model(test)
print(out.shape)

torch.Size([1, 8, 1024, 1024])


In [3]:
def rgb_binary_train_padding(model, num_hologram=24, custom_name='0703Unet', epochs=500, lr=1e-4, z=2e-3):
    date = datetime.datetime.now() - datetime.timedelta(hours=15)
    model.apply(initialize_weights)
    mean_psnr_list = []
    mean_mse_list = []
    max_epoch = 0
    max_psnr = 0
    crop = torchvision.transforms.CenterCrop(896)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    sign = SignFunction.apply
    # path = 'models/'
    model_name = f'{date}_{custom_name}_{num_hologram}_{z}'
    # make result folder
    result_folder = 'result/'+model_name
    os.mkdir(result_folder)
    num_parameters = count_parameters(model)
    model_path = os.path.join(result_folder, model_name)
    psnr_result_path = os.path.join(result_folder, 'meanPSNR.npy')
    mse_result_path = os.path.join(result_folder, 'meanMSE.npy')
    result = {'model_path': model_path, 'psnr_result_path': psnr_result_path,
          'mse_result_path': mse_result_path, 'z': z,
          'num_hologram': num_hologram, 'lr': lr, 'epochs': epochs,
            'num_parameters': num_parameters}
    # save result dict
    with open(os.path.join(result_folder, 'result_dict.pickle'), 'wb') as fw:
        pickle.dump(result, fw)
    for epoch in range(epochs):
        pbar = tqdm.tqdm(trainloader)
        model.train()
        for target in pbar:
            target = target
            out = model(target)
            binary = sign(out)
            sim = tt.simulate(binary, 2e-3).abs()**2
            mean_sim = torch.mean(sim, dim=1, keepdim=True)
            rgb = crop(mean_sim)
            croped_target = crop(target)
            loss = tt.relativeLoss(rgb, croped_target, F.mse_loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        model.eval()
        with torch.no_grad():
            psnrList = []
            mseList = []
            for idx, valid in enumerate(validloader):
                valid = valid
                out = model(valid)
                binary = sign(out)
                sim = tt.simulate(binary, 2e-3).abs()**2
                mean_sim = torch.mean(sim, dim=1, keepdim=True)
                rgb = crop(mean_sim)
                croped_target = crop(valid)
                psnr = tt.relativeLoss(rgb, croped_target, tm.get_PSNR)
                psnrList.append(psnr)
                mse = tt.relativeLoss(rgb, croped_target, F.mse_loss).detach().cpu().numpy()
                mseList.append(mse)
                if idx == 0:
                    pl = tt.show_with_insets(rgb, croped_target, correct_colorwise=True)
                    pl.save(os.path.join(result_folder, 'penguin.png'))
            mean_psnr = sum(psnrList)/len(psnrList)
            mean_mse = sum(mseList)/len(mseList)
            mean_psnr_list.append(mean_psnr)
            mean_mse_list.append(mean_mse)
            # save psnr and mse
            if epoch % 10 == 0:
                np.save(psnr_result_path, mean_psnr_list)
                np.save(mse_result_path, mean_mse_list)
                plt.plot(np.arange(len(mean_psnr_list)), mean_psnr_list,
                         label=f'max psnr: {max(mean_psnr_list)}')
                plt.title(model_name+'_psnr')
                plt.legend()
                plt.savefig(os.path.join(result_folder, 'meanPSNR.png'))
                plt.clf()
                plt.plot(np.arange(len(mean_mse_list)), mean_mse_list,
                         label=f'max psnr: {min(mean_mse_list)}')
                plt.title(model_name+'_mse')
                plt.legend()
                plt.savefig(os.path.join(result_folder, 'meanMSE.png'))
                plt.clf()
            print(f'mean PSNR : {mean_psnr} {epoch}/{epochs}')
            print(f'mean MSE : {mean_mse} {epoch}/{epochs}')

        if mean_psnr > max_psnr:
            max_psnr = mean_psnr
            max_epoch = epoch
            torch.save(model.state_dict(), model_path)
        print(f'max_psnr: {max_psnr}, epoch: {max_epoch}')

In [4]:
rgb_binary_train_padding(model, num_hologram=CH, custom_name='pre_reinforce', epochs=500, lr=1e-4, z=2e-3)

100%|██████████| 800/800 [00:52<00:00, 15.23it/s]


mean PSNR : 22.81658069610596 0/500
mean MSE : 0.0055587810545694085 0/500
max_psnr: 22.81658069610596, epoch: 0


100%|██████████| 800/800 [00:51<00:00, 15.57it/s]


mean PSNR : 21.670237197875977 1/500
mean MSE : 0.0072337186918593945 1/500
max_psnr: 22.81658069610596, epoch: 0


100%|██████████| 800/800 [00:51<00:00, 15.55it/s]


mean PSNR : 25.42415870666504 15/500
mean MSE : 0.0031933411897625773 15/500
max_psnr: 25.42415870666504, epoch: 15


100%|██████████| 800/800 [00:51<00:00, 15.57it/s]


mean PSNR : 25.211367816925048 16/500
mean MSE : 0.0033057986514177175 16/500
max_psnr: 25.42415870666504, epoch: 15


100%|██████████| 800/800 [00:51<00:00, 15.60it/s]


mean PSNR : 25.496622619628905 17/500
mean MSE : 0.0031583408848382534 17/500
max_psnr: 25.496622619628905, epoch: 17


100%|██████████| 800/800 [00:51<00:00, 15.57it/s]


mean PSNR : 25.551422080993653 18/500
mean MSE : 0.0030636929778847842 18/500
max_psnr: 25.551422080993653, epoch: 18


100%|██████████| 800/800 [00:51<00:00, 15.55it/s]


mean PSNR : 25.200543632507323 19/500
mean MSE : 0.003317460144171491 19/500
max_psnr: 25.551422080993653, epoch: 18


100%|██████████| 800/800 [00:51<00:00, 15.57it/s]


mean PSNR : 25.42896738052368 20/500
mean MSE : 0.0031556541193276644 20/500
max_psnr: 25.551422080993653, epoch: 18


100%|██████████| 800/800 [00:51<00:00, 15.59it/s]


mean PSNR : 25.45256633758545 21/500
mean MSE : 0.00314883139799349 21/500
max_psnr: 25.551422080993653, epoch: 18


 96%|█████████▌| 768/800 [00:49<00:02, 15.72it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|██████████| 800/800 [00:51<00:00, 15.57it/s]


mean PSNR : 24.567076511383057 351/500
mean MSE : 0.0037286372587550434 351/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.54it/s]


mean PSNR : 25.63066474914551 352/500
mean MSE : 0.0029878778161946686 352/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.60it/s]


mean PSNR : 25.539926204681397 353/500
mean MSE : 0.003051755939377472 353/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.57it/s]


mean PSNR : 25.603135471343993 354/500
mean MSE : 0.0030147780454717575 354/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.59it/s]


mean PSNR : 25.49263189315796 355/500
mean MSE : 0.0030751583707751707 355/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.59it/s]


mean PSNR : 25.434596576690673 356/500
mean MSE : 0.0031125043757492675 356/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.57it/s]


mean PSNR : 25.698878841400145 357/500
mean MSE : 0.0029370045469840986 357/500
max_psnr: 26.06439317703247, epoch: 264


 46%|████▌     | 364/800 [00:23<00:28, 15.11it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

100%|██████████| 800/800 [00:51<00:00, 15.55it/s]


mean PSNR : 25.229074211120604 364/500
mean MSE : 0.003248410096857697 364/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.52it/s]


mean PSNR : 25.274718360900877 365/500
mean MSE : 0.0032457676681224255 365/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.55it/s]


mean PSNR : 24.974621963500976 366/500
mean MSE : 0.0034063043538481 366/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.50it/s]


mean PSNR : 24.984821643829346 367/500
mean MSE : 0.0033927940996363757 367/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.50it/s]


mean PSNR : 24.33100471496582 368/500
mean MSE : 0.0038514892768580465 368/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.59it/s]


mean PSNR : 24.959243335723876 369/500
mean MSE : 0.003418365928228013 369/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.54it/s]


mean PSNR : 25.32785556793213 370/500
mean MSE : 0.0032090782519662753 370/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.50it/s]


mean PSNR : 25.118787899017335 371/500
mean MSE : 0.003371197273954749 371/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.53it/s]


mean PSNR : 24.741100006103515 372/500
mean MSE : 0.0035896848794072867 372/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.49it/s]


mean PSNR : 24.93101619720459 373/500
mean MSE : 0.0034368896088562907 373/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.48it/s]


mean PSNR : 25.037002182006837 374/500
mean MSE : 0.003365889369742945 374/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.47it/s]


mean PSNR : 25.15826171875 375/500
mean MSE : 0.0032745793438516556 375/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.51it/s]


mean PSNR : 25.212085762023925 376/500
mean MSE : 0.0032500659849029035 376/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.49it/s]


mean PSNR : 25.47067050933838 377/500
mean MSE : 0.0030913738894741984 377/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.47it/s]


mean PSNR : 25.630239448547364 378/500
mean MSE : 0.003010152163915336 378/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.49it/s]


mean PSNR : 25.60356128692627 379/500
mean MSE : 0.003033597865724005 379/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.53it/s]


mean PSNR : 25.192995567321777 380/500
mean MSE : 0.003243379060877487 380/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.56it/s]


mean PSNR : 25.58726568222046 381/500
mean MSE : 0.0030422493303194644 381/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.54it/s]


mean PSNR : 25.205943202972414 382/500
mean MSE : 0.0032293714181287213 382/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.54it/s]


mean PSNR : 25.495267868041992 383/500
mean MSE : 0.0030530628928681835 383/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.49it/s]


mean PSNR : 24.972506198883057 384/500
mean MSE : 0.0033977239415980878 384/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.48it/s]


mean PSNR : 24.56971565246582 385/500
mean MSE : 0.0037134720390895382 385/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.51it/s]


mean PSNR : 24.867336463928222 386/500
mean MSE : 0.0035053724993485955 386/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.48it/s]


mean PSNR : 24.93383871078491 387/500
mean MSE : 0.003448840818600729 387/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.49it/s]


mean PSNR : 24.9675136756897 388/500
mean MSE : 0.0033897838869597764 388/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.51it/s]


mean PSNR : 24.9004744720459 389/500
mean MSE : 0.003466175100766122 389/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.48it/s]


mean PSNR : 24.76541160583496 390/500
mean MSE : 0.0034957835439126937 390/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.55it/s]


mean PSNR : 25.52624921798706 391/500
mean MSE : 0.0030424864054657517 391/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.50it/s]


mean PSNR : 24.33595775604248 392/500
mean MSE : 0.003813735553994775 392/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.51it/s]


mean PSNR : 25.178013572692873 393/500
mean MSE : 0.0032774551486363635 393/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.49it/s]


mean PSNR : 25.447124042510985 394/500
mean MSE : 0.0031032181676710026 394/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.48it/s]


mean PSNR : 24.90296787261963 395/500
mean MSE : 0.0034156918560620396 395/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.48it/s]


mean PSNR : 25.573102893829347 396/500
mean MSE : 0.0030187476990977302 396/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.45it/s]


mean PSNR : 25.112965965270995 397/500
mean MSE : 0.0032735015300568195 397/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.51it/s]


mean PSNR : 25.410344886779786 398/500
mean MSE : 0.0031399334501475097 398/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.53it/s]


mean PSNR : 25.28378204345703 399/500
mean MSE : 0.0032398215477587654 399/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.48it/s]


mean PSNR : 25.070310096740723 400/500
mean MSE : 0.00339429299114272 400/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.45it/s]


mean PSNR : 25.401585483551024 401/500
mean MSE : 0.0031425561511423438 401/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.45it/s]


mean PSNR : 24.839025688171386 402/500
mean MSE : 0.003560072148684412 402/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.43it/s]


mean PSNR : 24.87261070251465 403/500
mean MSE : 0.0035336441185791047 403/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.47it/s]


mean PSNR : 24.143796691894533 404/500
mean MSE : 0.004080291987629607 404/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.51it/s]


mean PSNR : 24.579258785247802 405/500
mean MSE : 0.0036865989584475758 405/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.48it/s]


mean PSNR : 24.693927764892578 406/500
mean MSE : 0.0036768818029668184 406/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.52it/s]


mean PSNR : 24.699600524902344 407/500
mean MSE : 0.0036031383648514746 407/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.55it/s]


mean PSNR : 25.22807819366455 408/500
mean MSE : 0.00332926619972568 408/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.49it/s]


mean PSNR : 24.423466281890867 409/500
mean MSE : 0.0038395061902701854 409/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.43it/s]


mean PSNR : 25.422402534484863 410/500
mean MSE : 0.0031306560477241873 410/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.51it/s]


mean PSNR : 25.197124366760253 411/500
mean MSE : 0.0032560347532853483 411/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.54it/s]


mean PSNR : 24.64380016326904 412/500
mean MSE : 0.0036642722762189805 412/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:52<00:00, 15.33it/s]


mean PSNR : 25.605997886657715 413/500
mean MSE : 0.003007934804772958 413/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.44it/s]


mean PSNR : 25.6096000289917 414/500
mean MSE : 0.003015433991095051 414/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.39it/s]


mean PSNR : 25.456971054077147 415/500
mean MSE : 0.0030985429632710295 415/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 24.62374231338501 416/500
mean MSE : 0.003634044856298715 416/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.45it/s]


mean PSNR : 25.568501014709472 417/500
mean MSE : 0.0030284265975933523 417/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.42it/s]


mean PSNR : 25.163401374816896 418/500
mean MSE : 0.0033262408600421623 418/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.47it/s]


mean PSNR : 25.656374530792235 419/500
mean MSE : 0.0029490186274051667 419/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 25.35620101928711 420/500
mean MSE : 0.003161165512865409 420/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.44it/s]


mean PSNR : 25.512462043762206 421/500
mean MSE : 0.00307389072549995 421/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.48it/s]


mean PSNR : 25.041872005462647 422/500
mean MSE : 0.003353947241557762 422/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.48it/s]


mean PSNR : 25.695339488983155 423/500
mean MSE : 0.002907008174806833 423/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.44it/s]


mean PSNR : 25.440362720489503 424/500
mean MSE : 0.003097084722830914 424/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 25.71968879699707 425/500
mean MSE : 0.0029282642778707666 425/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.42it/s]


mean PSNR : 24.868009643554686 426/500
mean MSE : 0.003452918989351019 426/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.43it/s]


mean PSNR : 25.040607223510744 427/500
mean MSE : 0.0033257421874441206 427/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 25.32773229598999 428/500
mean MSE : 0.0031570806331001223 428/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.44it/s]


mean PSNR : 24.916000366210938 429/500
mean MSE : 0.00343538468121551 429/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.42it/s]


mean PSNR : 25.25294734954834 430/500
mean MSE : 0.0032070984435267746 430/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.44it/s]


mean PSNR : 24.887283363342284 431/500
mean MSE : 0.003449012152850628 431/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.45it/s]


mean PSNR : 25.168030242919922 432/500
mean MSE : 0.0032820657000411303 432/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.41it/s]


mean PSNR : 24.835892734527587 433/500
mean MSE : 0.0034850333037320523 433/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.48it/s]


mean PSNR : 24.880945014953614 434/500
mean MSE : 0.0034786801680456846 434/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:52<00:00, 15.35it/s]


mean PSNR : 25.670610885620118 435/500
mean MSE : 0.0030254016217077153 435/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.45it/s]


mean PSNR : 25.247664737701417 436/500
mean MSE : 0.0032435601809993387 436/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.43it/s]


mean PSNR : 25.633361682891845 437/500
mean MSE : 0.002996795795625076 437/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.41it/s]


mean PSNR : 25.701972599029542 438/500
mean MSE : 0.0029411654616706075 438/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.43it/s]


mean PSNR : 25.002516193389894 439/500
mean MSE : 0.0033614161418518054 439/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 25.056938247680662 440/500
mean MSE : 0.0033916723617585377 440/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 25.129334411621095 441/500
mean MSE : 0.003335391697473824 441/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:52<00:00, 15.31it/s]


mean PSNR : 25.274898014068604 442/500
mean MSE : 0.003233376123243943 442/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:52<00:00, 15.12it/s]


mean PSNR : 24.765250244140624 443/500
mean MSE : 0.003563632977893576 443/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:52<00:00, 15.15it/s]


mean PSNR : 25.14735080718994 444/500
mean MSE : 0.0032574261305853723 444/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:52<00:00, 15.17it/s]


mean PSNR : 25.34294984817505 445/500
mean MSE : 0.0031440950179239735 445/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:52<00:00, 15.16it/s]


mean PSNR : 25.51223653793335 446/500
mean MSE : 0.003029988214257173 446/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:52<00:00, 15.38it/s]


mean PSNR : 25.560094089508056 447/500
mean MSE : 0.0030047816364094614 447/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 25.698781757354737 448/500
mean MSE : 0.0029629162955097854 448/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.41it/s]


mean PSNR : 24.505906848907472 449/500
mean MSE : 0.003806670616613701 449/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.40it/s]


mean PSNR : 24.707480430603027 450/500
mean MSE : 0.0035929219680838286 450/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.40it/s]


mean PSNR : 25.50457260131836 451/500
mean MSE : 0.0030788265843875707 451/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 25.838055305480957 452/500
mean MSE : 0.002875617050449364 452/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.39it/s]


mean PSNR : 23.885280799865722 453/500
mean MSE : 0.004315193044021726 453/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.39it/s]


mean PSNR : 25.061710777282716 454/500
mean MSE : 0.003332987338071689 454/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.42it/s]


mean PSNR : 25.4795112991333 455/500
mean MSE : 0.0031073246063897386 455/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.44it/s]


mean PSNR : 24.504277420043945 456/500
mean MSE : 0.0037033736635930835 456/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.45it/s]


mean PSNR : 25.556337146759034 457/500
mean MSE : 0.0030242120317416268 457/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.42it/s]


mean PSNR : 25.67868274688721 458/500
mean MSE : 0.00298406278132461 458/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.39it/s]


mean PSNR : 24.810231952667237 459/500
mean MSE : 0.003500644530868158 459/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.40it/s]


mean PSNR : 25.08111740112305 460/500
mean MSE : 0.0033609306369908154 460/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.43it/s]


mean PSNR : 25.73740686416626 461/500
mean MSE : 0.0029289769078604875 461/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.43it/s]


mean PSNR : 25.182139739990234 462/500
mean MSE : 0.0032385272765532136 462/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:52<00:00, 15.35it/s]


mean PSNR : 25.144338817596434 463/500
mean MSE : 0.0032578144734725354 463/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.45it/s]


mean PSNR : 25.065560455322267 464/500
mean MSE : 0.0033817536872811616 464/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:52<00:00, 15.33it/s]


mean PSNR : 24.507924690246583 465/500
mean MSE : 0.0037304383795708418 465/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 23.3701780128479 466/500
mean MSE : 0.004702489071059972 466/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:52<00:00, 15.17it/s]


mean PSNR : 12.968298711776733 467/500
mean MSE : 0.05320761422626674 467/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.40it/s]


mean PSNR : 12.968298778533935 468/500
mean MSE : 0.05320761380717158 468/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 12.968298664093018 469/500
mean MSE : 0.0532076152972877 469/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.44it/s]


mean PSNR : 12.968315525054932 470/500
mean MSE : 0.05320732584223151 470/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.43it/s]


mean PSNR : 12.968315525054932 471/500
mean MSE : 0.05320732584223151 471/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.47it/s]


mean PSNR : 12.968315525054932 472/500
mean MSE : 0.05320732584223151 472/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.47it/s]


mean PSNR : 12.968315525054932 473/500
mean MSE : 0.05320732584223151 473/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.49it/s]


mean PSNR : 12.968315525054932 474/500
mean MSE : 0.05320732584223151 474/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.45it/s]


mean PSNR : 12.968315525054932 475/500
mean MSE : 0.05320732584223151 475/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 12.968315525054932 476/500
mean MSE : 0.05320732584223151 476/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 12.96831789970398 477/500
mean MSE : 0.05320729680359364 477/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 12.968317890167237 478/500
mean MSE : 0.053207296691834924 478/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.49it/s]


mean PSNR : 12.968317890167237 479/500
mean MSE : 0.053207296691834924 479/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.48it/s]


mean PSNR : 12.968317890167237 480/500
mean MSE : 0.053207296691834924 480/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.49it/s]


mean PSNR : 12.968317890167237 481/500
mean MSE : 0.053207296691834924 481/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 12.968317890167237 482/500
mean MSE : 0.053207296691834924 482/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.44it/s]


mean PSNR : 12.968317890167237 483/500
mean MSE : 0.053207296691834924 483/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.48it/s]


mean PSNR : 12.968317890167237 484/500
mean MSE : 0.053207296691834924 484/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.47it/s]


mean PSNR : 12.968317890167237 485/500
mean MSE : 0.053207296691834924 485/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.45it/s]


mean PSNR : 12.968317890167237 486/500
mean MSE : 0.053207296691834924 486/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.45it/s]


mean PSNR : 12.968317890167237 487/500
mean MSE : 0.053207296691834924 487/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 12.968317890167237 488/500
mean MSE : 0.053207296691834924 488/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.42it/s]


mean PSNR : 12.968317890167237 489/500
mean MSE : 0.053207296691834924 489/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.43it/s]


mean PSNR : 12.968317890167237 490/500
mean MSE : 0.053207296691834924 490/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.45it/s]


mean PSNR : 12.968317890167237 491/500
mean MSE : 0.053207296691834924 491/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.48it/s]


mean PSNR : 12.968317890167237 492/500
mean MSE : 0.053207296691834924 492/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.47it/s]


mean PSNR : 12.968317890167237 493/500
mean MSE : 0.053207296691834924 493/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.46it/s]


mean PSNR : 12.968317890167237 494/500
mean MSE : 0.053207296691834924 494/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.47it/s]


mean PSNR : 12.968317890167237 495/500
mean MSE : 0.053207296691834924 495/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.44it/s]


mean PSNR : 12.968317890167237 496/500
mean MSE : 0.053207296691834924 496/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.44it/s]


mean PSNR : 12.96831829071045 497/500
mean MSE : 0.053207289176061746 497/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.49it/s]


mean PSNR : 12.96831829071045 498/500
mean MSE : 0.053207289176061746 498/500
max_psnr: 26.06439317703247, epoch: 264


100%|██████████| 800/800 [00:51<00:00, 15.49it/s]


mean PSNR : 12.96831829071045 499/500
mean MSE : 0.053207289176061746 499/500
max_psnr: 26.06439317703247, epoch: 264


<Figure size 640x480 with 0 Axes>

In [5]:
valid_dataset[0].shape

torch.Size([1, 1024, 1024])

In [6]:
sign = SignFunction.apply
out = model(valid_dataset[0].unsqueeze(0).unsqueeze(0))
binary = sign(out, 0.5)
sim = tt.simulate(binary, 2e-3)
mean_sim = torch.mean(sim, dim=1, keepdim=True)

RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [1, 1, 1, 1024, 1024]

In [None]:
for target in trainloader:
    print(target.shape)
    break

In [None]:
target

In [None]:
mean_sim