In [1]:
import torchOptics.optics as tt
import warnings
import torch.nn as nn
import torchOptics.metrics as tm
import torch.nn.functional as F
import torch.optim
import torch
from torch.utils.data import Dataset, DataLoader
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torchvision
import datetime
import tqdm
import time
import pandas as pd


warnings.filterwarnings('ignore')


class SignFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        t = torch.Tensor([0.5]).to(input.device)  # threshold
        output = (input > t).float() * 1
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output * torch.ones_like(input)  # Replace with your custom gradient computation
        return grad_input

class RealSign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        output = torch.sign(input)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output * torch.ones_like(input)  # Replace with your custom gradient computation
        return grad_input


class Dataset512(Dataset):
    def __init__(self, target_dir, meta, transform=None, isTrain=True, padding=0):
        self.target_dir = target_dir
        self.transform = transform
        self.meta = meta
        self.isTrain = isTrain
        self.target_list = sorted(glob.glob(target_dir+'*.png'))
        self.center_crop = torchvision.transforms.CenterCrop(1024)
        self.random_crop = torchvision.transforms.RandomCrop((1024, 1024))
        self.padding = padding

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        target = tt.imread(self.target_list[idx], meta=meta, gray=True).unsqueeze(0)
        if target.shape[-1] < 1024 or target.shape[-2] < 1024:
            target = torchvision.transforms.Resize(1024)(target)
        if self.isTrain:
            target = self.random_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        else:
            target = self.center_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        return target


def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)

# load data
batch_size = 1
# target_dir = 'dataset/DIV2K/DIV2K_train_HR/'
# valid_dir = 'dataset/DIV2K/DIV2K_valid_HR/'
target_dir = '/nfs/dataset/DIV2K/DIV2K_train_HR/DIV2K_train_HR/'
valid_dir = '/nfs/dataset/DIV2K/DIV2K_valid_HR/DIV2K_valid_HR/'
# sim_dir = 'binary_dataset/simulated/'
meta = {'wl' : (515e-9), 'dx':(7.56e-6, 7.56e-6)}
padding = 0
train_dataset = Dataset512(target_dir=target_dir, meta=meta, isTrain=True, padding=padding)
valid_dataset = Dataset512(target_dir=valid_dir, meta=meta, isTrain=False, padding=padding)

trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
validloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)


def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)


def train(model, num_hologram=10, custom_name='fcn', epochs=100, lr=1e-4, z=2e-3):
    date = datetime.datetime.now() - datetime.timedelta(hours=15)
    mean_psnr_list = []
    mean_mse_list = []
    max_epoch = 0
    max_psnr = 0
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    sign_function = SignFunction.apply
    # path = 'models/'
    model_name = f'{date}_{custom_name}_{num_hologram}_{z}'
    # make result folder
    result_folder = 'result/'+model_name
    os.mkdir(result_folder)
    model_path = os.path.join(result_folder, model_name)
    psnr_result_path = os.path.join(result_folder, 'meanPSNR.npy')
    mse_result_path = os.path.join(result_folder, 'meanMSE.npy')
    result = {'model_path': model_path, 'psnr_result_path': psnr_result_path,
          'mse_result_path': mse_result_path, 'z': z,
          'num_hologram': num_hologram, 'lr': lr, 'epochs': epochs}
    # save result dict
    with open(os.path.join(result_folder, 'result_dict.pickle'), 'wb') as fw:
        pickle.dump(result, fw)
    for epoch in range(epochs):
        pbar = tqdm.tqdm(trainloader)
        model.train()
        for target in pbar:
            out = model(target)
            binary = sign_function(out)
            sim = tt.simulate(binary, z).abs()**2
            result = torch.mean(sim, dim=1, keepdim=True)
            loss = tt.relativeLoss(result, target, F.mse_loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        model.eval()
        with torch.no_grad():
            psnrList = []
            mseList = []
            for valid in validloader:
                out = model(valid)
                binary = sign_function(out)
                sim = tt.simulate(binary, z).abs()**2
                result = torch.mean(sim, dim=1, keepdim=True)
                psnr = tt.relativeLoss(result, valid, tm.get_PSNR)
                psnrList.append(psnr)
                mse = tt.relativeLoss(result, valid, F.mse_loss).detach().cpu().numpy()
                mseList.append(mse)
            mean_psnr = sum(psnrList)/len(psnrList)
            mean_mse = sum(mseList)/len(mseList)
            mean_psnr_list.append(mean_psnr)
            mean_mse_list.append(mean_mse)
            # save psnr and mse
            np.save(psnr_result_path, mean_psnr_list)
            np.save(mse_result_path, mean_mse_list)
            plt.plot(np.arange(len(mean_psnr_list)), mean_psnr_list,
                     label=f'max psnr: {max(mean_psnr_list)}')
            plt.title(model_name+'_psnr')
            plt.legend()
            plt.savefig(os.path.join(result_folder, 'meanPSNR.png'))
            plt.clf()
            plt.plot(np.arange(len(mean_mse_list)), mean_mse_list,
                     label=f'max psnr: {min(mean_mse_list)}')
            plt.title(model_name+'_mse')
            plt.legend()
            plt.savefig(os.path.join(result_folder, 'meanMSE.png'))
            plt.clf()
            print(f'mean PSNR : {mean_psnr} {epoch}/{epochs}')
            print(f'mean MSE : {mean_mse} {epoch}/{epochs}')

        if mean_psnr > max_psnr:
            max_psnr = mean_psnr
            max_epoch = epoch
            torch.save(model.state_dict(), model_path)
        print(f'max_psnr: {max_psnr}, epoch: {max_epoch}')


def binary_sim(out, z=2e-3):
    binary = SignFunction.apply(out)
    sim = tt.simulate(binary, z).abs()**2
    res = torch.mean(sim, dim=1, keepdim=True)
    return binary, res


def valid(model):
    with torch.no_grad():
        psnr_list = []
        for target in validloader:
            out = model(target)
            binary, res = binary_sim(out)
            psnr = tt.relativeLoss(res, target, tm.get_PSNR)
            psnr_list.append(psnr)
    print(sum(psnr_list)/len(psnr_list))


def check_order(bmodel, classifier, target):
    if len(target.shape) == 3:
        target = target.unsqueeze(0)
    result = []
    with torch.no_grad():
        out = bmodel(target)
        binary, res = binary_sim(out)
        for i in range(10):
            binary_input = binary[0][i].unsqueeze(0).unsqueeze(0)
            prop_order = classifier(binary_input)
            result.append(torch.argmax(prop_order))
    print(result)


def check_order_without_model(binary, classifier):
    with torch.no_grad():
        result = []
        for i in range(10):
            binary_input = binary[0][i].unsqueeze(0).unsqueeze(0)
            prop_order = classifier(binary_input)
            result.append(torch.argmax(prop_order))
    print(result)


class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc = nn.Linear(64 * 64 * 64, 10)  # 최종 분류를 위한 FC 레이어

    def forward(self, x):
        x = self.pool1(torch.relu(self.conv1(x)))
        x = self.pool2(torch.relu(self.conv2(x)))
        x = self.pool3(torch.relu(self.conv3(x)))
        x = x.view(-1, 64 * 64 * 64)  # flatten
        x = self.fc(x)
        return x


def rgb_binary_sim(out, z, th):
    pixel_pitch = 7.56e-6
    meta = {'wl' : (638e-9, 515e-9, 450e-9), 'dx':(pixel_pitch, pixel_pitch)}
    rmeta = {'wl': (638e-9), 'dx': (pixel_pitch, pixel_pitch)}
    gmeta = {'wl': (515e-9), 'dx': (pixel_pitch, pixel_pitch)}
    bmeta = {'wl': (450e-9), 'dx': (pixel_pitch, pixel_pitch)}
    sign = SignFunction.apply
    binary = sign(out, th)
    channel = out.shape[1]
    rchannel = int(channel/3)
    gchannel = int(channel*2/3)
    red = binary[:, :rchannel, :, :]
    green = binary[:, rchannel:gchannel, :, :]
    blue = binary[:, gchannel:, :, :]
    red = tt.Tensor(red, meta=rmeta)
    green = tt.Tensor(green, meta=gmeta)
    blue = tt.Tensor(blue, meta=bmeta)
    rsim = tt.simulate(red, z).abs()**2
    gsim = tt.simulate(green, z).abs()**2
    bsim = tt.simulate(blue, z).abs()**2
    rmean = torch.mean(rsim, dim=1, keepdim=True)
    gmean = torch.mean(gsim, dim=1, keepdim=True)
    bmean = torch.mean(bsim, dim=1, keepdim=True)
    rgb = torch.cat([rmean, gmean, bmean], dim=1)
    rgb = tt.Tensor(rgb, meta=meta)
    binary = tt.Tensor(binary, meta=meta)
    return binary, rgb

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [2]:
class BinaryNet(nn.Module):
    def __init__(self, num_hologram, final='Sigmoid', in_planes=3,
                 channels=[32, 64, 128, 256, 512, 1024, 2048, 4096],
                 convReLU=True, convBN=True, poolReLU=True, poolBN=True,
                 deconvReLU=True, deconvBN=True):
        super(BinaryNet, self).__init__()

        def CRB2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias=bias)]
            if relu:
                layers += [nn.Tanh()]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]

            cbr = nn.Sequential(*layers) # *으로 list unpacking 

            return cbr

        def TRB2d(in_channels, out_channels, kernel_size=2, stride=2, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                          kernel_size=2, stride=2, padding=0,
                                          bias=True)]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]
            if relu:
                layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers) # *으로 list unpacking 

            return cbr

        self.enc1_1 = CRB2d(in_planes, channels[0], relu=convReLU, bn=convBN)
        self.enc1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)
        self.pool1 = CRB2d(channels[0], channels[0], stride=2, relu=poolReLU, bn=poolBN)

        self.enc2_1 = CRB2d(channels[0], channels[1], relu=convReLU, bn=convBN)
        self.enc2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)
        self.pool2 = CRB2d(channels[1], channels[1], stride=2, relu=poolReLU, bn=poolBN)

        self.enc3_1 = CRB2d(channels[1], channels[2], relu=convReLU, bn=convBN)
        self.enc3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)
        self.pool3 = CRB2d(channels[2], channels[2], stride=2, relu=poolReLU, bn=poolBN)

        self.enc4_1 = CRB2d(channels[2], channels[3], relu=convReLU, bn=convBN)
        self.enc4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)
        self.pool4 = CRB2d(channels[3], channels[3], stride=2, relu=poolReLU, bn=poolBN)

        self.enc5_1 = CRB2d(channels[3], channels[4], relu=convReLU, bn=convBN)
        self.enc5_2 = CRB2d(channels[4], channels[4], relu=convReLU, bn=convBN)

        self.deconv4 = TRB2d(channels[4], channels[3], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec4_1 = CRB2d(channels[4], channels[3], relu=convReLU, bn=convBN)
        self.dec4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)

        self.deconv3 = TRB2d(channels[3], channels[2], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec3_1 = CRB2d(channels[3], channels[2], relu=convReLU, bn=convBN)
        self.dec3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)

        self.deconv2 = TRB2d(channels[2], channels[1], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec2_1 = CRB2d(channels[2], channels[1], relu=convReLU, bn=convBN)
        self.dec2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)

        self.deconv1 = TRB2d(channels[1], channels[0], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec1_1 = CRB2d(channels[1], channels[0], relu=convReLU, bn=convBN)
        self.dec1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)

        self.classifier = CRB2d(channels[0], num_hologram, relu=False, bn=False)

    def forward(self, x):
        # Encoder
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)
        enc5_2 = self.enc5_2(enc5_1)

        deconv4 = self.deconv4(enc5_2)
        concat4 = torch.cat((deconv4, enc4_2), dim=1)
        dec4_1 = self.dec4_1(concat4)
        dec4_2 = self.dec4_2(dec4_1)

        deconv3 = self.deconv3(dec4_2)
        concat3 = torch.cat((deconv3, enc3_2), dim=1)
        dec3_1 = self.dec3_1(concat3)
        dec3_2 = self.dec3_2(dec3_1)

        deconv2 = self.deconv2(dec3_2)
        concat2 = torch.cat((deconv2, enc2_2), dim=1)
        dec2_1 = self.dec2_1(concat2)
        dec2_2 = self.dec2_2(dec2_1)

        deconv1 = self.deconv1(dec2_2)
        concat1 = torch.cat((deconv1, enc1_2), dim=1)
        dec1_1 = self.dec1_1(concat1)
        dec1_2 = self.dec1_2(dec1_1)

        # Final classifier
        out = self.classifier(dec1_2)
        out = nn.Sigmoid()(out)
        return out


model = BinaryNet(num_hologram=8, in_planes=1, convReLU=False,
                  convBN=False, poolReLU=False, poolBN=False,
                  deconvReLU=False, deconvBN=False).cuda()
test = torch.randn(1, 1, 1024, 1024).cuda()
out = model(test)
print(out.shape)

torch.Size([1, 8, 1024, 1024])


In [3]:
def rgb_binary_train_padding(model, num_hologram=24, custom_name='0703Unet', epochs=500, lr=1e-4, z=2e-3):
    date = datetime.datetime.now() - datetime.timedelta(hours=15)
    model.apply(initialize_weights)
    mean_psnr_list = []
    mean_mse_list = []
    max_epoch = 0
    max_psnr = 0
    crop = torchvision.transforms.CenterCrop(896)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    sign = SignFunction.apply
    # path = 'models/'
    model_name = f'{date}_{custom_name}_{num_hologram}_{z}'
    # make result folder
    result_folder = 'result/'+model_name
    os.mkdir(result_folder)
    num_parameters = count_parameters(model)
    model_path = os.path.join(result_folder, model_name)
    psnr_result_path = os.path.join(result_folder, 'meanPSNR.npy')
    mse_result_path = os.path.join(result_folder, 'meanMSE.npy')
    result = {'model_path': model_path, 'psnr_result_path': psnr_result_path,
          'mse_result_path': mse_result_path, 'z': z,
          'num_hologram': num_hologram, 'lr': lr, 'epochs': epochs,
            'num_parameters': num_parameters}
    # save result dict
    with open(os.path.join(result_folder, 'result_dict.pickle'), 'wb') as fw:
        pickle.dump(result, fw)
    for epoch in range(epochs):
        pbar = tqdm.tqdm(trainloader)
        model.train()
        for target in pbar:
            target = target
            out = model(target)
            binary = sign(out)
            sim = tt.simulate(binary, 2e-3).abs()**2
            mean_sim = torch.mean(sim, dim=1, keepdim=True)
            rgb = crop(mean_sim)
            croped_target = crop(target)
            loss = tt.relativeLoss(rgb, croped_target, F.mse_loss)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        model.eval()
        with torch.no_grad():
            psnrList = []
            mseList = []
            for idx, valid in enumerate(validloader):
                valid = valid
                out = model(valid)
                binary = sign(out)
                sim = tt.simulate(binary, 2e-3).abs()**2
                mean_sim = torch.mean(sim, dim=1, keepdim=True)
                rgb = crop(mean_sim)
                croped_target = crop(valid)
                psnr = tt.relativeLoss(rgb, croped_target, tm.get_PSNR)
                psnrList.append(psnr)
                mse = tt.relativeLoss(rgb, croped_target, F.mse_loss).detach().cpu().numpy()
                mseList.append(mse)
                if idx == 0:
                    pl = tt.show_with_insets(rgb, croped_target, correct_colorwise=True)
                    pl.save(os.path.join(result_folder, 'penguin.png'))
            mean_psnr = sum(psnrList)/len(psnrList)
            mean_mse = sum(mseList)/len(mseList)
            mean_psnr_list.append(mean_psnr)
            mean_mse_list.append(mean_mse)
            # save psnr and mse
            if epoch % 10 == 0:
                np.save(psnr_result_path, mean_psnr_list)
                np.save(mse_result_path, mean_mse_list)
                plt.plot(np.arange(len(mean_psnr_list)), mean_psnr_list,
                         label=f'max psnr: {max(mean_psnr_list)}')
                plt.title(model_name+'_psnr')
                plt.legend()
                plt.savefig(os.path.join(result_folder, 'meanPSNR.png'))
                plt.clf()
                plt.plot(np.arange(len(mean_mse_list)), mean_mse_list,
                         label=f'max psnr: {min(mean_mse_list)}')
                plt.title(model_name+'_mse')
                plt.legend()
                plt.savefig(os.path.join(result_folder, 'meanMSE.png'))
                plt.clf()
            print(f'mean PSNR : {mean_psnr} {epoch}/{epochs}')
            print(f'mean MSE : {mean_mse} {epoch}/{epochs}')

        if mean_psnr > max_psnr:
            max_psnr = mean_psnr
            max_epoch = epoch
            torch.save(model.state_dict(), model_path)
        print(f'max_psnr: {max_psnr}, epoch: {max_epoch}')
        print(f'Model saved at: {model_path}')

In [4]:
rgb_binary_train_padding(model, num_hologram=8, custom_name='pre_reinforce', epochs=500, lr=1e-4, z=2e-3)

100%|██████████| 800/800 [00:54<00:00, 14.78it/s]


mean PSNR : 22.13637659072876 0/500
mean MSE : 0.0065378099004738035 0/500
max_psnr: 22.13637659072876, epoch: 0
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 20.78962423324585 1/500
mean MSE : 0.008564953352324664 1/500
max_psnr: 22.13637659072876, epoch: 0
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 15.08it/s]


mean PSNR : 24.275135803222657 2/500
mean MSE : 0.00416958402027376 2/500
max_psnr: 24.275135803222657, epoch: 2
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.14it/s]


mean PSNR : 23.931255645751953 3/500
mean MSE : 0.0044228369696065785 3/500
max_psnr: 24.275135803222657, epoch: 2
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.17it/s]


mean PSNR : 16.167649354934692 4/500
mean MSE : 0.025029034204781055 4/500
max_psnr: 24.275135803222657, epoch: 2
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.21it/s]


mean PSNR : 20.001152400970458 5/500
mean MSE : 0.010204670946113764 5/500
max_psnr: 24.275135803222657, epoch: 2
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 23.999080142974854 6/500
mean MSE : 0.004269842273788527 6/500
max_psnr: 24.275135803222657, epoch: 2
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.17it/s]


mean PSNR : 23.61373010635376 7/500
mean MSE : 0.004658002919750288 7/500
max_psnr: 24.275135803222657, epoch: 2
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 21.71242723464966 8/500
mean MSE : 0.006879117630887777 8/500
max_psnr: 24.275135803222657, epoch: 2
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 24.3826837348938 9/500
mean MSE : 0.003953990472946316 9/500
max_psnr: 24.3826837348938, epoch: 9
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.24it/s]


mean PSNR : 19.094439811706543 10/500
mean MSE : 0.012195304306223988 10/500
max_psnr: 24.3826837348938, epoch: 9
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.19it/s]


mean PSNR : 24.64353775024414 11/500
mean MSE : 0.0037372187251457946 11/500
max_psnr: 24.64353775024414, epoch: 11
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.17it/s]


mean PSNR : 22.898166217803954 12/500
mean MSE : 0.005481080097379163 12/500
max_psnr: 24.64353775024414, epoch: 11
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.14it/s]


mean PSNR : 23.728864345550537 13/500
mean MSE : 0.004501070246333256 13/500
max_psnr: 24.64353775024414, epoch: 11
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.14it/s]


mean PSNR : 24.49860969543457 14/500
mean MSE : 0.003817832192289643 14/500
max_psnr: 24.64353775024414, epoch: 11
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.21it/s]


mean PSNR : 24.685976161956788 15/500
mean MSE : 0.0036924881592858582 15/500
max_psnr: 24.685976161956788, epoch: 15
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.23it/s]


mean PSNR : 24.763236312866212 16/500
mean MSE : 0.003637375277467072 16/500
max_psnr: 24.763236312866212, epoch: 16
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 21.69372495651245 17/500
mean MSE : 0.007154474869603291 17/500
max_psnr: 24.763236312866212, epoch: 16
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.18it/s]


mean PSNR : 22.51586519241333 18/500
mean MSE : 0.0058001624699682 18/500
max_psnr: 24.763236312866212, epoch: 16
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.23it/s]


mean PSNR : 23.93928367614746 19/500
mean MSE : 0.0043119530833791945 19/500
max_psnr: 24.763236312866212, epoch: 16
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.23it/s]


mean PSNR : 24.751645183563234 20/500
mean MSE : 0.0036355965450638903 20/500
max_psnr: 24.763236312866212, epoch: 16
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 24.813671646118163 21/500
mean MSE : 0.0036021030391566455 21/500
max_psnr: 24.813671646118163, epoch: 21
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.21it/s]


mean PSNR : 25.041873416900636 22/500
mean MSE : 0.0034504395665135236 22/500
max_psnr: 25.041873416900636, epoch: 22
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.20it/s]


mean PSNR : 25.250986671447755 23/500
mean MSE : 0.0033137684472603723 23/500
max_psnr: 25.250986671447755, epoch: 23
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 25.411866741180418 24/500
mean MSE : 0.0032086913054808976 24/500
max_psnr: 25.411866741180418, epoch: 24
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.17it/s]


mean PSNR : 25.20108652114868 25/500
mean MSE : 0.003302661555353552 25/500
max_psnr: 25.411866741180418, epoch: 24
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.16it/s]


mean PSNR : 24.99763990402222 26/500
mean MSE : 0.003466446092352271 26/500
max_psnr: 25.411866741180418, epoch: 24
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.17it/s]


mean PSNR : 25.12537691116333 27/500
mean MSE : 0.0033697804424446076 27/500
max_psnr: 25.411866741180418, epoch: 24
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.10it/s]


mean PSNR : 24.868938159942626 28/500
mean MSE : 0.0035065648739691824 28/500
max_psnr: 25.411866741180418, epoch: 24
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.99it/s]


mean PSNR : 24.805587635040283 29/500
mean MSE : 0.003513851275201887 29/500
max_psnr: 25.411866741180418, epoch: 24
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.99it/s]


mean PSNR : 24.70513889312744 30/500
mean MSE : 0.0036301304621156306 30/500
max_psnr: 25.411866741180418, epoch: 24
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.10it/s]


mean PSNR : 25.352430248260497 31/500
mean MSE : 0.003177870282670483 31/500
max_psnr: 25.411866741180418, epoch: 24
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.14it/s]


mean PSNR : 25.299350757598877 32/500
mean MSE : 0.003224102850072086 32/500
max_psnr: 25.411866741180418, epoch: 24
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.16it/s]


mean PSNR : 25.529957962036132 33/500
mean MSE : 0.0031006389483809473 33/500
max_psnr: 25.529957962036132, epoch: 33
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.18it/s]


mean PSNR : 25.7607328414917 34/500
mean MSE : 0.0029809973074588925 34/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 15.06it/s]


mean PSNR : 25.233497638702392 35/500
mean MSE : 0.003226804898586124 35/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.12it/s]


mean PSNR : 25.66297399520874 36/500
mean MSE : 0.003058739284751937 36/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.20it/s]


mean PSNR : 24.375363788604737 37/500
mean MSE : 0.003942338770721108 37/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 25.591721725463866 38/500
mean MSE : 0.0030925471213413404 38/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.21it/s]


mean PSNR : 24.41408893585205 39/500
mean MSE : 0.003927964543690905 39/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.19it/s]


mean PSNR : 25.001346588134766 40/500
mean MSE : 0.003443584047490731 40/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.24it/s]


mean PSNR : 24.861484317779542 41/500
mean MSE : 0.003540429410641082 41/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 24.79914171218872 42/500
mean MSE : 0.0035652004613075405 42/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 25.010485572814943 43/500
mean MSE : 0.0033926455944310874 43/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 24.813898029327394 44/500
mean MSE : 0.0034997303411364554 44/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.28it/s]


mean PSNR : 24.871786918640137 45/500
mean MSE : 0.003502971944399178 45/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 25.024860076904297 46/500
mean MSE : 0.0033635654172394424 46/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.22it/s]


mean PSNR : 25.329705028533937 47/500
mean MSE : 0.0031882908893749117 47/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.25it/s]


mean PSNR : 24.790946197509765 48/500
mean MSE : 0.0034893351537175475 48/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.15it/s]


mean PSNR : 24.923219203948975 49/500
mean MSE : 0.0033917736052535476 49/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.20it/s]


mean PSNR : 25.41774896621704 50/500
mean MSE : 0.0031191810197196902 50/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.21it/s]


mean PSNR : 25.42978401184082 51/500
mean MSE : 0.0031330551079008727 51/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:52<00:00, 15.17it/s]


mean PSNR : 25.10968734741211 52/500
mean MSE : 0.003290433498332277 52/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:54<00:00, 14.56it/s]


mean PSNR : 25.248604679107665 53/500
mean MSE : 0.0032275102927815168 53/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 25.13869180679321 54/500
mean MSE : 0.003276826456421986 54/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 25.14308187484741 55/500
mean MSE : 0.003289193355012685 55/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 25.20544376373291 56/500
mean MSE : 0.003265189903322607 56/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 25.25973850250244 57/500
mean MSE : 0.003241741225356236 57/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.95it/s]


mean PSNR : 25.332769470214842 58/500
mean MSE : 0.0031993829226121305 58/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.91it/s]


mean PSNR : 25.445468139648437 59/500
mean MSE : 0.0031354835629463194 59/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 24.445939674377442 60/500
mean MSE : 0.003878006951417774 60/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 25.12560531616211 61/500
mean MSE : 0.0033114560978719965 61/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.98it/s]


mean PSNR : 25.026461715698243 62/500
mean MSE : 0.0033589206845499576 62/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 25.04756134033203 63/500
mean MSE : 0.003359043060336262 63/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 25.470106716156007 64/500
mean MSE : 0.003132090167491697 64/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 25.46799436569214 65/500
mean MSE : 0.003100528015056625 65/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.95it/s]


mean PSNR : 25.401438617706297 66/500
mean MSE : 0.0031527108850423246 66/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.98it/s]


mean PSNR : 25.54346803665161 67/500
mean MSE : 0.0030403991136699914 67/500
max_psnr: 25.7607328414917, epoch: 34
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.92it/s]


mean PSNR : 25.83650407791138 68/500
mean MSE : 0.0028875227639218793 68/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.89it/s]


mean PSNR : 25.60610704421997 69/500
mean MSE : 0.003012095708400011 69/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.95it/s]


mean PSNR : 25.571160526275634 70/500
mean MSE : 0.00301090610853862 70/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.95it/s]


mean PSNR : 25.257707901000977 71/500
mean MSE : 0.003190375647391193 71/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 25.326909923553465 72/500
mean MSE : 0.003170099378330633 72/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.92it/s]


mean PSNR : 25.550672855377197 73/500
mean MSE : 0.003088587191305123 73/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.90it/s]


mean PSNR : 25.502693519592285 74/500
mean MSE : 0.0031161794520448895 74/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 25.420060806274414 75/500
mean MSE : 0.003173069563345052 75/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 25.478099002838135 76/500
mean MSE : 0.0031138048024149613 76/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 25.299492321014405 77/500
mean MSE : 0.003191329719265923 77/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 25.591969985961914 78/500
mean MSE : 0.003038151778746396 78/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 15.01it/s]


mean PSNR : 25.23492027282715 79/500
mean MSE : 0.003233962310478091 79/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 15.00it/s]


mean PSNR : 25.211135807037355 80/500
mean MSE : 0.003305961191654205 80/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 23.807051486968994 81/500
mean MSE : 0.00439478532760404 81/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 24.813819522857667 82/500
mean MSE : 0.003511802885332145 82/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 25.01289535522461 83/500
mean MSE : 0.003375809204299003 83/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 23.908821258544922 84/500
mean MSE : 0.0041766654816456136 84/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 24.897420749664306 85/500
mean MSE : 0.003453332631615922 85/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.90it/s]


mean PSNR : 25.451338233947755 86/500
mean MSE : 0.003096228000940755 86/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 25.33116807937622 87/500
mean MSE : 0.003181017800234258 87/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.92it/s]


mean PSNR : 25.529249286651613 88/500
mean MSE : 0.003100309225264937 88/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.95it/s]


mean PSNR : 25.16164051055908 89/500
mean MSE : 0.003293280806392431 89/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.98it/s]


mean PSNR : 25.436471977233886 90/500
mean MSE : 0.0030609059100970624 90/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 15.00it/s]


mean PSNR : 25.707589740753175 91/500
mean MSE : 0.0029834722995292396 91/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 25.570883350372313 92/500
mean MSE : 0.0030399518099147827 92/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 25.039690608978272 93/500
mean MSE : 0.0033505007962230595 93/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 25.40150785446167 94/500
mean MSE : 0.0031022967211902144 94/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.95it/s]


mean PSNR : 25.790175971984862 95/500
mean MSE : 0.002915997394011356 95/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 24.23941246032715 96/500
mean MSE : 0.003929126755101606 96/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.92it/s]


mean PSNR : 25.529109897613527 97/500
mean MSE : 0.0030737482843687758 97/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 25.32884515762329 98/500
mean MSE : 0.0031907116482034325 98/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.92it/s]


mean PSNR : 25.19740238189697 99/500
mean MSE : 0.0032398876606021076 99/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 25.50059778213501 100/500
mean MSE : 0.003097311024321243 100/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 25.08154207229614 101/500
mean MSE : 0.003382558959419839 101/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.95it/s]


mean PSNR : 24.824951782226563 102/500
mean MSE : 0.0035366592032369227 102/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 24.26704200744629 103/500
mean MSE : 0.003935151157202199 103/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.92it/s]


mean PSNR : 25.355292510986327 104/500
mean MSE : 0.0032127268356271087 104/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.92it/s]


mean PSNR : 25.30411840438843 105/500
mean MSE : 0.0032746138877701014 105/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.90it/s]


mean PSNR : 24.981683502197267 106/500
mean MSE : 0.003421987968031317 106/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.92it/s]


mean PSNR : 24.94388328552246 107/500
mean MSE : 0.0034398921485990284 107/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.92it/s]


mean PSNR : 25.17996639251709 108/500
mean MSE : 0.0032807573222089557 108/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.89it/s]


mean PSNR : 25.279450855255128 109/500
mean MSE : 0.003194427457638085 109/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 24.77078649520874 110/500
mean MSE : 0.0035503701434936373 110/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.99it/s]


mean PSNR : 24.960076999664306 111/500
mean MSE : 0.003419813960790634 111/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.91it/s]


mean PSNR : 24.8326629447937 112/500
mean MSE : 0.0035363035020418467 112/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.98it/s]


mean PSNR : 24.71471549987793 113/500
mean MSE : 0.0036040445999242364 113/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 24.762698154449463 114/500
mean MSE : 0.003564292733790353 114/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 24.79501434326172 115/500
mean MSE : 0.003538432940840721 115/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 25.018031673431395 116/500
mean MSE : 0.003392740380950272 116/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 25.05194055557251 117/500
mean MSE : 0.003354734251042828 117/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.95it/s]


mean PSNR : 25.37637948989868 118/500
mean MSE : 0.003151432384038344 118/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.95it/s]


mean PSNR : 24.093040580749513 119/500
mean MSE : 0.004112269476754591 119/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 25.19155755996704 120/500
mean MSE : 0.0032476511993445455 120/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 24.854823684692384 121/500
mean MSE : 0.0034509692282881587 121/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 25.254732532501222 122/500
mean MSE : 0.0032099626923445613 122/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.89it/s]


mean PSNR : 25.156085681915282 123/500
mean MSE : 0.0033313517307396977 123/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 25.31719186782837 124/500
mean MSE : 0.0032048349687829616 124/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 24.58794143676758 125/500
mean MSE : 0.003669843786628917 125/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 24.8180082321167 126/500
mean MSE : 0.0034788731939624997 126/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 25.400397777557373 127/500
mean MSE : 0.003123279593419284 127/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 24.960079040527344 128/500
mean MSE : 0.0034062489692587407 128/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.91it/s]


mean PSNR : 24.821347732543945 129/500
mean MSE : 0.003610315365949646 129/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 20.110278806686402 130/500
mean MSE : 0.010679748782422394 130/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 22.781205425262453 131/500
mean MSE : 0.005474795566406101 131/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.95it/s]


mean PSNR : 24.864282398223878 132/500
mean MSE : 0.0034987781033851205 132/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.99it/s]


mean PSNR : 24.67673807144165 133/500
mean MSE : 0.0036089298449223863 133/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 25.210799827575684 134/500
mean MSE : 0.003267357109580189 134/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 24.40278377532959 135/500
mean MSE : 0.003919171171728522 135/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.95it/s]


mean PSNR : 24.236486473083495 136/500
mean MSE : 0.004035697581712156 136/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 24.143772315979003 137/500
mean MSE : 0.004119324494386092 137/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 23.970263118743897 138/500
mean MSE : 0.004268818770069629 138/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.82it/s]


mean PSNR : 25.16465229034424 139/500
mean MSE : 0.003388955192640424 139/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 25.06699089050293 140/500
mean MSE : 0.0033971055631991476 140/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.88it/s]


mean PSNR : 24.686084232330323 141/500
mean MSE : 0.0036322238459251823 141/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 24.439687156677245 142/500
mean MSE : 0.003823642106726766 142/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.95it/s]


mean PSNR : 25.123363819122314 143/500
mean MSE : 0.0033741472219116987 143/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.98it/s]


mean PSNR : 25.12785541534424 144/500
mean MSE : 0.003365832141134888 144/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 15.00it/s]


mean PSNR : 24.737143249511718 145/500
mean MSE : 0.003634274109499529 145/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.98it/s]


mean PSNR : 24.802282600402833 146/500
mean MSE : 0.003554442349122837 146/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.98it/s]


mean PSNR : 25.002894344329835 147/500
mean MSE : 0.0034440702898427844 147/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 25.323158187866213 148/500
mean MSE : 0.0032288950809743256 148/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.90it/s]


mean PSNR : 25.358228282928465 149/500
mean MSE : 0.0031779609713703393 149/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.89it/s]


mean PSNR : 25.1366676902771 150/500
mean MSE : 0.0032829415146261452 150/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 25.358336162567138 151/500
mean MSE : 0.0031716083292849364 151/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.99it/s]


mean PSNR : 24.921023044586182 152/500
mean MSE : 0.003439852768788114 152/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.99it/s]


mean PSNR : 24.75617691040039 153/500
mean MSE : 0.003557960846228525 153/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.99it/s]


mean PSNR : 25.32322296142578 154/500
mean MSE : 0.0031970251782331618 154/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 25.27320804595947 155/500
mean MSE : 0.0032001778029371055 155/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.91it/s]


mean PSNR : 25.171339912414552 156/500
mean MSE : 0.0033276508725248277 156/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.90it/s]


mean PSNR : 25.451709480285643 157/500
mean MSE : 0.003100753725739196 157/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.91it/s]


mean PSNR : 25.00136478424072 158/500
mean MSE : 0.0033780676405876873 158/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 25.170076904296874 159/500
mean MSE : 0.0033103644626680763 159/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.95it/s]


mean PSNR : 25.510104064941405 160/500
mean MSE : 0.003090205459157005 160/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 25.084309940338134 161/500
mean MSE : 0.0033004950114991515 161/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 25.281513900756835 162/500
mean MSE : 0.0031969721033237876 162/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.91it/s]


mean PSNR : 25.432672691345214 163/500
mean MSE : 0.003101135992910713 163/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.92it/s]


mean PSNR : 25.396686515808106 164/500
mean MSE : 0.003133061206317507 164/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 15.00it/s]


mean PSNR : 25.415657691955566 165/500
mean MSE : 0.0030993868253426627 165/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.99it/s]


mean PSNR : 25.60994647979736 166/500
mean MSE : 0.002980071163037792 166/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.98it/s]


mean PSNR : 25.158938179016115 167/500
mean MSE : 0.0033042364206630737 167/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 25.494061374664305 168/500
mean MSE : 0.003099635348189622 168/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 25.275293617248536 169/500
mean MSE : 0.0032229177025146784 169/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 25.185358963012696 170/500
mean MSE : 0.0032820781390182673 170/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 25.56148603439331 171/500
mean MSE : 0.003015607389388606 171/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.96it/s]


mean PSNR : 25.24755084991455 172/500
mean MSE : 0.0032515559589955958 172/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 25.210118160247802 173/500
mean MSE : 0.003214677380165085 173/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 25.595103874206544 174/500
mean MSE : 0.003015784912277013 174/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 25.793313903808595 175/500
mean MSE : 0.0029137152939802038 175/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.93it/s]


mean PSNR : 25.622179374694824 176/500
mean MSE : 0.002993072756798938 176/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.95it/s]


mean PSNR : 12.968310766220092 177/500
mean MSE : 0.053207371048629284 177/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.91it/s]


mean PSNR : 12.968310766220092 178/500
mean MSE : 0.053207371048629284 178/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.94it/s]


mean PSNR : 12.968310766220092 179/500
mean MSE : 0.053207371048629284 179/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 11.316968421936036 180/500
mean MSE : 0.07687519416213036 180/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 12.982070217132568 181/500
mean MSE : 0.052786914575845 181/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.99it/s]


mean PSNR : 12.938253135681153 182/500
mean MSE : 0.05343703282997012 182/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 9.755820608139038 183/500
mean MSE : 0.10934087831526995 183/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.99it/s]


mean PSNR : 13.002557458877563 184/500
mean MSE : 0.05271723570302129 184/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


100%|██████████| 800/800 [00:53<00:00, 14.97it/s]


mean PSNR : 13.002557458877563 185/500
mean MSE : 0.05271723570302129 185/500
max_psnr: 25.83650407791138, epoch: 68
Model saved at: result/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002/2024-12-12 17:12:48.697095_pre_reinforce_8_0.002


 57%|█████▋    | 454/800 [00:30<00:23, 14.96it/s]

KeyboardInterrupt



<Figure size 640x480 with 0 Axes>

In [None]:
valid_dataset[0].shape

In [None]:
sign = SignFunction.apply
out = model(valid_dataset[0].unsqueeze(0).unsqueeze(0))
binary = sign(out, 0.5)
sim = tt.simulate(binary, 2e-3)
mean_sim = torch.mean(sim, dim=1, keepdim=True)

In [None]:
for target in trainloader:
    print(target.shape)
    break

In [None]:
target

In [None]:
mean_sim

In [None]:
import gymnasium as gym
import numpy as np
from gymnasium import spaces
import torch
import torch.nn.functional as F

class BinaryHologramEnv(gym.Env):
    def __init__(self, target_function, max_steps=1000000, T_PSNR=30, T_steps=100):
        """
        target_function: 타겟 이미지와의 손실(MSE 또는 PSNR) 계산 함수.
        max_steps: 최대 타임스텝 제한.
        T_PSNR: 목표 PSNR 값.
        T_steps: PSNR 목표를 유지해야 하는 최소 타임스텝.
        """
        super(BinaryHologramEnv, self).__init__()

        # Observation space: 연속형 데이터
        self.observation_space = spaces.Box(
            low=0,
            high=1,
            shape=(1, 8, 1024, 1024),  # 8채널 데이터
            dtype=np.float32,
        )

        # Action space: MultiBinary 데이터
        self.action_space = spaces.MultiBinary(1 * 8 * 1024 * 1024)

        # 목표 함수
        self.target_function = target_function

        # 에피소드 설정
        self.max_steps = max_steps
        self.T_PSNR = T_PSNR
        self.T_steps = T_steps

        # 학습 상태
        self.state = None  # MultiBinary 형식의 환경 상태
        self.observation = None  # 에이전트가 관찰할 수 있는 정보
        self.steps = 0
        self.psnr_sustained_steps = 0  # PSNR 목표 유지 스텝 수

    def reset(self, seed=None, options=None):
        """
        환경 상태를 초기화.
        """
        # 관찰값 초기화 (연속형 데이터)
        self.observation = np.random.rand(1, 8, 1024, 1024).astype(np.float32)

        # 상태 초기화 (관찰값을 0.5 기준으로 이진화)
        self.state = (self.observation >= 0.5).astype(np.int8)

        # 초기화된 타임스텝
        self.steps = 0
        self.psnr_sustained_steps = 0

        return self.observation, {}

    # def calculate_psnr(self, mse):
        """
        MSE로부터 PSNR 계산.
        """
        # if mse == 0:
            # return np.inf
        # return 10 * np.log10(1.0 / mse)  # MAX 신호 값(1.0) 기준

    def step(self, action, lr=1e-4, z=2e-3):
        """
        action: 1024x1024x8 배열로 출력된 MultiBinary 행동 데이터 (0 또는 1).
        """
        # Action을 MultiBinary 형식으로 변환
        action = np.reshape(action, (1, 8, 1024, 1024)).astype(np.int8)

        # 타겟 계산 (예: 모델이 사용할 데이터셋)
        target = train_dataset
        # target = self.state.mean(axis=-1, keepdims=True)  # 타겟 이미지 (평균값)

        # 시뮬레이션 계산
        binary = action
        sim = tt.simulate(binary, z).abs()**2
        result = torch.mean(sim, dim=1, keepdim=True)

        # MSE 계산 (이진 출력의 평균과 타겟 비교)
        mse = tt.relativeLoss(result, target, F.mse_loss).detach().cpu().numpy()
        # mse = np.mean((action.mean(axis=-1) - target.squeeze()) ** 2)

        # PSNR 계산
        psnr = tt.relativeLoss(result, target, tm.get_PSNR)
        # psnr = self.calculate_psnr(mse)

        # 보상 설계 (MSE 기반)
        reward = -mse

        # 상태 업데이트 (MultiBinary 상태 XOR로 변경)
        self.state = np.logical_xor(self.state, action).astype(np.int8)

        # 종료 조건 초기화
        terminated = False
        truncated = False
        self.steps += 1

        # 시간 초과 조건
        if self.steps >= self.max_steps:
            truncated = True

        # 성능 기반 종료 조건 (PSNR 유지)
        if psnr >= self.T_PSNR:
            self.psnr_sustained_steps += 1
        else:
            self.psnr_sustained_steps = 0

        # 목표 PSNR을 일정 시간 유지하면 종료
        if self.psnr_sustained_steps >= self.T_steps:
            terminated = True

        # 관찰값은 상태와 독립적으로 유지됨
        observation = self.observation  # 관찰값은 행동과 무관하게 유지

        # 추가 정보 반환
        info = {
            "mse": mse,
            "psnr": psnr,
            "psnr_sustained_steps": self.psnr_sustained_steps,
        }

        return observation, reward, terminated, truncated, info


In [None]:
from stable_baselines3.common.env_checker import check_env

def target_function(state, action):
    """
    타겟 함수: 상태와 행동 간의 MSE를 계산합니다.
    """
    target = state.mean(axis=-1, keepdims=True)
    return np.mean((action.mean(axis=-1) - target.squeeze()) ** 2)

# 환경 인스턴스 생성
env = BinaryHologramEnv(
    target_function=target_function,
    max_steps=1000000,
    T_PSNR=30,
    T_steps=100,
)

# 환경 유효성 확인
check_env(env, warn=True)


In [None]:
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import RecurrentPPO

model = BinaryNet(num_hologram=8, in_planes=1, convReLU=False,
                  convBN=False, poolReLU=False, poolBN=False,
                  deconvReLU=False, deconvBN=False).cuda()
model.load_state_dict(torch.load('result/2024-12-07 19:38:09.105795_pre_reinforce_8_0.002/2024-12-07 19:38:09.105795_pre_reinforce_8_0.002'))
model.eval()

# Create the custom Gym environment
env = BinaryHologramEnv(model=model, validloader=validloader, max_steps=1000, T_PSNR=30, T_steps=100)

# Create a vectorized environment
venv = make_vec_env(lambda: env, n_envs=1)
venv = VecNormalize(venv, norm_obs=True, norm_reward=True, clip_obs=10.0)


# Recurrent PPO 모델 학습
ppo_model = RecurrentPPO(
    "MlpLstmPolicy",
    venv,
    verbose=1,
    n_steps=2048,
    batch_size=64,
    gamma=0.99,
    learning_rate=3e-4,
    tensorboard_log="./ppo_lstm/"
)

# 모델 학습
ppo_model.learn(total_timesteps=100000)

# 학습된 모델 저장
ppo_model.save("recurrent_ppo_binary_hologram")
