In [4]:
import torchOptics.optics as tt
import warnings
import torch.nn as nn
import torchOptics.metrics as tm
import torch.nn.functional as F
import torch.optim
import torch
from torch.utils.data import Dataset, DataLoader
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torchvision
import datetime
import tqdm
import time
import pandas as pd
# import HHS.model


warnings.filterwarnings('ignore')


class SignFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        t = torch.Tensor([0.5]).to(input.device)  # threshold
        output = (input > t).float() * 1
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output * torch.ones_like(input)  # Replace with your custom gradient computation
        return grad_input

class RealSign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output * torch.ones_like(input)  # Replace with your custom gradient computation
        return grad_input


class Dataset512(Dataset):
    def __init__(self, target_dir, meta, transform=None, isTrain=True, padding=0):
        self.target_dir = target_dir
        self.transform = transform
        self.meta = meta
        self.isTrain = isTrain
        self.target_list = sorted(glob.glob(target_dir+'*.png'))
        self.center_crop = torchvision.transforms.CenterCrop(512)
        self.random_crop = torchvision.transforms.RandomCrop((512, 512))
        self.padding = padding

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        target = tt.imread(self.target_list[idx], meta=meta, gray=True).unsqueeze(0)
        if target.shape[-1] < 512 or target.shape[-2] < 512:
            target = torchvision.transforms.Resize(512)(target)
        if self.isTrain:
            target = self.random_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        else:
            target = self.center_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        return target


def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)

# load data
batch_size = 1
# target_dir = 'dataset/DIV2K/DIV2K_train_HR/'
# valid_dir = 'dataset/DIV2K/DIV2K_valid_HR/'
target_dir = '/nfs/dataset/DIV2K/DIV2K_train_HR/DIV2K_train_HR/'
valid_dir = '/nfs/dataset/DIV2K/DIV2K_valid_HR/DIV2K_valid_HR/'
# sim_dir = 'binary_dataset/simulated/'
meta = {'wl' : (515e-9), 'dx':(7.56e-6, 7.56e-6)}
padding = 0
train_dataset = Dataset512(target_dir=target_dir, meta=meta, isTrain=True, padding=padding)
valid_dataset = Dataset512(target_dir=valid_dir, meta=meta, isTrain=False, padding=padding)

trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
validloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)


def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)


def train(model, num_hologram=10, custom_name='fcn', epochs=100, lr=1e-4, z=2e-3):
    date = datetime.datetime.now() - datetime.timedelta(hours=15)
    mean_psnr_list = []
    mean_mse_list = []
    max_epoch = 0
    max_psnr = 0
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    sign_function = SignFunction.apply
    # path = 'models/'
    model_name = f'{date}_{custom_name}_{num_hologram}_{z}'
    # make result folder
    result_folder = 'result/'+model_name
    os.mkdir(result_folder)
    model_path = os.path.join(result_folder, model_name)
    psnr_result_path = os.path.join(result_folder, 'meanPSNR.npy')
    mse_result_path = os.path.join(result_folder, 'meanMSE.npy')
    result = {'model_path': model_path, 'psnr_result_path': psnr_result_path,
          'mse_result_path': mse_result_path, 'z': z,
          'num_hologram': num_hologram, 'lr': lr, 'epochs': epochs}
    # save result dict
    with open(os.path.join(result_folder, 'result_dict.pickle'), 'wb') as fw:
        pickle.dump(result, fw)
    for epoch in range(epochs):
        pbar = tqdm.tqdm(trainloader)
        model.train()
        for target in pbar:
            out = model(target)
            binary = sign_function(out)
            sim = tt.simulate(binary, z).abs()**2
            result = torch.mean(sim, dim=1, keepdim=True)
            loss = tt.relativeLoss(result, target, F.mse_loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        model.eval()
        with torch.no_grad():
            psnrList = []
            mseList = []
            for valid in validloader:
                out = model(valid)
                binary = sign_function(out)
                sim = tt.simulate(binary, z).abs()**2
                result = torch.mean(sim, dim=1, keepdim=True)
                psnr = tt.relativeLoss(result, valid, tm.get_PSNR)
                psnrList.append(psnr)
                mse = tt.relativeLoss(result, valid, F.mse_loss).detach().cpu().numpy()
                mseList.append(mse)
            mean_psnr = sum(psnrList)/len(psnrList)
            mean_mse = sum(mseList)/len(mseList)
            mean_psnr_list.append(mean_psnr)
            mean_mse_list.append(mean_mse)
            # save psnr and mse
            np.save(psnr_result_path, mean_psnr_list)
            np.save(mse_result_path, mean_mse_list)
            plt.plot(np.arange(len(mean_psnr_list)), mean_psnr_list,
                     label=f'max psnr: {max(mean_psnr_list)}')
            plt.title(model_name+'_psnr')
            plt.legend()
            plt.savefig(os.path.join(result_folder, 'meanPSNR.png'))
            plt.clf()
            plt.plot(np.arange(len(mean_mse_list)), mean_mse_list,
                     label=f'max psnr: {min(mean_mse_list)}')
            plt.title(model_name+'_mse')
            plt.legend()
            plt.savefig(os.path.join(result_folder, 'meanMSE.png'))
            plt.clf()
            print(f'mean PSNR : {mean_psnr} {epoch}/{epochs}')
            print(f'mean MSE : {mean_mse} {epoch}/{epochs}')

        if mean_psnr > max_psnr:
            max_psnr = mean_psnr
            max_epoch = epoch
            torch.save(model.state_dict(), model_path)
        print(f'max_psnr: {max_psnr}, epoch: {max_epoch}')


def binary_sim(out, z=2e-3):
    binary = SignFunction.apply(out)
    sim = tt.simulate(binary, z).abs()**2
    res = torch.mean(sim, dim=1, keepdim=True)
    return binary, res


def valid(model):
    with torch.no_grad():
        psnr_list = []
        for target in validloader:
            out = model(target)
            binary, res = binary_sim(out)
            psnr = tt.relativeLoss(res, target, tm.get_PSNR)
            psnr_list.append(psnr)
    print(sum(psnr_list)/len(psnr_list))


def check_order(bmodel, classifier, target):
    if len(target.shape) == 3:
        target = target.unsqueeze(0)
    result = []
    with torch.no_grad():
        out = bmodel(target)
        binary, res = binary_sim(out)
        for i in range(10):
            binary_input = binary[0][i].unsqueeze(0).unsqueeze(0)
            prop_order = classifier(binary_input)
            result.append(torch.argmax(prop_order))
    print(result)


def check_order_without_model(binary, classifier):
    with torch.no_grad():
        result = []
        for i in range(10):
            binary_input = binary[0][i].unsqueeze(0).unsqueeze(0)
            prop_order = classifier(binary_input)
            result.append(torch.argmax(prop_order))
    print(result)


class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc = nn.Linear(64 * 64 * 64, 10)  # 최종 분류를 위한 FC 레이어

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = self.pool3(torch.relu(self.conv3(x)))
        x = x.view(-1, 64 * 64 * 64)  # flatten
        x = self.fc(x)
        return x


def rgb_binary_sim(out, z, th):
    pixel_pitch = 7.56e-6
    meta = {'wl' : (638e-9, 515e-9, 450e-9), 'dx':(pixel_pitch, pixel_pitch)}
    rmeta = {'wl': (638e-9), 'dx': (pixel_pitch, pixel_pitch)}
    gmeta = {'wl': (515e-9), 'dx': (pixel_pitch, pixel_pitch)}
    bmeta = {'wl': (450e-9), 'dx': (pixel_pitch, pixel_pitch)}
    sign = SignFunction.apply
    binary = sign(out, th)
    channel = out.shape[1]
    rchannel = int(channel/3)
    gchannel = int(channel*2/3)
    red = binary[:, :rchannel, :, :]
    green = binary[:, rchannel:gchannel, :, :]
    blue = binary[:, gchannel:, :, :]
    red = tt.Tensor(red, meta=rmeta)
    green = tt.Tensor(green, meta=gmeta)
    blue = tt.Tensor(blue, meta=bmeta)
    rsim = tt.simulate(red, z).abs()**2
    gsim = tt.simulate(green, z).abs()**2
    bsim = tt.simulate(blue, z).abs()**2
    rmean = torch.mean(rsim, dim=1, keepdim=True)
    gmean = torch.mean(gsim, dim=1, keepdim=True)
    bmean = torch.mean(bsim, dim=1, keepdim=True)
    rgb = torch.cat([rmean, gmean, bmean], dim=1)
    rgb = tt.Tensor(rgb, meta=meta)
    binary = tt.Tensor(binary, meta=meta)
    return binary, rgb

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [5]:
class BinaryNet(nn.Module):
    def __init__(self, num_hologram, final='Sigmoid', in_planes=3,
                 channels=[32, 64, 128, 256, 512, 1024, 2048, 4096],
                 convReLU=True, convBN=True, poolReLU=True, poolBN=True,
                 deconvReLU=True, deconvBN=True):
        super(BinaryNet, self).__init__()

        def CRB2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias=bias)]
            if relu:
                layers += [nn.Tanh()]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]

            cbr = nn.Sequential(*layers) # *으로 list unpacking 

            return cbr

        def TRB2d(in_channels, out_channels, kernel_size=2, stride=2, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                          kernel_size=2, stride=2, padding=0,
                                          bias=True)]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]
            if relu:
                layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers) # *으로 list unpacking 

            return cbr

        self.enc1_1 = CRB2d(in_planes, channels[0], relu=convReLU, bn=convBN)
        self.enc1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)
        self.pool1 = CRB2d(channels[0], channels[0], stride=2, relu=poolReLU, bn=poolBN)

        self.enc2_1 = CRB2d(channels[0], channels[1], relu=convReLU, bn=convBN)
        self.enc2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)
        self.pool2 = CRB2d(channels[1], channels[1], stride=2, relu=poolReLU, bn=poolBN)

        self.enc3_1 = CRB2d(channels[1], channels[2], relu=convReLU, bn=convBN)
        self.enc3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)
        self.pool3 = CRB2d(channels[2], channels[2], stride=2, relu=poolReLU, bn=poolBN)

        self.enc4_1 = CRB2d(channels[2], channels[3], relu=convReLU, bn=convBN)
        self.enc4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)
        self.pool4 = CRB2d(channels[3], channels[3], stride=2, relu=poolReLU, bn=poolBN)

        self.enc5_1 = CRB2d(channels[3], channels[4], relu=convReLU, bn=convBN)
        self.enc5_2 = CRB2d(channels[4], channels[4], relu=convReLU, bn=convBN)

        self.deconv4 = TRB2d(channels[4], channels[3], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec4_1 = CRB2d(channels[4], channels[3], relu=convReLU, bn=convBN)
        self.dec4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)

        self.deconv3 = TRB2d(channels[3], channels[2], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec3_1 = CRB2d(channels[3], channels[2], relu=convReLU, bn=convBN)
        self.dec3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)

        self.deconv2 = TRB2d(channels[2], channels[1], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec2_1 = CRB2d(channels[2], channels[1], relu=convReLU, bn=convBN)
        self.dec2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)

        self.deconv1 = TRB2d(channels[1], channels[0], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec1_1 = CRB2d(channels[1], channels[0], relu=convReLU, bn=convBN)
        self.dec1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)

        self.classifier = CRB2d(channels[0], num_hologram, relu=False, bn=False)

    def forward(self, x):
        # Encoder
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)
        enc5_2 = self.enc5_2(enc5_1)

        deconv4 = self.deconv4(enc5_2)
        concat4 = torch.cat((deconv4, enc4_2), dim=1)
        dec4_1 = self.dec4_1(concat4)
        dec4_2 = self.dec4_2(dec4_1)

        deconv3 = self.deconv3(dec4_2)
        concat3 = torch.cat((deconv3, enc3_2), dim=1)
        dec3_1 = self.dec3_1(concat3)
        dec3_2 = self.dec3_2(dec3_1)

        deconv2 = self.deconv2(dec3_2)
        concat2 = torch.cat((deconv2, enc2_2), dim=1)
        dec2_1 = self.dec2_1(concat2)
        dec2_2 = self.dec2_2(dec2_1)

        deconv1 = self.deconv1(dec2_2)
        concat1 = torch.cat((deconv1, enc1_2), dim=1)
        dec1_1 = self.dec1_1(concat1)
        dec1_2 = self.dec1_2(dec1_1)

        # Final classifier
        out = self.classifier(dec1_2)
        out = nn.Sigmoid()(out)
        return out


model = BinaryNet(num_hologram=8, in_planes=1, convReLU=False,
                  convBN=False, poolReLU=False, poolBN=False,
                  deconvReLU=False, deconvBN=False).cuda()
test = torch.randn(1, 1, 512, 512).cuda()
out = model(test)
print(out.shape)

torch.Size([1, 8, 512, 512])


In [6]:
def rgb_binary_train_padding(model, num_hologram=24, custom_name='0703Unet', epochs=500, lr=1e-4, z=2e-3):
    date = datetime.datetime.now() - datetime.timedelta(hours=15)
    model.apply(initialize_weights)
    mean_psnr_list = []
    mean_mse_list = []
    max_epoch = 0
    max_psnr = 0
    crop = torchvision.transforms.CenterCrop(896)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    sign = SignFunction.apply
    # path = 'models/'
    model_name = f'{date}_{custom_name}_{num_hologram}_{z}'
    # make result folder
    result_folder = 'result/'+model_name
    os.mkdir(result_folder)
    num_parameters = count_parameters(model)
    model_path = os.path.join(result_folder, model_name)
    psnr_result_path = os.path.join(result_folder, 'meanPSNR.npy')
    mse_result_path = os.path.join(result_folder, 'meanMSE.npy')
    result = {'model_path': model_path, 'psnr_result_path': psnr_result_path,
          'mse_result_path': mse_result_path, 'z': z,
          'num_hologram': num_hologram, 'lr': lr, 'epochs': epochs,
            'num_parameters': num_parameters}
    # save result dict
    with open(os.path.join(result_folder, 'result_dict.pickle'), 'wb') as fw:
        pickle.dump(result, fw)
    for epoch in range(epochs):
        pbar = tqdm.tqdm(trainloader)
        model.train()
        for target in pbar:
            target = target
            out = model(target)
            binary = sign(out)
            sim = tt.simulate(binary, 2e-3).abs()**2
            mean_sim = torch.mean(sim, dim=1, keepdim=True)
            rgb = crop(mean_sim)
            croped_target = crop(target)
            loss = tt.relativeLoss(rgb, croped_target, F.mse_loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        model.eval()
        with torch.no_grad():
            psnrList = []
            mseList = []
            for idx, valid in enumerate(validloader):
                valid = valid
                out = model(valid)
                binary = sign(out)
                sim = tt.simulate(binary, 2e-3).abs()**2
                mean_sim = torch.mean(sim, dim=1, keepdim=True)
                rgb = crop(mean_sim)
                croped_target = crop(valid)
                psnr = tt.relativeLoss(rgb, croped_target, tm.get_PSNR)
                psnrList.append(psnr)
                mse = tt.relativeLoss(rgb, croped_target, F.mse_loss).detach().cpu().numpy()
                mseList.append(mse)
                if idx == 0:
                    pl = tt.show_with_insets(rgb, croped_target, correct_colorwise=True)
                    pl.save(os.path.join(result_folder, 'penguin.png'))
            mean_psnr = sum(psnrList)/len(psnrList)
            mean_mse = sum(mseList)/len(mseList)
            mean_psnr_list.append(mean_psnr)
            mean_mse_list.append(mean_mse)
            # save psnr and mse
            if epoch % 10 == 0:
                np.save(psnr_result_path, mean_psnr_list)
                np.save(mse_result_path, mean_mse_list)
                plt.plot(np.arange(len(mean_psnr_list)), mean_psnr_list,
                         label=f'max psnr: {max(mean_psnr_list)}')
                plt.title(model_name+'_psnr')
                plt.legend()
                plt.savefig(os.path.join(result_folder, 'meanPSNR.png'))
                plt.clf()
                plt.plot(np.arange(len(mean_mse_list)), mean_mse_list,
                         label=f'max psnr: {min(mean_mse_list)}')
                plt.title(model_name+'_mse')
                plt.legend()
                plt.savefig(os.path.join(result_folder, 'meanMSE.png'))
                plt.clf()
            print(f'mean PSNR : {mean_psnr} {epoch}/{epochs}')
            print(f'mean MSE : {mean_mse} {epoch}/{epochs}')

        if mean_psnr > max_psnr:
            max_psnr = mean_psnr
            max_epoch = epoch
            torch.save(model.state_dict(), model_path)
        print(f'max_psnr: {max_psnr}, epoch: {max_epoch}')

In [None]:
rgb_binary_train_padding(model, num_hologram=8, custom_name='pre_reinforce', epochs=500, lr=1e-4, z=2e-3)

100%|██████████| 800/800 [00:43<00:00, 18.35it/s]


mean PSNR : 25.88376012802124 0/500
mean MSE : 0.0026824115216732025 0/500
max_psnr: 25.88376012802124, epoch: 0


100%|██████████| 800/800 [00:43<00:00, 18.37it/s]


mean PSNR : 27.181291637420653 1/500
mean MSE : 0.0020303982321638616 1/500
max_psnr: 27.181291637420653, epoch: 1


100%|██████████| 800/800 [00:43<00:00, 18.37it/s]


mean PSNR : 25.026188793182374 2/500
mean MSE : 0.003174679662333801 2/500
max_psnr: 27.181291637420653, epoch: 1


100%|██████████| 800/800 [00:43<00:00, 18.32it/s]


mean PSNR : 24.687065505981444 3/500
mean MSE : 0.0034312394668813795 3/500
max_psnr: 27.181291637420653, epoch: 1


100%|██████████| 800/800 [00:43<00:00, 18.37it/s]


mean PSNR : 21.63189739227295 4/500
mean MSE : 0.0068854451621882615 4/500
max_psnr: 27.181291637420653, epoch: 1


100%|██████████| 800/800 [00:43<00:00, 18.41it/s]


mean PSNR : 26.388931789398193 5/500
mean MSE : 0.002369480135967024 5/500
max_psnr: 27.181291637420653, epoch: 1


100%|██████████| 800/800 [00:43<00:00, 18.40it/s]


mean PSNR : 26.286740436553956 6/500
mean MSE : 0.002353205403778702 6/500
max_psnr: 27.181291637420653, epoch: 1


100%|██████████| 800/800 [00:43<00:00, 18.38it/s]


mean PSNR : 25.31306324005127 7/500
mean MSE : 0.0029746763029834256 7/500
max_psnr: 27.181291637420653, epoch: 1


100%|██████████| 800/800 [00:43<00:00, 18.33it/s]


mean PSNR : 26.911027126312256 8/500
mean MSE : 0.0020930603938177227 8/500
max_psnr: 27.181291637420653, epoch: 1


100%|██████████| 800/800 [00:43<00:00, 18.36it/s]


mean PSNR : 25.498642387390138 9/500
mean MSE : 0.0028761975234374406 9/500
max_psnr: 27.181291637420653, epoch: 1


100%|██████████| 800/800 [00:43<00:00, 18.37it/s]


mean PSNR : 27.48881212234497 10/500
mean MSE : 0.0018544390209717676 10/500
max_psnr: 27.48881212234497, epoch: 10


100%|██████████| 800/800 [00:43<00:00, 18.40it/s]


mean PSNR : 25.145859909057616 11/500
mean MSE : 0.003052564135286957 11/500
max_psnr: 27.48881212234497, epoch: 10


100%|██████████| 800/800 [00:43<00:00, 18.40it/s]


mean PSNR : 26.735845832824708 12/500
mean MSE : 0.0021839013078715653 12/500
max_psnr: 27.48881212234497, epoch: 10


100%|██████████| 800/800 [00:43<00:00, 18.38it/s]


mean PSNR : 28.151326694488525 13/500
mean MSE : 0.0016188524116296321 13/500
max_psnr: 28.151326694488525, epoch: 13


100%|██████████| 800/800 [00:43<00:00, 18.48it/s]


mean PSNR : 28.545452346801756 14/500
mean MSE : 0.0014947145062615163 14/500
max_psnr: 28.545452346801756, epoch: 14


100%|██████████| 800/800 [00:43<00:00, 18.45it/s]


mean PSNR : 28.673571014404295 15/500
mean MSE : 0.0014607561973389238 15/500
max_psnr: 28.673571014404295, epoch: 15


100%|██████████| 800/800 [00:43<00:00, 18.45it/s]


mean PSNR : 28.64277437210083 16/500
mean MSE : 0.0014367180014960468 16/500
max_psnr: 28.673571014404295, epoch: 15


100%|██████████| 800/800 [00:43<00:00, 18.42it/s]


mean PSNR : 27.909598903656004 17/500
mean MSE : 0.0016387359699001536 17/500
max_psnr: 28.673571014404295, epoch: 15


100%|██████████| 800/800 [00:43<00:00, 18.39it/s]


mean PSNR : 28.31265506744385 18/500
mean MSE : 0.0015143160632578657 18/500
max_psnr: 28.673571014404295, epoch: 15


100%|██████████| 800/800 [00:43<00:00, 18.54it/s]


mean PSNR : 28.274423007965087 19/500
mean MSE : 0.0015259090485051274 19/500
max_psnr: 28.673571014404295, epoch: 15


100%|██████████| 800/800 [00:43<00:00, 18.38it/s]


mean PSNR : 28.53729959487915 20/500
mean MSE : 0.0014510244864504784 20/500
max_psnr: 28.673571014404295, epoch: 15


100%|██████████| 800/800 [00:43<00:00, 18.37it/s]


mean PSNR : 28.378928604125978 21/500
mean MSE : 0.0015053828791133127 21/500
max_psnr: 28.673571014404295, epoch: 15


100%|██████████| 800/800 [00:43<00:00, 18.32it/s]


mean PSNR : 28.52612964630127 22/500
mean MSE : 0.0014613966864999383 22/500
max_psnr: 28.673571014404295, epoch: 15


100%|██████████| 800/800 [00:43<00:00, 18.24it/s]


mean PSNR : 28.53830759048462 23/500
mean MSE : 0.0014573704486247152 23/500
max_psnr: 28.673571014404295, epoch: 15


100%|██████████| 800/800 [00:44<00:00, 18.17it/s]


mean PSNR : 28.653489055633546 24/500
mean MSE : 0.0014270206430228426 24/500
max_psnr: 28.673571014404295, epoch: 15


100%|██████████| 800/800 [00:44<00:00, 18.18it/s]


mean PSNR : 29.0212815284729 25/500
mean MSE : 0.0013185615083784797 25/500
max_psnr: 29.0212815284729, epoch: 25


100%|██████████| 800/800 [00:43<00:00, 18.32it/s]


mean PSNR : 29.15760150909424 26/500
mean MSE : 0.001273899934021756 26/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.45it/s]


mean PSNR : 29.125227756500244 27/500
mean MSE : 0.0012802847969578578 27/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.46it/s]


mean PSNR : 28.96103956222534 28/500
mean MSE : 0.0013122287672013044 28/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.41it/s]


mean PSNR : 29.12606285095215 29/500
mean MSE : 0.0012724379240535199 29/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.46it/s]


mean PSNR : 28.719655570983885 30/500
mean MSE : 0.001380485650151968 30/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.40it/s]


mean PSNR : 29.008481197357177 31/500
mean MSE : 0.001318571188603528 31/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.42it/s]


mean PSNR : 29.0833270072937 32/500
mean MSE : 0.0012854650698136537 32/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.42it/s]


mean PSNR : 29.129273262023926 33/500
mean MSE : 0.00128206298511941 33/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.37it/s]


mean PSNR : 28.99323661804199 34/500
mean MSE : 0.0013237855589250103 34/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.38it/s]


mean PSNR : 28.275166015625 35/500
mean MSE : 0.001537835409399122 35/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.38it/s]


mean PSNR : 28.733605403900146 36/500
mean MSE : 0.0013644274545367807 36/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.35it/s]


mean PSNR : 29.04674165725708 37/500
mean MSE : 0.001295647833030671 37/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.37it/s]


mean PSNR : 28.79576246261597 38/500
mean MSE : 0.001348915655980818 38/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.37it/s]


mean PSNR : 28.95197244644165 39/500
mean MSE : 0.0013261905655963346 39/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.38it/s]


mean PSNR : 28.699788608551025 40/500
mean MSE : 0.0013769030518596991 40/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.38it/s]


mean PSNR : 28.926370372772215 41/500
mean MSE : 0.0013232313637854532 41/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.37it/s]


mean PSNR : 29.061722850799562 42/500
mean MSE : 0.0012894705595681445 42/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.46it/s]


mean PSNR : 28.969200229644777 43/500
mean MSE : 0.001303663341095671 43/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.47it/s]


mean PSNR : 29.00114809036255 44/500
mean MSE : 0.0012998463166877627 44/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.40it/s]


mean PSNR : 28.994009857177733 45/500
mean MSE : 0.0013049644784769044 45/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.36it/s]


mean PSNR : 29.064603214263915 46/500
mean MSE : 0.0012875211745267735 46/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.33it/s]


mean PSNR : 29.053187770843508 47/500
mean MSE : 0.0012878329021623358 47/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.36it/s]


mean PSNR : 28.862167377471923 48/500
mean MSE : 0.0013352457026485354 48/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.31it/s]


mean PSNR : 28.934907722473145 49/500
mean MSE : 0.001324486555531621 49/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.28it/s]


mean PSNR : 29.13701765060425 50/500
mean MSE : 0.0012691342661855743 50/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.30it/s]


mean PSNR : 28.94570375442505 51/500
mean MSE : 0.0013089653372298927 51/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.31it/s]


mean PSNR : 29.084835605621336 52/500
mean MSE : 0.0012728230317588895 52/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.30it/s]


mean PSNR : 29.08795259475708 53/500
mean MSE : 0.0012784638057928533 53/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.33it/s]


mean PSNR : 29.12446002960205 54/500
mean MSE : 0.0012704085407312959 54/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.38it/s]


mean PSNR : 28.856700325012206 55/500
mean MSE : 0.0013399215490790084 55/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.34it/s]


mean PSNR : 29.079863624572752 56/500
mean MSE : 0.001279713055700995 56/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.35it/s]


mean PSNR : 29.047989921569823 57/500
mean MSE : 0.0012873808236327023 57/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.24it/s]


mean PSNR : 29.079512729644776 58/500
mean MSE : 0.001284636798663996 58/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.36it/s]


mean PSNR : 29.01369306564331 59/500
mean MSE : 0.001295064883888699 59/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.34it/s]


mean PSNR : 29.0309073638916 60/500
mean MSE : 0.0013122070854296908 60/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.28it/s]


mean PSNR : 29.1401939201355 61/500
mean MSE : 0.001263966306578368 61/500
max_psnr: 29.15760150909424, epoch: 26


100%|██████████| 800/800 [00:43<00:00, 18.29it/s]


mean PSNR : 29.15787853240967 62/500
mean MSE : 0.001268502573366277 62/500
max_psnr: 29.15787853240967, epoch: 62


100%|██████████| 800/800 [00:43<00:00, 18.29it/s]


mean PSNR : 29.184666290283204 63/500
mean MSE : 0.0012576903437729924 63/500
max_psnr: 29.184666290283204, epoch: 63


100%|██████████| 800/800 [00:43<00:00, 18.27it/s]


mean PSNR : 29.182659854888914 64/500
mean MSE : 0.0012614312168443575 64/500
max_psnr: 29.184666290283204, epoch: 63


100%|██████████| 800/800 [00:43<00:00, 18.22it/s]


mean PSNR : 29.178766822814943 65/500
mean MSE : 0.0012530937406700104 65/500
max_psnr: 29.184666290283204, epoch: 63


100%|██████████| 800/800 [00:43<00:00, 18.39it/s]


mean PSNR : 29.071361465454103 66/500
mean MSE : 0.0012818804703420028 66/500
max_psnr: 29.184666290283204, epoch: 63


100%|██████████| 800/800 [00:43<00:00, 18.41it/s]


mean PSNR : 29.192692432403565 67/500
mean MSE : 0.001253333006752655 67/500
max_psnr: 29.192692432403565, epoch: 67


100%|██████████| 800/800 [00:43<00:00, 18.39it/s]


mean PSNR : 29.144377307891844 68/500
mean MSE : 0.0012659483670722693 68/500
max_psnr: 29.192692432403565, epoch: 67


100%|██████████| 800/800 [00:43<00:00, 18.40it/s]


mean PSNR : 28.982876682281493 69/500
mean MSE : 0.0012984734086785466 69/500
max_psnr: 29.192692432403565, epoch: 67


100%|██████████| 800/800 [00:43<00:00, 18.36it/s]


mean PSNR : 29.036448154449463 70/500
mean MSE : 0.001285245582112111 70/500
max_psnr: 29.192692432403565, epoch: 67


100%|██████████| 800/800 [00:43<00:00, 18.30it/s]


mean PSNR : 29.06844543457031 71/500
mean MSE : 0.0012854268471710383 71/500
max_psnr: 29.192692432403565, epoch: 67


100%|██████████| 800/800 [00:44<00:00, 18.17it/s]


mean PSNR : 29.104327583312987 72/500
mean MSE : 0.0012769985193153844 72/500
max_psnr: 29.192692432403565, epoch: 67


100%|██████████| 800/800 [00:44<00:00, 18.02it/s]


mean PSNR : 29.07910036087036 73/500
mean MSE : 0.0012765087239677087 73/500
max_psnr: 29.192692432403565, epoch: 67


100%|██████████| 800/800 [00:44<00:00, 18.08it/s]


mean PSNR : 29.20198259353638 74/500
mean MSE : 0.0012510640278924256 74/500
max_psnr: 29.20198259353638, epoch: 74


100%|██████████| 800/800 [00:44<00:00, 18.17it/s]


mean PSNR : 29.1276212310791 75/500
mean MSE : 0.0012610405730083584 75/500
max_psnr: 29.20198259353638, epoch: 74


100%|██████████| 800/800 [00:44<00:00, 18.15it/s]


mean PSNR : 29.22773124694824 76/500
mean MSE : 0.0012455109113943762 76/500
max_psnr: 29.22773124694824, epoch: 76


100%|██████████| 800/800 [00:44<00:00, 18.17it/s]


mean PSNR : 29.16984827041626 77/500
mean MSE : 0.0012572733504930512 77/500
max_psnr: 29.22773124694824, epoch: 76


100%|██████████| 800/800 [00:44<00:00, 18.06it/s]


mean PSNR : 29.280434970855712 78/500
mean MSE : 0.0012415126926498487 78/500
max_psnr: 29.280434970855712, epoch: 78


100%|██████████| 800/800 [00:44<00:00, 18.02it/s]


mean PSNR : 29.290249481201172 79/500
mean MSE : 0.0012383096967823804 79/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:44<00:00, 18.08it/s]


mean PSNR : 29.20439832687378 80/500
mean MSE : 0.001253361823328305 80/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:44<00:00, 18.07it/s]


mean PSNR : 29.096036739349366 81/500
mean MSE : 0.0012743352219695225 81/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:44<00:00, 18.09it/s]


mean PSNR : 29.155616455078125 82/500
mean MSE : 0.0012755087684490718 82/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:44<00:00, 18.10it/s]


mean PSNR : 29.161908988952636 83/500
mean MSE : 0.0012649504051660187 83/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:44<00:00, 18.14it/s]


mean PSNR : 29.156709785461427 84/500
mean MSE : 0.0012637620331952348 84/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:44<00:00, 18.08it/s]


mean PSNR : 29.207701511383057 85/500
mean MSE : 0.001259602900245227 85/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:44<00:00, 18.05it/s]


mean PSNR : 29.211926136016846 86/500
mean MSE : 0.0012476309295743704 86/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:44<00:00, 18.14it/s]


mean PSNR : 28.860502033233644 87/500
mean MSE : 0.0013271833502221852 87/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:44<00:00, 18.10it/s]


mean PSNR : 29.135169467926026 88/500
mean MSE : 0.0012736264057457447 88/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:44<00:00, 18.05it/s]


mean PSNR : 29.149328231811523 89/500
mean MSE : 0.0012665492619271391 89/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:44<00:00, 18.05it/s]


mean PSNR : 29.2572367477417 90/500
mean MSE : 0.001240575131669175 90/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:44<00:00, 18.00it/s]


mean PSNR : 29.148971614837645 91/500
mean MSE : 0.0012707665754714981 91/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:44<00:00, 18.13it/s]


mean PSNR : 29.230412940979004 92/500
mean MSE : 0.001246733739390038 92/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:43<00:00, 18.21it/s]


mean PSNR : 29.200875682830812 93/500
mean MSE : 0.0012547624835860915 93/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:43<00:00, 18.21it/s]


mean PSNR : 29.238610877990723 94/500
mean MSE : 0.0012422132247593253 94/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:44<00:00, 17.82it/s]


mean PSNR : 29.16754560470581 95/500
mean MSE : 0.0012663867682567798 95/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [01:04<00:00, 12.42it/s]


mean PSNR : 29.064156074523925 96/500
mean MSE : 0.001282368460088037 96/500
max_psnr: 29.290249481201172, epoch: 79


100%|██████████| 800/800 [00:52<00:00, 15.19it/s]


In [None]:
valid_dataset[0].shape

In [None]:
sign = SignFunction.apply
out = model(valid_dataset[0].unsqueeze(0).unsqueeze(0))
binary = sign(out, 0.5)
sim = tt.simulate(binary, 2e-3)
mean_sim = torch.mean(sim, dim=1, keepdim=True)

In [None]:
for target in trainloader:
    print(target.shape)
    break

In [None]:
target

In [None]:
mean_sim