In [None]:
import sys
import logging
from datetime import datetime
import os

# 로그를 저장할 디렉토리 설정
log_dir = "log"
os.makedirs(log_dir, exist_ok=True)  # 디렉토리가 없으면 생성

# 현재 파일 이름과 실행 시간 가져오기
if '__file__' in globals():
    current_file = os.path.splitext(os.path.basename(__file__))[0]  # 현재 파일 이름(확장자 제거)
else:
    current_file = "interactive"  # 인터프리터나 노트북 환경에서 기본 파일 이름 사용

current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")  # 현재 시간
log_filename = os.path.join(log_dir, f"{current_file}_{current_datetime}.log")  # log 폴더에 파일 저장

# 로그 설정
logging.basicConfig(
    level=logging.INFO,
    format='%(message)s',
    handlers=[
        logging.FileHandler(log_filename),  # 동적으로 생성된 파일 이름 사용
        logging.StreamHandler()  # 콘솔 출력
    ]
)

class Tee:
    def __init__(self, *files):
        self.files = files

    def write(self, data):
        for file in self.files:
            file.write(data)
            file.flush()  # 실시간 저장

    def flush(self):
        for file in self.files:
            file.flush()


# stdout을 파일과 콘솔로 동시에 출력
log_file = open(log_filename, "a")
sys.stdout = Tee(sys.stdout, log_file)

# 테스트 출력
print("이 메시지는 콘솔과 파일에 동시에 기록됩니다.")

import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import MaskablePPO
from stable_baselines3.common.policies import ActorCriticPolicy
import gymnasium as gym
from gymnasium import spaces
from datetime import datetime
import glob
import torchOptics.optics as tt
import torch.nn as nn
import torchOptics.metrics as tm
import torch.nn.functional as F
import torch.optim
import torch
from torch.utils.data import Dataset, DataLoader
import os
import glob
import matplotlib.pyplot as plt
import pickle
import torchvision
import tqdm
import time
import pandas as pd
from sb3_contrib.common.maskable.utils import get_action_masks
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3 import PPO
import warnings

import os
import torch

IPS = 256
CH = 8

warnings.filterwarnings('ignore')

# 현재 날짜와 시간을 가져와 포맷 지정
current_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

torch.backends.cudnn.enabled = False

class BinaryNet(nn.Module):
    def __init__(self, num_hologram, final='Sigmoid', in_planes=3,
                 channels=[32, 64, 128, 256, 512, 1024, 2048, 4096],
                 convReLU=True, convBN=True, poolReLU=True, poolBN=True,
                 deconvReLU=True, deconvBN=True):
        super(BinaryNet, self).__init__()

        def CRB2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias=bias)]
            if relu:
                layers += [nn.Tanh()]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]

            cbr = nn.Sequential(*layers)  # *으로 list unpacking

            return cbr

        def TRB2d(in_channels, out_channels, kernel_size=2, stride=2, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                          kernel_size=2, stride=2, padding=0,
                                          bias=True)]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]
            if relu:
                layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers)  # *으로 list unpacking

            return cbr

        self.enc1_1 = CRB2d(in_planes, channels[0], relu=convReLU, bn=convBN)
        self.enc1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)
        self.pool1 = CRB2d(channels[0], channels[0], stride=2, relu=poolReLU, bn=poolBN)

        self.enc2_1 = CRB2d(channels[0], channels[1], relu=convReLU, bn=convBN)
        self.enc2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)
        self.pool2 = CRB2d(channels[1], channels[1], stride=2, relu=poolReLU, bn=poolBN)

        self.enc3_1 = CRB2d(channels[1], channels[2], relu=convReLU, bn=convBN)
        self.enc3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)
        self.pool3 = CRB2d(channels[2], channels[2], stride=2, relu=poolReLU, bn=poolBN)

        self.enc4_1 = CRB2d(channels[2], channels[3], relu=convReLU, bn=convBN)
        self.enc4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)
        self.pool4 = CRB2d(channels[3], channels[3], stride=2, relu=poolReLU, bn=poolBN)

        self.enc5_1 = CRB2d(channels[3], channels[4], relu=convReLU, bn=convBN)
        self.enc5_2 = CRB2d(channels[4], channels[4], relu=convReLU, bn=convBN)

        self.deconv4 = TRB2d(channels[4], channels[3], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec4_1 = CRB2d(channels[4], channels[3], relu=convReLU, bn=convBN)
        self.dec4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)

        self.deconv3 = TRB2d(channels[3], channels[2], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec3_1 = CRB2d(channels[3], channels[2], relu=convReLU, bn=convBN)
        self.dec3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)

        self.deconv2 = TRB2d(channels[2], channels[1], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec2_1 = CRB2d(channels[2], channels[1], relu=convReLU, bn=convBN)
        self.dec2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)

        self.deconv1 = TRB2d(channels[1], channels[0], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec1_1 = CRB2d(channels[1], channels[0], relu=convReLU, bn=convBN)
        self.dec1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)

        self.classifier = CRB2d(channels[0], num_hologram, relu=False, bn=False)

    def forward(self, x):
        # Encoder
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)
        enc5_2 = self.enc5_2(enc5_1)

        deconv4 = self.deconv4(enc5_2)
        concat4 = torch.cat((deconv4, enc4_2), dim=1)
        dec4_1 = self.dec4_1(concat4)
        dec4_2 = self.dec4_2(dec4_1)

        deconv3 = self.deconv3(dec4_2)
        concat3 = torch.cat((deconv3, enc3_2), dim=1)
        dec3_1 = self.dec3_1(concat3)
        dec3_2 = self.dec3_2(dec3_1)

        deconv2 = self.deconv2(dec3_2)
        concat2 = torch.cat((deconv2, enc2_2), dim=1)
        dec2_1 = self.dec2_1(concat2)
        dec2_2 = self.dec2_2(dec2_1)

        deconv1 = self.deconv1(dec2_2)
        concat1 = torch.cat((deconv1, enc1_2), dim=1)
        dec1_1 = self.dec1_1(concat1)
        dec1_2 = self.dec1_2(dec1_1)

        # Final classifier
        out = self.classifier(dec1_2)
        out = nn.Sigmoid()(out)
        return out


model = BinaryNet(num_hologram=CH, in_planes=1, convReLU=False,
                  convBN=False, poolReLU=False, poolBN=False,
                  deconvReLU=False, deconvBN=False).cuda()
test = torch.randn(1, 1, IPS, IPS).cuda()
out = model(test)
print(out.shape)


class Dataset512(Dataset):
    def __init__(self, target_dir, meta, transform=None, isTrain=True, padding=0):
        self.target_dir = target_dir
        self.transform = transform
        self.meta = meta
        self.isTrain = isTrain
        self.target_list = sorted(glob.glob(target_dir+'*.png'))
        self.center_crop = torchvision.transforms.CenterCrop(IPS)
        self.random_crop = torchvision.transforms.RandomCrop((IPS, IPS))
        self.padding = padding

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        target = tt.imread(self.target_list[idx], meta=meta, gray=True).unsqueeze(0)
        if target.shape[-1] < IPS or target.shape[-2] < IPS:
            target = torchvision.transforms.Resize(IPS)(target)
        if self.isTrain:
            target = self.random_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        else:
            target = self.center_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        return target


# BinaryHologramEnv 클래스
class BinaryHologramEnv(gym.Env):
    def __init__(self, target_function, trainloader, max_steps=10000, T_PSNR=30, T_steps=10, T_PSNR_DIFF=0.1, max_allowed_changes=1):
        super(BinaryHologramEnv, self).__init__()
        # 관찰 공간: (1, 채널, 픽셀, 픽셀)
        self.observation_space = spaces.Box(low=0, high=1, shape=(4, CH, IPS, IPS), dtype=np.float32)

        # 행동 공간: 픽셀 하나를 선택하는 인덱스 (채널 * 픽셀 *픽셀)
        self.num_pixels = CH * IPS * IPS
        self.action_space = spaces.Discrete(self.num_pixels)

        # 타겟 함수와 데이터 로더 설정
        self.target_function = target_function
        self.trainloader = trainloader

        # 환경 설정
        self.max_steps = max_steps
        self.T_PSNR = T_PSNR
        self.T_steps = T_steps
        self.T_PSNR_DIFF = T_PSNR_DIFF
        self.max_allowed_changes = max_allowed_changes  # 추가된 속성

        # 학습 상태 초기화
        self.state = None
        self.observation = None
        self.steps = 0
        self.psnr_sustained_steps = 0

        # 데이터 로더에서 첫 배치 설정
        self.data_iter = iter(self.trainloader)
        self.target_image = None

        # 실패한 경우 반복 여부
        self.retry_current_target = False  # 현재 데이터셋 반복 여부

        # 연속 실패 관련 변수
        self.consecutive_fail_count = 0  # 연속 실패 횟수
        self.max_consecutive_failures = 0  # 최대 연속 실패 횟수 기록

        # 최고 PSNR_DIFF 추적 변수
        self.max_psnr_diff = float('-inf')  # 가장 높은 PSNR_DIFF를 추적

        self.flip_count = 0

        # PSNR 저장 변수
        self.previous_psnr = None


    def reset(self, seed=None, options=None, z=2e-3):
        torch.cuda.empty_cache()

        try:
            self.target_image = next(self.data_iter)
        except StopIteration:
            print("\033[93m[INFO] Reached the end of dataset. Restarting from the beginning.\033[0m")
            self.data_iter = iter(self.trainloader)
            self.target_image = next(self.data_iter)

        # 연속 실패 처리
        #if self.retry_current_target:  # 이전 에피소드에서 실패한 경우
        #    self.consecutive_fail_count += 1
        #else:
        #    self.consecutive_fail_count = 0  # 성공적인 에피소드로 연속 실패 초기화

        # 최대 연속 실패 기록 갱신
        #self.max_consecutive_failures = max(self.max_consecutive_failures, self.consecutive_fail_count)

        # 실패 플래그에 따라 데이터 유지 또는 새 데이터 로드
        #if not self.retry_current_target:
        #    try:
        #        self.target_image = next(self.data_iter)
        #    except StopIteration:
        #        print("\033[93m[INFO] Reached the end of dataset. Restarting from the beginning.\033[0m")
        #        self.data_iter = iter(self.trainloader)
        #        self.target_image = next(self.data_iter)

        # 현재 사용하는 데이터셋 파일 이름 출력
        dataset_index = next(iter(range(len(self.trainloader.dataset))))
        current_file = self.trainloader.dataset.target_list[dataset_index]
        print(f"\033[93m[Episode Start] Currently using dataset file: {current_file}\033[0m")

        # 매 에피소드마다 최대 PSNR 차이 초기화
        self.max_psnr_diff = float('-inf')

        self.target_image = self.target_image.cuda()

        # 타겟 이미지 형식 출력
        #print(f"[DEBUG]Target image shape: {self.target_image.shape}, dtype: {self.target_image.dtype}")

        self.target_image = self.target_image.cuda()
        with torch.no_grad():
            model_output = self.target_function(self.target_image)
        self.observation = model_output.cpu().numpy()  # (1, 8, 512, 512)

        self.steps = 0
        self.flip_count = 0
        self.psnr_sustained_steps = 0

        # Ensure observation shape is (채널, 픽셀, 픽셀)
        self.observation = model_output.squeeze(0).cpu().numpy()  # (채널, 픽셀, 픽셀)
        self.state = (self.observation >= 0.5).astype(np.int8)  # Binary state

        # 시뮬레이션 전 binary 형상을 (1, 채널, 픽셀, 픽셀)로 복원
        binary = torch.tensor(self.state, dtype=torch.float32).unsqueeze(0).cuda()  # (1, 채널, 픽셀, 픽셀)
        binary = tt.Tensor(binary, meta={'dx': (7.56e-6, 7.56e-6), 'wl': 515e-9})  # meta 정보 포함

        # 시뮬레이션
        sim = tt.simulate(binary, z).abs()**2
        result = torch.mean(sim, dim=1, keepdim=True)

        # MSE 및 PSNR 계산
        mse = tt.relativeLoss(result, self.target_image, F.mse_loss).detach().cpu().numpy()
        self.initial_psnr = tt.relativeLoss(result, self.target_image, tm.get_PSNR)  # 초기 PSNR 저장
        self.previous_psnr = self.initial_psnr # 초기 PSNR 저장

        # target_image_np와 result를 채널 차원(CH=8)으로 확장
        target_image_np = np.repeat(self.target_image.squeeze(0).cpu().numpy(), CH, axis=0)  # 모양: [8, 512, 512]
        result_np = np.repeat(result.squeeze(0).cpu().numpy(), CH, axis=0)  # 모양: [8, 512, 512]

        # 모든 관찰값을 스택으로 결합
        combined_observation = np.stack(
            [self.state, self.observation, target_image_np, result_np], axis=0
        )  # 최종 모양: [4, CH, IPS, IPS]

        print(f"\033[91mResetting environment. Consecutive episode failures: {self.consecutive_fail_count}, Max consecutive episode failures: {self.max_consecutive_failures}\033[0m")

        current_time = datetime.now().strftime("%H:%M:%S")
        print(f"\033[92mInitial MSE: {mse:.6f}, Initial PSNR: {self.initial_psnr:.6f}, {current_time}\033[0m")

        self.retry_current_target = False  # 초기화 후 데이터 반복 플래그 해제

        return combined_observation, {"state": self.state}


    def step(self, action, lr=1e-4, z=2e-3):
        # 행동 전 PSNR 계산
        psnr_before = self.previous_psnr

        # 행동을 기반으로 픽셀 좌표 계산Inv
        channel = action // (IPS * IPS)
        pixel_index = action % (IPS * IPS)
        row = pixel_index // IPS
        col = pixel_index % IPS

        # 플립 전 모델 예측값 가져오기
        pre_flip_value = self.observation[channel, row, col]

        # 상태 변경
        self.state[channel, row, col] = 1 - self.state[channel, row, col]
        self.flip_count += 1  # 플립 증가

        # 현재 상태로 새로운 시뮬레이션 수행
        binary_after = torch.tensor(self.state, dtype=torch.float32).unsqueeze(0).cuda()
        binary_after = tt.Tensor(binary_after, meta={'dx': (7.56e-6, 7.56e-6), 'wl': 515e-9})
        sim_after = tt.simulate(binary_after, z).abs()**2
        result_after = torch.mean(sim_after, dim=1, keepdim=True)
        psnr_after = tt.relativeLoss(result_after, self.target_image, tm.get_PSNR)

        # 시뮬레이션 결과를 NumPy로 변환
        result_np = np.repeat(result_after.squeeze(0).cpu().numpy(), CH, axis=0)

        target_image_np = np.repeat(self.target_image.squeeze(0).cpu().numpy(), CH, axis=0)

        # Combined observation 생성 후 출력
        combined_observation = np.stack(
            [self.state, self.observation, target_image_np, result_np], axis=0
        )

        # PSNR 변화량 계산
        psnr_change = psnr_after - psnr_before
        psnr_diff = psnr_after - self.initial_psnr
        is_max_psnr_diff = psnr_diff > self.max_psnr_diff  # 최고 PSNR_DIFF 확인
        self.max_psnr_diff = max(self.max_psnr_diff, psnr_diff)  # 최고 PSNR_DIFF 업데이트

        # psnr_change가 음수인 경우 상태 롤백 수행
        if psnr_change < 0:

            failed_observation = combined_observation

            failed_action = action
            failed_reward = psnr_change * 800  # PSNR 변화량(psnr_change)에 기반한 보상

            # 이전 스텝의 누적 보상을 안전하게 초기화
            previous_cumulative_reward = 0
            if "info" in locals() and isinstance(info, dict):
                previous_cumulative_reward = info.get("cumulative_reward", 0)

            # 플립된 픽셀을 원래대로 복구
            self.state[channel, row, col] = 1 - self.state[channel, row, col]
            self.flip_count -= 1

            # 출력 추가 (100 스텝마다 출력)
            if self.steps % 100 == 0:
                current_time = datetime.now().strftime("%H:%M:%S")
                print(
                    f"Step: {self.steps}, PSNR Before: {psnr_before:.6f}, PSNR After: {psnr_after:.6f}, "
                    f"PSNR Change: {psnr_change:.6f}, PSNR Diff: {psnr_diff:.6f} (New Max), "
                    f"Reward: {failed_reward:.2f}, {current_time} "
                    f"Pre-flip Model Output={pre_flip_value:.6f}, "
                    f"New State Value={self.state[channel, row, col]}, "
                    f"Flip Count={self.flip_count}"
                )

            # 스텝 증가
            self.steps += 1

            # 실패 정보 생성
            info = {
                "psnr_before": psnr_before,
                "psnr_after": psnr_after,
                "psnr_change": psnr_change,
                "psnr_diff": psnr_diff,
                "pre_flip_value": pre_flip_value,
                "state_before": self.state.copy(),  # 행동 이전 상태
                "state_after": None,  # 실패한 경우에는 상태를 업데이트하지 않음
                "observation_before": self.observation.copy(),  # 행동 이전 관찰값
                "observation_after": None,  # 실패한 경우 관찰값 업데이트 없음
                "failed_action": failed_action,  # 실패한 행동
                "flip_count": self.flip_count,  # 현재까지의 플립 횟수
                "reward": failed_reward,
                "target_image": self.target_image.cpu().numpy(),  # 타겟 이미지
                "simulation_result": result_np,  # 현재 시뮬레이션 결과
                "step": self.steps  # 현재 스텝
            }
            return failed_observation, failed_reward, False, False, info

        # 보상 계산
        reward = psnr_change * 800  # PSNR 변화량(psnr_change)에 기반한 보상

        # 출력 추가 (100 스텝마다 출력)
        if self.steps % 100 == 0:
            current_time = datetime.now().strftime("%H:%M:%S")
            print(
                f"Step: {self.steps}, PSNR Before: {psnr_before:.6f}, PSNR After: {psnr_after:.6f}, "
                f"PSNR Change: {psnr_change:.6f}, PSNR Diff: {psnr_diff:.6f} (New Max), "
                f"Reward: {reward:.2f}, {current_time} "
                f"Pre-flip Model Output={pre_flip_value:.6f}, "
                f"New State Value={self.state[channel, row, col]}, "
                f"Flip Count={self.flip_count}"
            )

        self.previous_psnr = psnr_after

        # 성공 종료 조건: PSNR >= T_PSNR 또는 PSNR_DIFF >= T_PSNR_DIFF
        terminated = self.steps >= self.max_steps or self.psnr_sustained_steps >= self.T_steps
        truncated = self.steps >= self.max_steps

        if psnr_after >= self.T_PSNR or psnr_diff >= self.T_PSNR_DIFF:
            current_time = datetime.now().strftime("%H:%M:%S")
            print(
                f"\033[94mStep: {self.steps}, PSNR Before: {psnr_before:.6f}, PSNR After: {psnr_after:.6f}, "
                f"PSNR Change: {psnr_change:.6f}, PSNR Diff: {psnr_diff:.6f} (New Max), "
                f"Reward: {reward:.2f}, {current_time} "
                f"Pre-flip Model Output={pre_flip_value:.6f}, "
                f"New State Value={self.state[channel, row, col]}, "
                f"Flip Count={self.flip_count}\033[0m"
            )
            self.psnr_sustained_steps += 1
            if self.psnr_sustained_steps >= self.T_steps:  # 성공 에피소드 조건
                reward += 100  # 에피소드 성공 시 추가 보상
        else:
            self.psnr_sustained_steps = 0

        # 관찰값 업데이트
        info = {
            "psnr_before": psnr_before,
            "psnr_after": psnr_after,
            "psnr_change": psnr_change,
            "psnr_diff": psnr_diff,
            "pre_flip_value": pre_flip_value,
            "state_before": self.state.copy(),  # 행동 이전 상태
            "state_after": self.state.copy() if psnr_change >= 0 else None,  # 행동 성공 시 상태
            "observation_before": self.observation.copy(),  # 행동 이전 관찰값
            "observation_after": combined_observation if psnr_change >= 0 else None,  # 행동 성공 시 관찰값
            "failed_action": action if psnr_change < 0 else None,  # 실패한 행동
            "flip_count": self.flip_count,  # 현재까지의 플립 횟수
            "reward": reward,
            "target_image": self.target_image.cpu().numpy(),  # 타겟 이미지
            "simulation_result": result_np,  # 현재 시뮬레이션 결과
            "action_coords": (channel, row, col),  # 행동한 좌표
            "step": self.steps  # 현재 스텝
        }

        self.steps += 1

        return combined_observation, reward, terminated, truncated, info


batch_size = 1
target_dir = '/nfs/dataset/DIV2K/DIV2K_train_HR/DIV2K_train_HR/'
valid_dir = '/nfs/dataset/DIV2K/DIV2K_valid_HR/DIV2K_valid_HR/'
meta = {'wl': (515e-9), 'dx': (7.56e-6, 7.56e-6)}  # 메타 정보
padding = 0

# Dataset512 클래스 사용
train_dataset = Dataset512(target_dir=target_dir, meta=meta, isTrain=True, padding=padding)
valid_dataset = Dataset512(target_dir=valid_dir, meta=meta, isTrain=False, padding=padding)

# DataLoader 생성
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# BinaryNet 모델 로드
model = BinaryNet(num_hologram=CH, in_planes=1, convReLU=False, convBN=False,
                  poolReLU=False, poolBN=False, deconvReLU=False, deconvBN=False).cuda()
model.load_state_dict(torch.load('result_v/2024-12-19 20:37:52.499731_pre_reinforce_8_0.002/2024-12-19 20:37:52.499731_pre_reinforce_8_0.002'))
model.eval()

# 환경 생성에 새로운 데이터 로더 적용
env = BinaryHologramEnv(
    target_function=model,
    trainloader=train_loader,  # 업데이트된 train_loader 사용
    #max_steps = 10000,
    #T_PSNR = 30,
    #T_steps = 10
    #T_PSNR_DIFF = 0.1
)

from stable_baselines3.common.callbacks import BaseCallback

class StopOnEpisodeCallback(BaseCallback):
    def __init__(self, max_episodes, verbose=1):
        super(StopOnEpisodeCallback, self).__init__(verbose)
        self.max_episodes = max_episodes
        self.episode_count = 0  # 에피소드 수를 추적

    def _on_step(self) -> bool:
        # `done`이 True일 때마다 에피소드 증가
        if self.locals.get("dones") is not None:
            self.episode_count += np.sum(self.locals["dones"])  # 에피소드 완료 횟수 추가

        #if self.verbose > 0 and self.episode_count % 10 == 0:
            #print(f"Episode count: {self.episode_count}/{self.max_episodes}")

        # 최대 에피소드 도달 시 학습 종료
        if self.episode_count >= self.max_episodes:
            print(f"Stopping training at episode {self.episode_count}")
            return False  # 학습 중단

        return True  # 학습 계속


# 저장할 폴더 경로 설정
save_dir = "./ppo_MlpPolicy_models/"  # 모델 저장 디렉토리
os.makedirs(save_dir, exist_ok=True)  # 디렉토리가 없으면 생성

# PPO 모델과 PSNRPredictor 경로 설정
ppo_model_path = os.path.join(save_dir, "ppo_MlpPolicy_latest.zip")  # 최신 PPO 모델 저장 경로
resume_training = True  # True로 설정하면 이전 모델에서 학습 재개

# PPO 모델 로드 또는 새로 생성
if resume_training and os.path.exists(ppo_model_path):
    print(f"Loading trained PPO model from {ppo_model_path}")
    ppo_model = PPO.load(ppo_model_path, env=env)
else:
    if resume_training:
        print(f"Warning: PPO model not found at {ppo_model_path}. Starting training from scratch.")
    print("Starting training from scratch.")
    ppo_model = PPO(
        "MlpPolicy",
        env,
        verbose=2,
        n_steps=512,
        batch_size=128,
        gamma=0.99,
        gae_lambda=0.9,
        learning_rate=1e-4,
        clip_range=0.2,
        vf_coef=0.5,
        max_grad_norm=0.5,
        ent_coef=0.01,
        tensorboard_log="./ppo_MlpPolicy/"
    )

# 학습을 종료할 최대 에피소드 설정
max_episodes = 800  # 원하는 에피소드 수
stop_callback = StopOnEpisodeCallback(max_episodes=max_episodes)

# 학습 시작 (max_episodes 이후 자동 종료)
ppo_model.learn(total_timesteps=1000000, callback=stop_callback)

# 모델 저장
print(f"Start model saving at {save_dir}")
current_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
ppo_model_save_path = os.path.join(save_dir, f"ppo_MlpPolicy_{current_date}.zip")
ppo_model.save(ppo_model_save_path)
print(f"PPO Model saved at {ppo_model_save_path}")

# 최신 모델 업데이트
print(f"Start model updating at {save_dir}")
ppo_model_latest_path = os.path.join(save_dir, "ppo_MlpPolicy_latest.zip")
ppo_model.save(ppo_model_latest_path)
print(f"Latest PPO Model updated at {ppo_model_latest_path}")

이 메시지는 콘솔과 파일에 동시에 기록됩니다.


  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


torch.Size([1, 8, 256, 256])
Starting training from scratch.
Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
[93m[Episode Start] Currently using dataset file: /nfs/dataset/DIV2K/DIV2K_train_HR/DIV2K_train_HR/0001.png[0m
[91mResetting environment. Consecutive episode failures: 0, Max consecutive episode failures: 0[0m
[92mInitial MSE: 0.001907, Initial PSNR: 23.926991, 12:56:48[0m
Logging to ./ppo_MlpPolicy/PPO_3
Step: 0, PSNR Before: 23.926991, PSNR After: 23.926935, PSNR Change: -0.000055, PSNR Diff: -0.000055 (New Max), Reward: -0.04, 12:56:49 Pre-flip Model Output=0.305036, New State Value=0, Flip Count=0
Step: 100, PSNR Before: 23.932129, PSNR After: 23.931915, PSNR Change: -0.000214, PSNR Diff: 0.004925 (New Max), Reward: -0.17, 12:56:50 Pre-flip Model Output=0.262809, New State Value=0, Flip Count=27
Step: 200, PSNR Before: 23.935947, PSNR After: 23.936066, PSNR Change: 0.000118, PSNR Diff: 0.009075 (New Max), Reward: 0.09, 12: