In [1]:
import torchOptics.optics as tt
import warnings
import torch.nn as nn
import torchOptics.metrics as tm
import torch.nn.functional as F
import torch.optim
import torch
from torch.utils.data import Dataset, DataLoader
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torchvision
import datetime
import tqdm
import time
import pandas as pd

# 이미지 데이터를 처리하기 위한 커스텀 Dataset 클래스
class Dataset512(Dataset):
    def __init__(self, target_dir, meta, transform=None, isTrain=True, padding=0):
        self.target_dir = target_dir
        self.transform = transform
        self.meta = meta
        self.isTrain = isTrain
        self.target_list = sorted(glob.glob(target_dir+'*.png'))  # 이미지 경로를 정렬하여 가져오기
        self.center_crop = torchvision.transforms.CenterCrop(1024)  # 이미지를 중심으로 크롭
        self.random_crop = torchvision.transforms.RandomCrop((1024, 1024))  # 데이터 증강을 위한 랜덤 크롭
        self.padding = padding

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        # 이미지 읽기 및 회색조 채널 추가
        target = tt.imread(self.target_list[idx], meta=meta, gray=True).unsqueeze(0)
        # 이미지 크기가 1024x1024보다 작으면 리사이즈
        if target.shape[-1] < 1024 or target.shape[-2] < 1024:
            target = torchvision.transforms.Resize(1024)(target)
        if self.isTrain:  # 학습용 변환
            target = self.random_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        else:  # 검증용 변환
            target = self.center_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        return target

# 학습 및 검증 데이터를 로드
batch_size = 1
target_dir = '/nfs/dataset/DIV2K/DIV2K_train_HR/DIV2K_train_HR/'
valid_dir = '/nfs/dataset/DIV2K/DIV2K_valid_HR/DIV2K_valid_HR/'
meta = {'wl': (515e-9), 'dx': (7.56e-6, 7.56e-6)}  # 광학 시뮬레이션에 필요한 메타데이터
padding = 0
train_dataset = Dataset512(target_dir=target_dir, meta=meta, isTrain=True, padding=padding)
valid_dataset = Dataset512(target_dir=valid_dir, meta=meta, isTrain=False, padding=padding)

trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
validloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)



  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
  warn(f"Failed to load image Python extension: {e}")


In [2]:
class BinaryNet(nn.Module):
    def __init__(self, num_hologram, final='Sigmoid', in_planes=3,
                 channels=[32, 64, 128, 256, 512, 1024, 2048, 4096],
                 convReLU=True, convBN=True, poolReLU=True, poolBN=True,
                 deconvReLU=True, deconvBN=True):
        super(BinaryNet, self).__init__()

        def CRB2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias=bias)]
            if relu:
                layers += [nn.Tanh()]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]

            cbr = nn.Sequential(*layers) # *으로 list unpacking 

            return cbr

        def TRB2d(in_channels, out_channels, kernel_size=2, stride=2, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                          kernel_size=2, stride=2, padding=0,
                                          bias=True)]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]
            if relu:
                layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers) # *으로 list unpacking 

            return cbr

        self.enc1_1 = CRB2d(in_planes, channels[0], relu=convReLU, bn=convBN)
        self.enc1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)
        self.pool1 = CRB2d(channels[0], channels[0], stride=2, relu=poolReLU, bn=poolBN)

        self.enc2_1 = CRB2d(channels[0], channels[1], relu=convReLU, bn=convBN)
        self.enc2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)
        self.pool2 = CRB2d(channels[1], channels[1], stride=2, relu=poolReLU, bn=poolBN)

        self.enc3_1 = CRB2d(channels[1], channels[2], relu=convReLU, bn=convBN)
        self.enc3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)
        self.pool3 = CRB2d(channels[2], channels[2], stride=2, relu=poolReLU, bn=poolBN)

        self.enc4_1 = CRB2d(channels[2], channels[3], relu=convReLU, bn=convBN)
        self.enc4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)
        self.pool4 = CRB2d(channels[3], channels[3], stride=2, relu=poolReLU, bn=poolBN)

        self.enc5_1 = CRB2d(channels[3], channels[4], relu=convReLU, bn=convBN)
        self.enc5_2 = CRB2d(channels[4], channels[4], relu=convReLU, bn=convBN)

        self.deconv4 = TRB2d(channels[4], channels[3], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec4_1 = CRB2d(channels[4], channels[3], relu=convReLU, bn=convBN)
        self.dec4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)

        self.deconv3 = TRB2d(channels[3], channels[2], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec3_1 = CRB2d(channels[3], channels[2], relu=convReLU, bn=convBN)
        self.dec3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)

        self.deconv2 = TRB2d(channels[2], channels[1], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec2_1 = CRB2d(channels[2], channels[1], relu=convReLU, bn=convBN)
        self.dec2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)

        self.deconv1 = TRB2d(channels[1], channels[0], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec1_1 = CRB2d(channels[1], channels[0], relu=convReLU, bn=convBN)
        self.dec1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)

        self.classifier = CRB2d(channels[0], num_hologram, relu=False, bn=False)

    def forward(self, x):
        # Encoder
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)
        enc5_2 = self.enc5_2(enc5_1)

        deconv4 = self.deconv4(enc5_2)
        concat4 = torch.cat((deconv4, enc4_2), dim=1)
        dec4_1 = self.dec4_1(concat4)
        dec4_2 = self.dec4_2(dec4_1)

        deconv3 = self.deconv3(dec4_2)
        concat3 = torch.cat((deconv3, enc3_2), dim=1)
        dec3_1 = self.dec3_1(concat3)
        dec3_2 = self.dec3_2(dec3_1)

        deconv2 = self.deconv2(dec3_2)
        concat2 = torch.cat((deconv2, enc2_2), dim=1)
        dec2_1 = self.dec2_1(concat2)
        dec2_2 = self.dec2_2(dec2_1)

        deconv1 = self.deconv1(dec2_2)
        concat1 = torch.cat((deconv1, enc1_2), dim=1)
        dec1_1 = self.dec1_1(concat1)
        dec1_2 = self.dec1_2(dec1_1)

        # Final classifier
        out = self.classifier(dec1_2)
        out = nn.Sigmoid()(out)
        return out


model = BinaryNet(num_hologram=8, in_planes=1, convReLU=False,
                  convBN=False, poolReLU=False, poolBN=False,
                  deconvReLU=False, deconvBN=False).cuda()
model.load_state_dict(torch.load('result/2024-12-07 19:38:09.105795_pre_reinforce_8_0.002/2024-12-07 19:38:09.105795_pre_reinforce_8_0.002'))  # 저장된 모델 경로를 입력
model.eval()


  model.load_state_dict(torch.load('result/2024-12-07 19:38:09.105795_pre_reinforce_8_0.002/2024-12-07 19:38:09.105795_pre_reinforce_8_0.002'))  # 저장된 모델 경로를 입력


BinaryNet(
  (enc1_1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (enc1_2): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (pool1): Sequential(
    (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  )
  (enc2_1): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (enc2_2): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (pool2): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  )
  (enc3_1): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (enc3_2): Sequential(
    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (pool3): Sequential(
    (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  )
  (enc4_1): Sequential(
    (0): Conv2d(128, 256, kernel_size

In [3]:
import torch
import torch.nn.functional as F
import torchOptics.metrics as tm
from gymnasium import spaces
import gymnasium as gym
import numpy as np


class BinaryHologramEnv(gym.Env):
    def __init__(self, model, validloader, max_steps=1000, T_PSNR=30, T_steps=100, reward_function=None):
        """
        Custom Gym environment for binary hologram generation.

        Args:
            model: Pre-trained PyTorch model for generating observations.
            validloader: DataLoader for validation dataset.
            max_steps: Maximum number of steps per episode.
            T_PSNR: Target PSNR value for successful termination.
            T_steps: Number of consecutive steps to maintain the target PSNR.
            reward_function: Custom function to compute the reward (optional).
        """
        super(BinaryHologramEnv, self).__init__()
        self.model = model.eval().cuda()
        self.validloader = iter(validloader)

        # Calculate observation space shape
        example_input = next(iter(validloader)).cuda()
        with torch.no_grad():
            example_output = self.model(example_input)
        self.output_shape = example_output.shape

        # Define observation and action spaces
        self.observation_space = spaces.Box(low=0, high=1, shape=self.output_shape[1:], dtype=np.float32)
        self.action_space = spaces.MultiBinary(np.prod(self.output_shape[1:]))

        # Environment parameters
        self.max_steps = max_steps
        self.T_PSNR = T_PSNR
        self.T_steps = T_steps
        self.reward_function = reward_function  # Optional custom reward function

        # Internal state
        self.state = None
        self.steps = 0
        self.psnr_sustained_steps = 0

    def reset(self, seed=None, options=None):
        """
        Reset the environment and initialize state.
        """
        super().reset(seed=seed)
        with torch.no_grad():
            try:
                input_tensor = next(self.validloader).cuda()
            except StopIteration:
                self.validloader = iter(self.validloader)
                input_tensor = next(self.validloader).cuda()

            out = self.model(input_tensor)
            self.state = out.cpu().numpy().squeeze(0).astype(np.float32)  # 배치 차원 제거

        self.steps = 0
        self.psnr_sustained_steps = 0
        return self.state, {}


    def step(self, action):
        """
        Execute a step in the environment.

        Args:
            action: Binary action array matching the model's output shape.

        Returns:
            Tuple: (new_state, reward, terminated, truncated, info)
        """
        # Reshape action to match the state shape
        action = np.reshape(action, self.state.shape)
        action_tensor = torch.tensor(action, dtype=torch.float32).cuda()

        # Target tensor for metrics calculation
        target = torch.tensor(self.state.mean(axis=1, keepdims=True), dtype=torch.float32).cuda()

        # Expand target to match the shape of action_tensor
        target_expanded = target.expand_as(action_tensor)

        # Calculate PSNR and MSE
        mse = tt.relativeLoss(action_tensor, target_expanded, F.mse_loss).item()
        psnr = tt.relativeLoss(action_tensor, target_expanded, tm.get_PSNR)

        # Compute reward using custom function or default MSE-based reward
        if self.reward_function:
            reward = self.reward_function(self.state, action)
        else:
            reward = -mse  # Default reward is negative MSE

        # Update episode status
        terminated = False
        truncated = False
        self.steps += 1

        if self.steps >= self.max_steps:
            truncated = True
        if psnr >= self.T_PSNR:
            self.psnr_sustained_steps += 1
        else:
            self.psnr_sustained_steps = 0
        if self.psnr_sustained_steps >= self.T_steps:
            terminated = True

        with torch.no_grad():
            try:
                input_tensor = next(self.validloader).cuda()
            except StopIteration:
                self.validloader = iter(self.validloader)
                input_tensor = next(self.validloader).cuda()

            out = self.model(input_tensor)
            self.state = out.cpu().numpy().squeeze(0).astype(np.float32) 

        info = {"mse": mse, "psnr": psnr, "psnr_sustained_steps": self.psnr_sustained_steps}
        return self.state, reward, terminated, truncated, info

In [4]:
from stable_baselines3.common.env_checker import check_env

# 환경 초기화
env = BinaryHologramEnv(
    model=model,
    validloader=validloader,
    max_steps=1000,
    T_PSNR=30,
    T_steps=100,
)

# Gymnasium 환경 유효성 검사
check_env(env, warn=True)


  return super().__new__(cls, torch.Tensor(x), *args, **kwargs).to(device)
  target = tt.imread(self.target_list[idx], meta=meta, gray=True).unsqueeze(0)
  if target.shape[-1] < 1024 or target.shape[-2] < 1024:
  return x.ndim >= 2
  channels = 1 if img.ndim == 2 else img.shape[-3]
  height, width = img.shape[-2:]
  return img[..., top:bottom, left:right]
  if img.ndim < 4:
  img = img.unsqueeze(dim=0)
  out_dtype = img.dtype
  img = img.squeeze(dim=0)
  if elem.is_nested:
  if elem.layout in {torch.sparse_coo, torch.sparse_csr, torch.sparse_bsr, torch.sparse_csc, torch.sparse_bsc}:
  return torch.stack(batch, 0, out=out)
  example_input = next(iter(validloader)).cuda()
  return F.conv2d(input, weight, bias, self.stride,
  return F.conv_transpose2d(
  concat4 = torch.cat((deconv4, enc4_2), dim=1)
  concat3 = torch.cat((deconv3, enc3_2), dim=1)
  concat2 = torch.cat((deconv2, enc2_2), dim=1)
  concat1 = torch.cat((deconv1, enc1_2), dim=1)
  return torch.sigmoid(input)
  self.output_shap

In [5]:
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import RecurrentPPO

model = BinaryNet(num_hologram=8, in_planes=1, convReLU=False,
                  convBN=False, poolReLU=False, poolBN=False,
                  deconvReLU=False, deconvBN=False).cuda()
model.load_state_dict(torch.load('result/2024-12-07 19:38:09.105795_pre_reinforce_8_0.002/2024-12-07 19:38:09.105795_pre_reinforce_8_0.002'))
model.eval()

# Create the custom Gym environment
env = BinaryHologramEnv(model=model, validloader=validloader, max_steps=1000, T_PSNR=30, T_steps=100)

# Create a vectorized environment
venv = make_vec_env(lambda: env, n_envs=1)
venv = VecNormalize(venv, norm_obs=True, norm_reward=True, clip_obs=10.0)


# Recurrent PPO 모델 학습
ppo_model = RecurrentPPO(
    "MlpLstmPolicy",
    venv,
    verbose=1,
    n_steps=2048,
    batch_size=64,
    gamma=0.99,
    learning_rate=3e-4,
    tensorboard_log="./ppo_lstm/"
)

# 모델 학습
ppo_model.learn(total_timesteps=100000)

# 학습된 모델 저장
ppo_model.save("recurrent_ppo_binary_hologram")


  model.load_state_dict(torch.load('result/2024-12-07 19:38:09.105795_pre_reinforce_8_0.002/2024-12-07 19:38:09.105795_pre_reinforce_8_0.002'))
  target = tt.imread(self.target_list[idx], meta=meta, gray=True).unsqueeze(0)
  if target.shape[-1] < 1024 or target.shape[-2] < 1024:
  example_input = next(iter(validloader)).cuda()
  concat4 = torch.cat((deconv4, enc4_2), dim=1)
  concat3 = torch.cat((deconv3, enc3_2), dim=1)
  concat2 = torch.cat((deconv2, enc2_2), dim=1)
  concat1 = torch.cat((deconv1, enc1_2), dim=1)
  self.output_shape = example_output.shape


Using cuda device


OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 GiB. GPU 0 has a total capacity of 23.65 GiB of which 20.73 GiB is free. Process 582476 has 384.00 MiB memory in use. Including non-PyTorch memory, this process has 2.54 GiB memory in use. Of the allocated memory 2.10 GiB is allocated by PyTorch, and 1.76 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)