In [1]:
import sys
import logging
from datetime import datetime
import os

# 로그를 저장할 디렉토리 설정
log_dir = "log"
os.makedirs(log_dir, exist_ok=True)  # 디렉토리가 없으면 생성

# 현재 파일 이름과 실행 시간 가져오기
if '__file__' in globals():
    current_file = os.path.splitext(os.path.basename(__file__))[0]  # 현재 파일 이름(확장자 제거)
else:
    current_file = "interactive"  # 인터프리터나 노트북 환경에서 기본 파일 이름 사용

current_datetime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")  # 현재 시간
log_filename = os.path.join(log_dir, f"{current_file}_{current_datetime}.log")  # log 폴더에 파일 저장

# 로그 설정
logging.basicConfig(
    level=logging.INFO,
    format='%(message)s',
    handlers=[
        logging.FileHandler(log_filename),  # 동적으로 생성된 파일 이름 사용
        logging.StreamHandler()  # 콘솔 출력
    ]
)

class Tee:
    def __init__(self, *files):
        self.files = files

    def write(self, data):
        for file in self.files:
            file.write(data)
            file.flush()  # 실시간 저장

    def flush(self):
        for file in self.files:
            file.flush()


# stdout을 파일과 콘솔로 동시에 출력
log_file = open(log_filename, "a")
sys.stdout = Tee(sys.stdout, log_file)

# 테스트 출력
print("이 메시지는 콘솔과 파일에 동시에 기록됩니다.")
logging.info("로깅 메시지입니다.")

import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import MaskablePPO
from stable_baselines3.common.policies import ActorCriticPolicy
import gymnasium as gym
from gymnasium import spaces
from datetime import datetime
import glob
import torchOptics.optics as tt
import torch.nn as nn
import torchOptics.metrics as tm
import torch.nn.functional as F
import torch.optim
import torch
from torch.utils.data import Dataset, DataLoader
import os
import glob
import matplotlib.pyplot as plt
import pickle
import torchvision
import tqdm
import time
import pandas as pd
from sb3_contrib.common.maskable.utils import get_action_masks
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3 import PPO
import warnings

IPS = 512
CH = 8

warnings.filterwarnings('ignore')

# 현재 날짜와 시간을 가져와 포맷 지정
current_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

torch.backends.cudnn.enabled = False

class BinaryNet(nn.Module):
    def __init__(self, num_hologram, final='Sigmoid', in_planes=3,
                 channels=[32, 64, 128, 256, 512, 1024, 2048, 4096],
                 convReLU=True, convBN=True, poolReLU=True, poolBN=True,
                 deconvReLU=True, deconvBN=True):
        super(BinaryNet, self).__init__()

        def CRB2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias=bias)]
            if relu:
                layers += [nn.Tanh()]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]

            cbr = nn.Sequential(*layers)  # *으로 list unpacking

            return cbr

        def TRB2d(in_channels, out_channels, kernel_size=2, stride=2, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                          kernel_size=2, stride=2, padding=0,
                                          bias=True)]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]
            if relu:
                layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers)  # *으로 list unpacking

            return cbr

        self.enc1_1 = CRB2d(in_planes, channels[0], relu=convReLU, bn=convBN)
        self.enc1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)
        self.pool1 = CRB2d(channels[0], channels[0], stride=2, relu=poolReLU, bn=poolBN)

        self.enc2_1 = CRB2d(channels[0], channels[1], relu=convReLU, bn=convBN)
        self.enc2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)
        self.pool2 = CRB2d(channels[1], channels[1], stride=2, relu=poolReLU, bn=poolBN)

        self.enc3_1 = CRB2d(channels[1], channels[2], relu=convReLU, bn=convBN)
        self.enc3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)
        self.pool3 = CRB2d(channels[2], channels[2], stride=2, relu=poolReLU, bn=poolBN)

        self.enc4_1 = CRB2d(channels[2], channels[3], relu=convReLU, bn=convBN)
        self.enc4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)
        self.pool4 = CRB2d(channels[3], channels[3], stride=2, relu=poolReLU, bn=poolBN)

        self.enc5_1 = CRB2d(channels[3], channels[4], relu=convReLU, bn=convBN)
        self.enc5_2 = CRB2d(channels[4], channels[4], relu=convReLU, bn=convBN)

        self.deconv4 = TRB2d(channels[4], channels[3], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec4_1 = CRB2d(channels[4], channels[3], relu=convReLU, bn=convBN)
        self.dec4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)

        self.deconv3 = TRB2d(channels[3], channels[2], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec3_1 = CRB2d(channels[3], channels[2], relu=convReLU, bn=convBN)
        self.dec3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)

        self.deconv2 = TRB2d(channels[2], channels[1], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec2_1 = CRB2d(channels[2], channels[1], relu=convReLU, bn=convBN)
        self.dec2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)

        self.deconv1 = TRB2d(channels[1], channels[0], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec1_1 = CRB2d(channels[1], channels[0], relu=convReLU, bn=convBN)
        self.dec1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)

        self.classifier = CRB2d(channels[0], num_hologram, relu=False, bn=False)

    def forward(self, x):
        # Encoder
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)
        enc5_2 = self.enc5_2(enc5_1)

        deconv4 = self.deconv4(enc5_2)
        concat4 = torch.cat((deconv4, enc4_2), dim=1)
        dec4_1 = self.dec4_1(concat4)
        dec4_2 = self.dec4_2(dec4_1)

        deconv3 = self.deconv3(dec4_2)
        concat3 = torch.cat((deconv3, enc3_2), dim=1)
        dec3_1 = self.dec3_1(concat3)
        dec3_2 = self.dec3_2(dec3_1)

        deconv2 = self.deconv2(dec3_2)
        concat2 = torch.cat((deconv2, enc2_2), dim=1)
        dec2_1 = self.dec2_1(concat2)
        dec2_2 = self.dec2_2(dec2_1)

        deconv1 = self.deconv1(dec2_2)
        concat1 = torch.cat((deconv1, enc1_2), dim=1)
        dec1_1 = self.dec1_1(concat1)
        dec1_2 = self.dec1_2(dec1_1)

        # Final classifier
        out = self.classifier(dec1_2)
        out = nn.Sigmoid()(out)
        return out


model = BinaryNet(num_hologram=CH, in_planes=1, convReLU=False,
                  convBN=False, poolReLU=False, poolBN=False,
                  deconvReLU=False, deconvBN=False).cuda()
test = torch.randn(1, 1, IPS, IPS).cuda()
out = model(test)
print(out.shape)


class Dataset512(Dataset):
    def __init__(self, target_dir, meta, transform=None, isTrain=True, padding=0):
        self.target_dir = target_dir
        self.transform = transform
        self.meta = meta
        self.isTrain = isTrain
        self.target_list = sorted(glob.glob(target_dir+'*.png'))
        self.center_crop = torchvision.transforms.CenterCrop(IPS)
        self.random_crop = torchvision.transforms.RandomCrop((IPS, IPS))
        self.padding = padding

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        target = tt.imread(self.target_list[idx], meta=meta, gray=True).unsqueeze(0)
        if target.shape[-1] < IPS or target.shape[-2] < IPS:
            target = torchvision.transforms.Resize(IPS)(target)
        if self.isTrain:
            target = self.random_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        else:
            target = self.center_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        return target


# BinaryHologramEnv 클래스
class BinaryHologramEnv(gym.Env):
    def __init__(self, target_function, trainloader, max_steps=10000, T_PSNR=30, T_steps=10, T_PSNR_DIFF=0.1, max_allowed_changes=1):
        """
        환경 초기화 함수.
        target_function: 타겟 이미지와의 손실(MSE 또는 PSNR) 계산 함수.
        trainloader: 학습 데이터셋 로더.
        max_steps: 최대 타임스텝 제한.
        T_PSNR: 목표 PSNR 값.
        T_steps: PSNR 목표를 유지해야 하는 최소 타임스텝.
        T_PSNR_DIFF: PSNR 차이의 임계값.
        max_allowed_changes: 한 번에 변경 가능한 최대 픽셀 수 (기본값: 1).
        """
        super(BinaryHologramEnv, self).__init__()

        # 관찰 공간: (1, 채널, 픽셀, 픽셀)
        self.observation_space = spaces.Box(low=0, high=1, shape=(4, CH, IPS, IPS), dtype=np.float32)

        # 행동 공간: 픽셀 하나를 선택하는 인덱스 (채널 * 픽셀 *픽셀)
        self.num_pixels = CH * IPS * IPS
        self.action_space = spaces.Discrete(self.num_pixels)

        # 타겟 함수와 데이터 로더 설정
        self.target_function = target_function
        self.trainloader = trainloader

        # 환경 설정
        self.max_steps = max_steps
        self.T_PSNR = T_PSNR
        self.T_steps = T_steps
        self.T_PSNR_DIFF = T_PSNR_DIFF
        self.max_allowed_changes = max_allowed_changes  # 추가된 속성

        # 학습 상태 초기화
        self.state = None
        self.observation = None
        self.steps = 0
        self.psnr_sustained_steps = 0

        # 데이터 로더에서 첫 배치 설정
        self.data_iter = iter(self.trainloader)
        self.target_image = None

        # 실패한 경우 반복 여부
        self.retry_current_target = False  # 현재 데이터셋 반복 여부

        # 연속 실패 관련 변수
        self.consecutive_fail_count = 0  # 연속 실패 횟수
        self.max_consecutive_failures = 0  # 최대 연속 실패 횟수 기록

        # 최고 PSNR_DIFF 추적 변수
        self.max_psnr_diff = float('-inf')  # 가장 높은 PSNR_DIFF를 추적

        self.flip_count = 0

    def reset(self, seed=None, options=None, z=2e-3):
        """
        환경 초기화 함수.
        데이터셋에서 새로운 이미지를 가져오고 초기 상태를 설정합니다.
        - 데이터셋의 다음 이미지를 불러옵니다. 
        - BinaryNet을 사용해 초기 관찰값을 생성합니다.
        - 초기 상태(state)는 관찰값을 이진화한 결과입니다.
        - 초기 PSNR과 MSE를 계산하고 출력합니다.
        - 실패 시 이전 데이터를 다시 불러옵니다.

        Args:
            seed (int, optional): 랜덤 시드 값. Default는 None.
            options (dict, optional): 추가 옵션. Default는 None.
            lr (float, optional): 학습률. Default는 1e-4.
            z (float, optional): 시뮬레이션 거리. Default는 2e-3.

        Returns:
            observation (np.ndarray): 초기 관찰값.
            dict: 초기 상태와 행동 마스크.
        """
        torch.cuda.empty_cache()

        if self.retry_current_target:  # 이전 에피소드에서 실패한 경우
            self.consecutive_fail_count += 1
        else:
            self.consecutive_fail_count = 0  # 성공적인 에피소드로 연속 실패 초기화

        self.max_consecutive_failures = max(self.max_consecutive_failures, self.consecutive_fail_count)

        if not self.retry_current_target:  # 실패한 경우 현재 데이터를 다시 사용
            try:
                self.target_image = next(self.data_iter)
            except StopIteration:
                self.data_iter = iter(self.trainloader)
                self.target_image = next(self.data_iter)

        # 매 에피소드마다 최대 PSNR 차이 초기화
        self.max_psnr_diff = float('-inf')

        self.target_image = self.target_image.cuda()

        # 타겟 이미지 형식 출력
        #print(f"[DEBUG]Target image shape: {self.target_image.shape}, dtype: {self.target_image.dtype}")

        self.target_image = self.target_image.cuda()
        with torch.no_grad():
            model_output = self.target_function(self.target_image)
        self.observation = model_output.cpu().numpy()  # (1, 8, 512, 512)

        self.steps = 0
        self.psnr_sustained_steps = 0

        # Ensure observation shape is (채널, 픽셀, 픽셀)
        self.observation = model_output.squeeze(0).cpu().numpy()  # (채널, 픽셀, 픽셀)
        self.state = (self.observation >= 0.5).astype(np.int8)  # Binary state

        # 시뮬레이션 전 binary 형상을 (1, 채널, 픽셀, 픽셀)로 복원
        binary = torch.tensor(self.state, dtype=torch.float32).unsqueeze(0).cuda()  # (1, 채널, 픽셀, 픽셀)
        binary = tt.Tensor(binary, meta={'dx': (7.56e-6, 7.56e-6), 'wl': 515e-9})  # meta 정보 포함

        # 시뮬레이션
        sim = tt.simulate(binary, z).abs()**2
        result = torch.mean(sim, dim=1, keepdim=True)

        # MSE 및 PSNR 계산
        mse = tt.relativeLoss(result, self.target_image, F.mse_loss).detach().cpu().numpy()
        self.initial_psnr = tt.relativeLoss(result, self.target_image, tm.get_PSNR)  # 초기 PSNR 저장

        # target_image_np와 result를 채널 차원(CH=8)으로 확장
        target_image_np = np.repeat(self.target_image.squeeze(0).cpu().numpy(), CH, axis=0)  # 모양: [8, 512, 512]
        result_np = np.repeat(result.squeeze(0).cpu().numpy(), CH, axis=0)  # 모양: [8, 512, 512]

        #print(f"[DEBUG]reset")
        #print(f"[DEBUG]self.state shape: {self.state.shape}, type: {type(self.state)}")
        #print(f"[DEBUG]self.observation shape: {self.observation.shape}, type: {type(self.observation)}")
        #print(f"[DEBUG]target_image_np shape: {target_image_np.shape}, type: {type(target_image_np)}")
        #print(f"[DEBUG]result_np shape: {result_np.shape}, type: {type(result_np)}")


        # 모든 관찰값을 스택으로 결합
        combined_observation = np.stack(
            [self.state, self.observation, target_image_np, result_np], axis=0
        )  # 최종 모양: [4, 8, 512, 512]

        #print(f"[DEBUG]Combined observation shape: {combined_observation.shape}, type: {type(combined_observation)}")

        print(f"\033[91mResetting environment. Consecutive episode failures: {self.consecutive_fail_count}, Max consecutive episode failures: {self.max_consecutive_failures}\033[0m")

        current_time = datetime.now().strftime("%H:%M:%S")
        print(f"\033[92mInitial MSE: {mse:.6f}, Initial PSNR: {self.initial_psnr:.6f}, {current_time}\033[0m")


        self.retry_current_target = False  # 초기화 후 데이터 반복 플래그 해제
        mask = self.create_action_mask(self.observation)
        
        #print(f"[DEBUG] Reset Observation: combined_observation shape={combined_observation.shape}, type={type(combined_observation)}")
        return combined_observation, {"state": self.state, "mask": mask}

    def initialize_state(self, z=2e-3): #아마도 사용하지 않는 레거시 코드드
        """
        초기 상태를 생성하고, 시뮬레이션 및 관련 값을 계산합니다.

        Args:
            z (float): 시뮬레이션 거리. Default는 2e-3.

        Returns:
            observation (np.ndarray): 초기 관찰값.
            dict: 초기 상태와 행동 마스크.
        """
        with torch.no_grad():
            # 모델로 초기 관찰값 생성
            model_output = self.target_function(self.target_image)
        self.observation = model_output.cpu().numpy()  # 관찰값을 numpy 배열로 변환

        # 초기 상태는 이진화된 값으로 설정
        self.state = (self.observation >= 0.5).astype(np.int8)

        binary = torch.tensor(self.state, dtype=torch.float32).cuda()  # 상태를 Torch 텐서로 변환
        binary = tt.Tensor(binary, meta={'dx': (7.56e-6, 7.56e-6), 'wl': 515e-9})  # 메타 정보 추가

        # 시뮬레이션 수행
        sim = tt.simulate(binary, z).abs()**2
        result = torch.mean(sim, dim=1, keepdim=True)

        # 초기 MSE와 PSNR 계산
        mse = tt.relativeLoss(result, self.target_image, F.mse_loss).detach().cpu().numpy()
        psnr = tt.relativeLoss(result, self.target_image, tm.get_PSNR)

        # 초기 값 출력
        print(f"Initial MSE: {mse:.6f}, Initial PSNR: {psnr:.6f}, {datetime.now()}")

        # 시뮬레이션 결과를 별도로 저장
        self.simulation_result = result.detach().cpu().numpy()

        # 마스크 생성
        mask = self.create_action_mask(self.observation)

        # 관찰값(초기 모델 출력)을 반환
        return self.observation, {"state": self.state, "mask": mask}

    def create_action_mask(self, observation):
        """
        관찰값에 따라 행동 마스크 생성.
        - 관찰값이 0.2 ~ 0.8 범위에 해당하는 픽셀만 변경 가능.

        Args:
            observation (np.ndarray): 관찰값.

        Returns:
            np.ndarray: 가능한 행동에 대해 1, 불가능한 행동에 대해 0.
        """
        # 모든 픽셀을 고려한 기본 마스크
        mask = np.zeros(self.num_pixels, dtype=np.int8)

        # (1, 채널, 픽셀, 픽셀) -> (채널, 픽셀, 픽셀)로 변환
        obs = self.observation.squeeze()

        # 조건을 만족하는 위치에 대해 마스크 설정
        for channel in range(obs.shape[0]):
            indices = np.where((obs[channel] > 0) & (obs[channel] < 1))
            for row, col in zip(*indices):
                pixel_idx = channel * IPS * IPS + row * IPS + col
                mask[pixel_idx] = 1  # 가능한 행동으로 설정

        return mask

    def step(self, action, lr=1e-4, z=2e-3):
        """
        환경의 한 타임스텝을 진행합니다.
        - 주어진 행동(action)을 적용하고, 새로운 상태를 계산합니다.
        - 보상은 행동 전후 PSNR 변화량(psnr_change)을 기반으로 계산합니다.
        - psnr_change가 0보다 작을 경우 잘못된 행동으로 처리됩니다.

        Args:
            action (np.ndarray): 에이전트가 수행한 행동.
            lr (float, optional): 학습률. Default는 1e-4.
            z (float, optional): 시뮬레이션 거리. Default는 2e-3.

        Returns:
            observation (np.ndarray): 새로운 관찰값.
            float: 보상 값.
            bool: 종료 여부.
            bool: Truncated 여부.
            dict: 추가 정보 (MSE, PSNR, PSNR_DIFF, 행동 마스크 등).
        """
        #if self.steps == 0:
        #    print("Executing reset logic for the first step")
        #    self.steps += 1
        #    observation, info = self.initialize_state(z)
        #    return observation, 0.0, False, False, info

        # 행동 마스크를 적용
        mask = self.create_action_mask(self.observation)
        if mask.flatten()[action] == 0:
            # 잘못된 행동 시 패널티와 함께 상태 유지
            #print(f"Invalid action taken at step {self.steps}, action: {action}")
            return self.observation, -10.0, False, False, {"mask": mask}

        # 행동 전 PSNR 계산
        binary_before = torch.tensor(self.state, dtype=torch.float32).unsqueeze(0).cuda()
        binary_before = tt.Tensor(binary_before, meta={'dx': (7.56e-6, 7.56e-6), 'wl': 515e-9})
        sim_before = tt.simulate(binary_before, z).abs()**2
        result_before = torch.mean(sim_before, dim=1, keepdim=True)
        psnr_before = tt.relativeLoss(result_before, self.target_image, tm.get_PSNR)

        # 행동을 기반으로 픽셀 좌표 계산
        channel = action // (IPS * IPS)
        pixel_index = action % (IPS * IPS)
        row = pixel_index // IPS
        col = pixel_index % IPS

        # 플립 전 모델 예측값 가져오기
        pre_flip_value = self.observation[channel, row, col]

        # 상태 변경
        self.state[channel, row, col] = 1 - self.state[channel, row, col]
        self.flip_count += 1  # 플립 증가

        # 현재 상태로 새로운 시뮬레이션 수행
        binary_after = torch.tensor(self.state, dtype=torch.float32).unsqueeze(0).cuda()
        binary_after = tt.Tensor(binary_after, meta={'dx': (7.56e-6, 7.56e-6), 'wl': 515e-9})
        sim_after = tt.simulate(binary_after, z).abs()**2
        result_after = torch.mean(sim_after, dim=1, keepdim=True)
        psnr_after = tt.relativeLoss(result_after, self.target_image, tm.get_PSNR)

        # 시뮬레이션 결과를 NumPy로 변환
        result_np = np.repeat(result_after.squeeze(0).cpu().numpy(), CH, axis=0)

        target_image_np = np.repeat(self.target_image.squeeze(0).cpu().numpy(), CH, axis=0)

        #print(f"[DEBUG]step")
        #print(f"[DEBUG]self.state shape: {self.state.shape}, type: {type(self.state)}")
        #print(f"[DEBUG]self.observation shape: {self.observation.shape}, type: {type(self.observation)}")
        #print(f"[DEBUG]target_image_np shape: {target_image_np.shape}, type: {type(target_image_np)}")
        #print(f"[DEBUG]result_np shape: {result_np.shape}, type: {type(result_np)}")

        # Combined observation 생성 후 출력
        combined_observation = np.stack(
            [self.state, self.observation, target_image_np, result_np], axis=0
        )

        #print(f"[DEBUG]Combined observation shape: {combined_observation.shape}, type: {type(combined_observation)}")

        # PSNR 변화량 계산
        psnr_change = psnr_after - psnr_before

        # 기존 PSNR_DIFF 계산
        psnr_diff = psnr_after - self.initial_psnr
        is_max_psnr_diff = psnr_diff > self.max_psnr_diff  # 최고 PSNR_DIFF 확인
        self.max_psnr_diff = max(self.max_psnr_diff, psnr_diff)  # 최고 PSNR_DIFF 업데이트

        # psnr_change가 음수인 경우 상태 롤백만 수행
        if psnr_change < 0:

            # print(f"[DEBUG]rollback")
            # print(f"[DEBUG]self.state shape: {self.state.shape}, type: {type(self.state)}")
            # print(f"[DEBUG]self.observation shape: {self.observation.shape}, type: {type(self.observation)}")
            # print(f"[DEBUG]target_image_np shape: {target_image_np.shape}, type: {type(target_image_np)}")
            # print(f"[DEBUG]result_np shape: {result_np.shape}, type: {type(result_np)}")

            failed_observation = np.stack(
                [self.state, self.observation, target_image_np, result_np], axis=0
            )  # 최종 모양: [4, 8, 512, 512]

            # print(f"[DEBUG]failed_observationshape: {failed_observation.shape}, type: {type(failed_observation)}")

            failed_action = action
            failed_reward = -10  # 실패에 대한 보상

            # 이전 스텝의 누적 보상을 안전하게 초기화
            previous_cumulative_reward = 0
            if "info" in locals() and isinstance(info, dict):
                previous_cumulative_reward = info.get("cumulative_reward", 0)

            # 실패 정보 생성
            info = {
                "psnr_before": psnr_before,
                "psnr_after": psnr_after,
                "psnr_change": psnr_change,
                "psnr_diff": psnr_diff,
                "mask": mask,
                "pre_flip_value": pre_flip_value,
                "state_before": self.state.copy(),  # 행동 이전 상태
                "state_after": None,  # 실패한 경우에는 상태를 업데이트하지 않음
                "observation_before": self.observation.copy(),  # 행동 이전 관찰값
                "observation_after": None,  # 실패한 경우 관찰값 업데이트 없음
                "failed_action": failed_action,  # 실패한 행동
                "flip_count": self.flip_count,  # 현재까지의 플립 횟수
                "reward": failed_reward,
                "target_image": self.target_image.cpu().numpy(),  # 타겟 이미지
                "simulation_result": result_np,  # 현재 시뮬레이션 결과
                "step": self.steps  # 현재 스텝
            }

            # 플립된 픽셀을 원래대로 복구
            self.state[channel, row, col] = 1 - self.state[channel, row, col]
            self.flip_count -= 1

            # 스텝 증가
            self.steps += 1

            return failed_observation, failed_reward, False, False, info


        # PSNR_CHANGE가 0보다 작을 경우 잘못된 행동으로 처리
        #if psnr_change < 0:
            #print(f"Invalid action: PSNR Change {psnr_change:.6f} < 0 at step {self.steps}")
        #    return self.observation, -10.0, False, False, {"psnr_before": psnr_before, "psnr_after": psnr_after, "psnr_change": psnr_change, "mask": mask}

        # 보상 계산
        reward = psnr_diff * 800  # PSNR 변화량(psnr_change)에 기반한 보상

        # 실패 조건 확인
        if psnr_diff < -0.01:
            print(f"\033[91mEpisode failed: PSNR Diff {psnr_diff:.6f} < -0.01 at step {self.steps}\033[0m")
            self.retry_current_target = True  # 실패 시 반복 플래그 활성화
            return self.observation, -100.0, True, False, {"psnr_diff": psnr_diff, "mask": None}

        # 최고 PSNR_DIFF일 때 출력
        if is_max_psnr_diff:
            current_time = datetime.now().strftime("%H:%M:%S")
            print(
                f"\033[94mStep: {self.steps}, PSNR Before: {psnr_before:.6f}, PSNR After: {psnr_after:.6f}, "
                f"PSNR Change: {psnr_change:.6f}, PSNR Diff: {psnr_diff:.6f} (New Max), "
                f"Reward: {reward:.2f}, {current_time} "
                f"Pre-flip Model Output={pre_flip_value:.6f}, "
                f"New State Value={self.state[channel, row, col]}, "
                f"Flip Count={self.flip_count}\033[0m"
            )

        # 출력 추가 (100 스텝마다 출력)
        if self.steps % 100 == 0:
            current_time = datetime.now().strftime("%H:%M:%S")
            print(
                f"Step: {self.steps}, PSNR Before: {psnr_before:.6f}, PSNR After: {psnr_after:.6f}, "
                f"PSNR Change: {psnr_change:.6f}, PSNR Diff: {psnr_diff:.6f} (New Max), "
                f"Reward: {reward:.2f}, {current_time} "
                f"Pre-flip Model Output={pre_flip_value:.6f}, "
                f"New State Value={self.state[channel, row, col]}, "
                f"Flip Count={self.flip_count}"
            )

        # 성공 종료 조건: PSNR >= T_PSNR 또는 PSNR_DIFF >= T_PSNR_DIFF
        terminated = self.steps >= self.max_steps or self.psnr_sustained_steps >= self.T_steps
        truncated = self.steps >= self.max_steps

        if psnr_after >= self.T_PSNR or psnr_diff >= self.T_PSNR_DIFF:
            self.psnr_sustained_steps += 1
            if self.psnr_sustained_steps >= self.T_steps:  # 성공 에피소드 조건
                reward += 100  # 에피소드 성공 시 추가 보상
        else:
            self.psnr_sustained_steps = 0

        # 관찰값 업데이트
        mask = self.create_action_mask(self.observation)
        info = {
            "psnr_before": psnr_before,
            "psnr_after": psnr_after,
            "psnr_change": psnr_change,
            "psnr_diff": psnr_diff,
            "mask": mask,
            "pre_flip_value": pre_flip_value,
            "state_before": self.state.copy(),  # 행동 이전 상태
            "state_after": self.state.copy() if psnr_change >= 0 else None,  # 행동 성공 시 상태
            "observation_before": self.observation.copy(),  # 행동 이전 관찰값
            "observation_after": combined_observation if psnr_change >= 0 else None,  # 행동 성공 시 관찰값
            "failed_action": action if psnr_change < 0 else None,  # 실패한 행동
            "flip_count": self.flip_count,  # 현재까지의 플립 횟수
            "reward": reward,
            "target_image": self.target_image.cpu().numpy(),  # 타겟 이미지
            "simulation_result": result_np,  # 현재 시뮬레이션 결과
            "action_coords": (channel, row, col),  # 행동한 좌표
            "step": self.steps  # 현재 스텝
        }

        del binary_before, binary_after, sim_before, sim_after, result_before, result_after
        torch.cuda.empty_cache()

        self.steps += 1

        #print(f"[DEBUG] Step Observation: combined_observation shape={combined_observation.shape}, type={type(combined_observation)}")

        return combined_observation, reward, terminated, truncated, info

def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)


import torch
from torch import nn
from stable_baselines3.common.policies import ActorCriticPolicy

class CustomCNNPolicy(ActorCriticPolicy):
    def __init__(self, observation_space, action_space, lr_schedule, *args, **kwargs):
        super(CustomCNNPolicy, self).__init__(observation_space, action_space, lr_schedule, *args, **kwargs)

        # Define CNN modules for each part of the observation
        self.state_cnn = nn.Sequential(
            nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )

        self.obs_cnn = nn.Sequential(
            nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )

        self.target_cnn = nn.Sequential(
            nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten()
        )

        self.simulation_cnn = nn.Sequential(
            nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten()
        )

        # Calculate the total flattened features
        dummy_input = torch.zeros(1, 8, 512, 512)
        state_features = self.state_cnn(dummy_input).shape[1]
        obs_features = self.obs_cnn(dummy_input).shape[1]
        target_features = self.target_cnn(dummy_input).shape[1]
        simulation_features = self.simulation_cnn(dummy_input).shape[1]

        total_features = state_features + obs_features + target_features + simulation_features

        # Actor (policy network)
        self.actor = nn.Sequential(
            nn.Linear(total_features, 256),
            nn.ReLU(),
            nn.Linear(256, self.action_space.n)
        )

        # Critic (value network)
        self.critic = nn.Sequential(
            nn.Linear(total_features, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, obs, deterministic=False):
        #print(f"[DEBUG] Policy Forward Input: obs shape={obs.shape}, type={type(obs)}")

        # Split the observation into components along dim=1
        state, obs_out, target, simulation = torch.chunk(obs, 4, dim=1)

        # Remove the extra dimension (squeeze dim=1)
        state = state.squeeze(1)  # From [1, 8, 512, 512] -> [8, 512, 512]
        obs_out = obs_out.squeeze(1)
        target = target.squeeze(1)
        simulation = simulation.squeeze(1)

        # Debugging shapes after squeeze
        #print(f"[DEBUG] After Squeeze: state shape={state.shape}, obs_out shape={obs_out.shape}, target shape={target.shape}, simulation shape={simulation.shape}")

        # Process each component using respective CNNs
        state_features = self.state_cnn(state)
        obs_features = self.obs_cnn(obs_out)
        target_features = self.target_cnn(target)
        simulation_features = self.simulation_cnn(simulation)

        # Debugging CNN output shapes
        #print(f"[DEBUG] CNN Outputs: state_features shape={state_features.shape}, obs_features shape={obs_features.shape}, target_features shape={target_features.shape}, simulation_features shape={simulation_features.shape}")

        # Concatenate all features
        combined_features = torch.cat([state_features, obs_features, target_features, simulation_features], dim=1)

        # Compute policy logits and value
        logits = self.actor(combined_features)
        values = self.critic(combined_features)

        # Create distribution from logits
        distribution = self.action_dist.proba_distribution(logits)

        # Sample actions
        if deterministic:
            actions = distribution.mode()
        else:
            actions = distribution.sample()

        # Compute log probabilities
        log_probs = distribution.log_prob(actions)

        return actions, values, log_probs


    def _predict(self, obs, deterministic=False):
        logits, _ = self.forward(obs, deterministic)
        return logits.argmax(dim=1) if deterministic else torch.multinomial(torch.softmax(logits, dim=1), num_samples=1)

    def get_distribution(self, obs):
        logits, _ = self.forward(obs)
        return self.action_dist.proba_distribution(logits=logits)

    def evaluate_actions(self, obs, actions):
        logits, values = self.forward(obs)
        distribution = self.action_dist.proba_distribution(logits=logits)
        log_prob = distribution.log_prob(actions)
        return log_prob, values, distribution.entropy()

import torch

from stable_baselines3.common.callbacks import BaseCallback, CallbackList

class DummyCallback(BaseCallback):
    """
    기본 동작을 제공하는 더미 콜백.
    """
    def __init__(self):
        super(DummyCallback, self).__init__()

    def _on_step(self) -> bool:
        # 매 스텝마다 호출되며, 기본적으로 계속 학습하도록 True를 반환
        return True

class CustomPPO(PPO):
    def learn(self, total_timesteps, log_interval=10):
        device_env = torch.device("cuda:0")  # 환경과 행동 생성용
        device_train = torch.device("cuda:1")  # 학습 연산용

        # 환경에서 사용될 정책 네트워크를 GPU 0으로 이동
        self.policy.to(device_env)

        # 학습 설정 초기화
        self._setup_learn(total_timesteps)

        # 기본 콜백 설정
        callback = CallbackList([DummyCallback()])
        callback.init_callback(self)  # 콜백에 모델 정보를 전달

        timesteps_since_eval = 0
        for step in range(0, total_timesteps, self.n_steps):
            # 데이터 수집은 GPU 0에서 수행
            with torch.cuda.device(device_env):
                self.collect_rollouts(
                    env=self.env,
                    rollout_buffer=self.rollout_buffer,
                    n_rollout_steps=self.n_steps,
                    callback=callback  # 더미 콜백 사용
                )

            # 학습은 GPU 1에서 수행
            with torch.cuda.device(device_train):
                self.policy.to(device_train)  # 학습 전 정책 네트워크를 GPU 1으로 이동
                self.train()
                self.policy.to(device_env)  # 학습 후 다시 정책 네트워크를 GPU 0으로 이동

            timesteps_since_eval += self.n_steps
            if step % log_interval == 0:
                print(f"Step {step}/{total_timesteps}: Training complete for this batch.")

        return self


batch_size = 1
target_dir = '/nfs/dataset/DIV2K/DIV2K_train_HR/DIV2K_train_HR/'
valid_dir = '/nfs/dataset/DIV2K/DIV2K_valid_HR/DIV2K_valid_HR/'
meta = {'wl': (515e-9), 'dx': (7.56e-6, 7.56e-6)}  # 메타 정보
padding = 0

# Dataset512 클래스 사용
train_dataset = Dataset512(target_dir=target_dir, meta=meta, isTrain=True, padding=padding)
valid_dataset = Dataset512(target_dir=valid_dir, meta=meta, isTrain=False, padding=padding)

# DataLoader 생성
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# BinaryNet 모델 로드
model = BinaryNet(num_hologram=8, in_planes=1, convReLU=False, convBN=False,
                  poolReLU=False, poolBN=False, deconvReLU=False, deconvBN=False).cuda()
model.load_state_dict(torch.load('result_v/2024-12-19 20:37:52.499731_pre_reinforce_8_0.002/2024-12-19 20:37:52.499731_pre_reinforce_8_0.002'))
model.eval()


# 마스크 함수 정의
def mask_fn(env):
    return env.create_action_mask(env.observation)

# 환경 생성
env = BinaryHologramEnv(
    target_function=model,
    trainloader=train_loader
)
env = ActionMasker(env, mask_fn)
venv = make_vec_env(lambda: env, n_envs=1)
venv = VecNormalize(venv, norm_obs=True, norm_reward=True, clip_obs=10.0)

# PPO 모델 초기화
ppo_model = CustomPPO(
    CustomCNNPolicy,  # Use the custom policy
    venv,
    verbose=2,
    n_steps=256,
    batch_size=64,
    gamma=0.99,
    gae_lambda=0.95,
    learning_rate=1e-4,  # Learning rate
    clip_range=0.2,
    vf_coef=0.5,
    max_grad_norm=0.2,  # Gradient clipping
    tensorboard_log="./ppo_custom_policy/"
)

# 학습 시작
ppo_model.learn(total_timesteps=1000000)

# 학습된 모델 저장
ppo_model.save(f"ppo_with_mask_{current_date}")


"""

# 환경 생성에 새로운 데이터 로더 적용
env = BinaryHologramEnv(
    target_function=model,
    trainloader=train_loader,  # 업데이트된 train_loader 사용
    #max_steps=10000,
    #T_PSNR=30,
    #T_steps=10
)

# ActionMasker 래퍼 적용
env = ActionMasker(env, mask_fn)

# Vectorized 환경 생성
venv = make_vec_env(lambda: env, n_envs=1)

#obs = venv.reset()
#print(f"[DEBUG] Venv Reset Observation: obs shape={obs.shape}, type={type(obs)}")

venv = VecNormalize(venv, norm_obs=True, norm_reward=True, clip_obs=10.0)

#obs = venv.reset()
#print(f"[DEBUG] After VecNormalize: obs shape={obs.shape}, type={type(obs)}")

ppo_model = PPO(
    CustomCNNPolicy,  # Use the custom policy
    venv,
    verbose=2,
    n_steps=256,
    batch_size=64,
    gamma=0.99,
    gae_lambda=0.95,
    learning_rate=1e-4,  # Learning rate
    clip_range=0.2,
    vf_coef=0.5,
    max_grad_norm=0.2,  # Gradient clipping
    tensorboard_log="./ppo_custom_policy/"
)
"""

"""
# PPO 학습
ppo_model = PPO(
    "MlpPolicy",
    venv,
    verbose=2,
    n_steps=256,
    batch_size=64,
    gamma=0.99,
    gae_lambda=0.95,
    learning_rate=1e-4,  # 학습률 감소
    clip_range=0.2,
    vf_coef=0.5,
    max_grad_norm=0.2,  # Gradient clipping 추가
    tensorboard_log="./ppo_with_mask/"
)
"""

ppo_model.learn(total_timesteps=10000000)

# 학습된 모델 저장
ppo_model.save(f"ppo_with_mask_{current_date}")

이 메시지는 콘솔과 파일에 동시에 기록됩니다.


로깅 메시지입니다.
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


torch.Size([1, 8, 512, 512])
Using cuda device
[91mResetting environment. Consecutive episode failures: 0, Max consecutive episode failures: 0[0m
[92mInitial MSE: 0.001446, Initial PSNR: 27.761915, 11:10:43[0m
Logging to ./ppo_custom_policy/run_4
[94mStep: 5, PSNR Before: 27.761915, PSNR After: 27.762030, PSNR Change: 0.000114, PSNR Diff: 0.000114 (New Max), Reward: 0.09, 11:10:49 Pre-flip Model Output=0.016652, New State Value=1, Flip Count=1[0m
[94mStep: 34, PSNR Before: 27.762030, PSNR After: 27.762035, PSNR Change: 0.000006, PSNR Diff: 0.000120 (New Max), Reward: 0.10, 11:11:13 Pre-flip Model Output=0.255861, New State Value=1, Flip Count=3[0m
[94mStep: 47, PSNR Before: 27.762035, PSNR After: 27.762066, PSNR Change: 0.000031, PSNR Diff: 0.000151 (New Max), Reward: 0.12, 11:11:24 Pre-flip Model Output=0.379479, New State Value=1, Flip Count=4[0m
[94mStep: 49, PSNR Before: 27.762066, PSNR After: 27.762150, PSNR Change: 0.000084, PSNR Diff: 0.000235 (New Max), Reward: 0.19,

OutOfMemoryError: CUDA out of memory. Tried to allocate 512.00 MiB. GPU 1 has a total capacity of 23.65 GiB of which 160.69 MiB is free. Including non-PyTorch memory, this process has 23.49 GiB memory in use. Of the allocated memory 22.77 GiB is allocated by PyTorch, and 273.78 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)