In [None]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import MaskablePPO
from stable_baselines3.common.policies import ActorCriticPolicy
import gymnasium as gym
from gymnasium import spaces
from datetime import datetime
import glob
import torchOptics.optics as tt
import torch.nn as nn
import torchOptics.metrics as tm
import torch.nn.functional as F
import torch.optim
import torch
from torch.utils.data import Dataset, DataLoader
import os
import glob
import matplotlib.pyplot as plt
import pickle
import torchvision
import tqdm
import time
import pandas as pd
from sb3_contrib.common.maskable.utils import get_action_masks
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3 import PPO
import warnings

warnings.filterwarnings('ignore')

# 현재 날짜와 시간을 가져와 포맷 지정
current_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

torch.backends.cudnn.enabled = False

class BinaryNet(nn.Module):
    def __init__(self, num_hologram, final='Sigmoid', in_planes=3,
                 channels=[32, 64, 128, 256, 512, 1024, 2048, 4096],
                 convReLU=True, convBN=True, poolReLU=True, poolBN=True,
                 deconvReLU=True, deconvBN=True):
        super(BinaryNet, self).__init__()

        def CRB2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias=bias)]
            if relu:
                layers += [nn.Tanh()]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]

            cbr = nn.Sequential(*layers)  # *으로 list unpacking

            return cbr

        def TRB2d(in_channels, out_channels, kernel_size=2, stride=2, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                          kernel_size=2, stride=2, padding=0,
                                          bias=True)]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]
            if relu:
                layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers)  # *으로 list unpacking

            return cbr

        self.enc1_1 = CRB2d(in_planes, channels[0], relu=convReLU, bn=convBN)
        self.enc1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)
        self.pool1 = CRB2d(channels[0], channels[0], stride=2, relu=poolReLU, bn=poolBN)

        self.enc2_1 = CRB2d(channels[0], channels[1], relu=convReLU, bn=convBN)
        self.enc2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)
        self.pool2 = CRB2d(channels[1], channels[1], stride=2, relu=poolReLU, bn=poolBN)

        self.enc3_1 = CRB2d(channels[1], channels[2], relu=convReLU, bn=convBN)
        self.enc3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)
        self.pool3 = CRB2d(channels[2], channels[2], stride=2, relu=poolReLU, bn=poolBN)

        self.enc4_1 = CRB2d(channels[2], channels[3], relu=convReLU, bn=convBN)
        self.enc4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)
        self.pool4 = CRB2d(channels[3], channels[3], stride=2, relu=poolReLU, bn=poolBN)

        self.enc5_1 = CRB2d(channels[3], channels[4], relu=convReLU, bn=convBN)
        self.enc5_2 = CRB2d(channels[4], channels[4], relu=convReLU, bn=convBN)

        self.deconv4 = TRB2d(channels[4], channels[3], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec4_1 = CRB2d(channels[4], channels[3], relu=convReLU, bn=convBN)
        self.dec4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)

        self.deconv3 = TRB2d(channels[3], channels[2], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec3_1 = CRB2d(channels[3], channels[2], relu=convReLU, bn=convBN)
        self.dec3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)

        self.deconv2 = TRB2d(channels[2], channels[1], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec2_1 = CRB2d(channels[2], channels[1], relu=convReLU, bn=convBN)
        self.dec2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)

        self.deconv1 = TRB2d(channels[1], channels[0], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec1_1 = CRB2d(channels[1], channels[0], relu=convReLU, bn=convBN)
        self.dec1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)

        self.classifier = CRB2d(channels[0], num_hologram, relu=False, bn=False)

    def forward(self, x):
        # Encoder
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)
        enc5_2 = self.enc5_2(enc5_1)

        deconv4 = self.deconv4(enc5_2)
        concat4 = torch.cat((deconv4, enc4_2), dim=1)
        dec4_1 = self.dec4_1(concat4)
        dec4_2 = self.dec4_2(dec4_1)

        deconv3 = self.deconv3(dec4_2)
        concat3 = torch.cat((deconv3, enc3_2), dim=1)
        dec3_1 = self.dec3_1(concat3)
        dec3_2 = self.dec3_2(dec3_1)

        deconv2 = self.deconv2(dec3_2)
        concat2 = torch.cat((deconv2, enc2_2), dim=1)
        dec2_1 = self.dec2_1(concat2)
        dec2_2 = self.dec2_2(dec2_1)

        deconv1 = self.deconv1(dec2_2)
        concat1 = torch.cat((deconv1, enc1_2), dim=1)
        dec1_1 = self.dec1_1(concat1)
        dec1_2 = self.dec1_2(dec1_1)

        # Final classifier
        out = self.classifier(dec1_2)
        out = nn.Sigmoid()(out)
        return out


model = BinaryNet(num_hologram=8, in_planes=1, convReLU=False,
                  convBN=False, poolReLU=False, poolBN=False,
                  deconvReLU=False, deconvBN=False).cuda()
test = torch.randn(1, 1, 128, 128).cuda()
out = model(test)
print(out.shape)


class Dataset512(Dataset):
    def __init__(self, target_dir, meta, transform=None, isTrain=True, padding=0):
        self.target_dir = target_dir
        self.transform = transform
        self.meta = meta
        self.isTrain = isTrain
        self.target_list = sorted(glob.glob(target_dir+'*.png'))
        self.center_crop = torchvision.transforms.CenterCrop(128)
        self.random_crop = torchvision.transforms.RandomCrop((128, 128))
        self.padding = padding

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        target = tt.imread(self.target_list[idx], meta=meta, gray=True).unsqueeze(0)
        if target.shape[-1] < 128 or target.shape[-2] < 128:
            target = torchvision.transforms.Resize(128)(target)
        if self.isTrain:
            target = self.random_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        else:
            target = self.center_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        return target


# BinaryHologramEnv 클래스
class BinaryHologramEnv(gym.Env):
    def __init__(self, target_function, trainloader, max_steps=1000, T_PSNR=30, T_steps=10, max_allowed_changes=10):
        """
        target_function: 타겟 이미지와의 손실(MSE 또는 PSNR) 계산 함수.
        trainloader: 학습 데이터셋 로더.
        max_steps: 최대 타임스텝 제한.
        T_PSNR: 목표 PSNR 값.
        T_steps: PSNR 목표를 유지해야 하는 최소 타임스텝.
        max_allowed_changes: 한 번에 조작할 수 있는 최대 픽셀 수.
        """
        super(BinaryHologramEnv, self).__init__()

        # 관찰 공간 (1, 8, 128, 128)
        self.observation_space = spaces.Box(low=0, high=1, shape=(1, 8, 128, 128), dtype=np.float32)

        # 행동 공간: MultiBinary 데이터
        self.action_space = spaces.MultiBinary(1 * 8 * 128 * 128)

        # 모델 및 데이터 로더 설정
        self.target_function = target_function  # BinaryNet 모델
        self.trainloader = trainloader          # 학습 데이터 로더

        # 에피소드 설정
        self.max_steps = max_steps
        self.T_PSNR = T_PSNR
        self.T_steps = T_steps
        self.max_allowed_changes = max_allowed_changes  # 한 번에 조작할 수 있는 최대 픽셀 수

        # 학습 상태 초기화
        self.state = None
        self.observation = None
        self.steps = 0
        self.psnr_sustained_steps = 0

        # 학습 데이터셋에서 첫 배치 추출
        self.data_iter = iter(self.trainloader)
        self.target_image = None

    def reset(self, seed=None, options=None, lr=1e-4, z=2e-3):
        """
        환경 초기화 함수.
        데이터셋에서 새로운 이미지를 가져오고 초기 상태를 설정합니다.
        - 데이터셋의 다음 이미지를 불러옵니다. 
        - BinaryNet을 사용해 초기 관찰값을 생성합니다.
        - 초기 상태(state)는 관찰값을 이진화한 결과입니다.
        - 초기 PSNR과 MSE를 계산하고 출력합니다.

        Args:
            seed (int, optional): 랜덤 시드 값. Default는 None.
            options (dict, optional): 추가 옵션. Default는 None.
            lr (float, optional): 학습률. Default는 1e-4.
            z (float, optional): 시뮬레이션 거리. Default는 2e-3.

        Returns:
            observation (np.ndarray): 초기 관찰값.
            dict: 초기 상태와 행동 마스크.
        """
        torch.cuda.empty_cache()
        try:
            self.target_image = next(self.data_iter)
        except StopIteration:
            self.data_iter = iter(self.trainloader)
            self.target_image = next(self.data_iter)

        self.target_image = self.target_image.cuda()
        with torch.no_grad():
            model_output = self.target_function(self.target_image)
        self.observation = model_output.cpu().numpy()  # (1, 8, 512, 512)

        self.steps = 0
        self.psnr_sustained_steps = 0
        self.state = (self.observation >= 0.5).astype(np.int8)  # 이진화 상태

        binary = torch.tensor(self.state, dtype=torch.float32).cuda()
        binary = tt.Tensor(binary, meta={'dx': (7.56e-6, 7.56e-6), 'wl': 515e-9})  # meta 정보 포함

        # 시뮬레이션
        sim = tt.simulate(binary, z).abs()**2
        result = torch.mean(sim, dim=1, keepdim=True)

        # MSE 및 PSNR 계산
        mse = tt.relativeLoss(result, self.target_image, F.mse_loss).detach().cpu().numpy()
        self.initial_psnr = tt.relativeLoss(result, self.target_image, tm.get_PSNR)  # 초기 PSNR 저장

        current_time = datetime.now().strftime("%H:%M:%S")
        print(f"Initial MSE: {mse:.6f}, Initial PSNR: {self.initial_psnr:.6f}, {current_time}")

        mask = self.create_action_mask(self.observation)
        return self.observation, {"state": self.state, "mask": mask}


    def initialize_state(self, z=2e-3):
        """
        초기 상태를 생성하고, 시뮬레이션 및 관련 값을 계산합니다.

        Args:
            z (float): 시뮬레이션 거리. Default는 2e-3.

        Returns:
            observation (np.ndarray): 초기 관찰값.
            dict: 초기 상태와 행동 마스크.
        """
        with torch.no_grad():
            # 모델로 초기 관찰값 생성
            model_output = self.target_function(self.target_image)
        self.observation = model_output.cpu().numpy()  # 관찰값을 numpy 배열로 변환

        self.state = (self.observation >= 0.5).astype(np.int8)  # 이진화 상태

        binary = torch.tensor(self.state, dtype=torch.float32).cuda()  # 상태를 Torch 텐서로 변환
        binary = tt.Tensor(binary, meta={'dx': (7.56e-6, 7.56e-6), 'wl': 515e-9})  # 메타 정보 추가

        # 시뮬레이션 수행
        sim = tt.simulate(binary, z).abs()**2
        result = torch.mean(sim, dim=1, keepdim=True)

        # 초기 MSE와 PSNR 계산
        mse = tt.relativeLoss(result, self.target_image, F.mse_loss).detach().cpu().numpy()
        psnr = tt.relativeLoss(result, self.target_image, tm.get_PSNR)

        # 초기 값 출력
        print(f"Initial MSE: {mse:.6f}, Initial PSNR: {psnr:.6f}, {datetime.now()}")

        # 관찰값 업데이트
        self.observation = result.detach().cpu().numpy()
        mask = self.create_action_mask(self.observation)

        return self.observation, {"state": self.state, "mask": mask}

    def create_action_mask(self, observation):
        """
        관찰값에 따라 행동 마스크 생성.
        - 관찰값이 0~0.2인 경우 행동 0으로 고정.
        - 관찰값이 0.8~1인 경우 행동 1로 고정.
        - 최대 변경 가능 픽셀 수 제한.

        Args:
            observation (np.ndarray): 관찰값.

        Returns:
            np.ndarray: 행동 마스크.
        """
        mask = np.ones_like(observation, dtype=np.int8)  # 기본적으로 모든 행동 가능
        mask[observation <= 0.2] = 0  # 관찰값이 0~0.2면 행동 0으로 고정
        mask[observation >= 0.8] = 1  # 관찰값이 0.8~1이면 행동 1로 고정

        # 허용된 변경 수를 강제 적용
        allowed_indices = np.where(mask.flatten() == 1)[0]
        if len(allowed_indices) > self.max_allowed_changes:
            # 초과 변경을 방지하도록 고정된 수의 픽셀만 선택 가능
            selected_indices = np.random.choice(allowed_indices, self.max_allowed_changes, replace=False)
            mask = np.zeros_like(mask.flatten())
            mask[selected_indices] = 1
            mask = mask.reshape(observation.shape)

        return mask

    def step(self, action, lr=1e-4, z=2e-3):
        """
        환경의 한 타임스텝을 진행합니다.
        - 주어진 행동(action)을 적용하고, 새로운 상태를 계산합니다.
        - MSE와 PSNR 계산 후 보상을 반환합니다.

        Args:
            action (np.ndarray): 에이전트가 수행한 행동.
            lr (float, optional): 학습률. Default는 1e-4.
            z (float, optional): 시뮬레이션 거리. Default는 2e-3.

        Returns:
            observation (np.ndarray): 새로운 관찰값.
            float: 보상 값.
            bool: 종료 여부.
            bool: Truncated 여부.
            dict: 추가 정보 (MSE, PSNR, 행동 마스크 등).
        """
        if self.steps == 0:
            print("Executing reset logic for the first step")
            self.steps += 1
            observation, info = self.initialize_state(z)
            return observation, 0.0, False, False, info

        action = np.reshape(action, (1, 8, 128, 128)).astype(np.int8)

        # 행동에 마스크 강제 적용
        mask = self.create_action_mask(self.observation)
        masked_action = action * mask

        # 조작 픽셀 수 확인
        num_changes = np.sum(masked_action)
        reward = 0

        if num_changes > self.max_allowed_changes:
            reward -= 50

        # 현재 상태에 행동을 적용하여 새로운 상태 생성
        new_state = np.logical_xor(self.state, masked_action).astype(np.int8)

        binary = torch.tensor(new_state, dtype=torch.float32).cuda()
        binary = tt.Tensor(binary, meta={'dx': (7.56e-6, 7.56e-6), 'wl': 515e-9})

        # 시뮬레이션 수행
        sim = tt.simulate(binary, z).abs()**2
        result = torch.mean(sim, dim=1, keepdim=True)

        # MSE 및 PSNR 계산
        mse = tt.relativeLoss(result, self.target_image, F.mse_loss).detach().cpu().numpy()
        psnr = tt.relativeLoss(result, self.target_image, tm.get_PSNR)

        # 초기 PSNR과의 차이 계산
        psnr_diff = psnr - self.initial_psnr

        # 실패 조건 확인
        if psnr_diff < -1:
            print(f"Episode failed: PSNR Diff {psnr_diff:.6f} < -1 at step {self.steps}")
            return self.observation, -100.0, True, False, {"mse": mse, "psnr": psnr, "psnr_diff": psnr_diff, "mask": None}

        # 보상 계산
        reward += psnr_diff * 7
        reward -= 0.1 * num_changes if num_changes > self.max_allowed_changes else 0

        # 출력 추가 (100 스텝마다 출력)
        if self.steps % 100 == 0:
            current_time = datetime.now().strftime("%H:%M:%S")
            print(f"Step: {self.steps}, MSE: {mse:.6f}, PSNR: {psnr:.6f}, PSNR Diff: {psnr_diff:.6f}, "
                  f"Changes: {num_changes}, Reward: {reward:.2f}, {current_time}")

        # 상태 업데이트
        self.state = new_state
        self.observation = self.state

        terminated = self.steps >= self.max_steps or self.psnr_sustained_steps >= self.T_steps
        truncated = self.steps >= self.max_steps

        if psnr >= self.T_PSNR:
            self.psnr_sustained_steps += 1
        else:
            self.psnr_sustained_steps = 0

        mask = self.create_action_mask(self.observation)
        info = {"mse": mse, "psnr": psnr, "psnr_diff": psnr_diff, "mask": mask}

        del binary, sim, result
        torch.cuda.empty_cache()

        self.steps += 1
        return self.observation, reward, terminated, truncated, info


def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)

batch_size = 1
target_dir = '/nfs/dataset/DIV2K/DIV2K_train_HR/DIV2K_train_HR/'
valid_dir = '/nfs/dataset/DIV2K/DIV2K_valid_HR/DIV2K_valid_HR/'
meta = {'wl': (515e-9), 'dx': (7.56e-6, 7.56e-6)}  # 메타 정보
padding = 0

# Dataset512 클래스 사용
train_dataset = Dataset512(target_dir=target_dir, meta=meta, isTrain=True, padding=padding)
valid_dataset = Dataset512(target_dir=valid_dir, meta=meta, isTrain=False, padding=padding)

# DataLoader 생성
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# BinaryNet 모델 로드
model = BinaryNet(num_hologram=8, in_planes=1, convReLU=False, convBN=False,
                  poolReLU=False, poolBN=False, deconvReLU=False, deconvBN=False).cuda()
model.load_state_dict(torch.load('result_v/2024-12-19 21:37:10.439713_pre_reinforce_8_0.002/2024-12-19 21:37:10.439713_pre_reinforce_8_0.002'))
model.eval()


# 마스크 함수 정의
def mask_fn(env):
    return env.create_action_mask(env.observation)

# 환경 생성에 새로운 데이터 로더 적용
env = BinaryHologramEnv(
    target_function=model,
    trainloader=train_loader,  # 업데이트된 train_loader 사용
    max_steps=1000,
    T_PSNR=30,
    T_steps=10
)

# ActionMasker 래퍼 적용
env = ActionMasker(env, mask_fn)

# Vectorized 환경 생성
venv = make_vec_env(lambda: env, n_envs=1)
venv = VecNormalize(venv, norm_obs=True, norm_reward=True, clip_obs=10.0)

# PPO 학습
#ppo_model = PPO(
#    "MlpPolicy",
#    venv,
#    verbose=2,
#    n_steps=1024,
#    batch_size=64,
#    gamma=0.99,
#    learning_rate=3e-4,
#    tensorboard_log="./ppo_with_mask/"
#)

#ppo_model.learn(total_timesteps=10000000)

# 학습된 모델 저장
#ppo_model.save(f"ppo_with_mask_{current_date}")

from sb3_contrib import RecurrentPPO

policy_kwargs = dict(
    net_arch=[dict(pi=[256, 256], vf=[256, 256])],  # 더 복잡한 네트워크 구조
    lstm_hidden_size=128,  # LSTM 크기 유지
    shared_lstm=False  # 별도 LSTM 사용
)

ppo_model = RecurrentPPO(
    "MlpLstmPolicy",
    venv,
    verbose=2,
    n_steps=256,
    batch_size=64,
    gamma=0.99,
    gae_lambda=0.95,
    learning_rate=1e-5,
    clip_range=0.2,
    vf_coef=0.5,
    max_grad_norm=0.5,  # 그라디언트 클리핑 활성화
    tensorboard_log="./ppo_with_mask/",
    policy_kwargs=policy_kwargs
)


# 학습
ppo_model.learn(total_timesteps=10000000)

# 모델 저장
ppo_model.save(f"ppo_with_mask_{current_date}")


# 평가용 환경 생성
#eval_env = make_vec_env(lambda: env, n_envs=1)

# EvalCallback 추가
#eval_callback = EvalCallback(
#    eval_env,
#    best_model_save_path='./logs/',
#    log_path='./logs/',
#    eval_freq=10000,  # 평가 빈도 (타임스텝 기준)
#    deterministic=True,
#    render=False
#)

#ppo_model = PPO(
#    "MlpPolicy",
#    venv,
#    verbose=2,
#    n_steps=1024,
#    batch_size=64,
#    gamma=0.99,
#    learning_rate=3e-4,
#    tensorboard_log="./ppo_with_mask/"
#)

# 학습 시작 (콜백 추가)
#ppo_model.learn(total_timesteps=10000000, callback=eval_callback)

  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


torch.Size([1, 8, 128, 128])
Using cuda device
Initial MSE: 0.003386, Initial PSNR: 23.341049, 08:27:07
Logging to ./ppo_with_mask/RecurrentPPO_35
Executing reset logic for the first step
Initial MSE: 0.003386, Initial PSNR: 23.341049, 2024-12-23 08:27:07.614790
Step: 100, MSE: 0.003933, PSNR: 22.691204, PSNR Diff: -0.649845, Changes: 4, Reward: -4.55, 08:27:09
Episode failed: PSNR Diff -1.002575 < -1 at step 165
Initial MSE: 0.001719, Initial PSNR: 25.492575, 08:27:10
Executing reset logic for the first step
Initial MSE: 0.001719, Initial PSNR: 25.492575, 2024-12-23 08:27:10.675499
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 166      |
|    ep_rew_mean     | -757     |
| time/              |          |
|    fps             | 47       |
|    iterations      | 1        |
|    time_elapsed    | 5        |
|    total_timesteps | 256      |
---------------------------------
Step: 100, MSE: 0.001964, PSNR: 24.913910, PSNR Diff: -0.578665, Chang