In [1]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import MaskablePPO
from stable_baselines3.common.policies import ActorCriticPolicy
import gymnasium as gym
from gymnasium import spaces
from datetime import datetime
import glob
import torchOptics.optics as tt
import torch.nn as nn
import torchOptics.metrics as tm
import torch.nn.functional as F
import torch.optim
import torch
from torch.utils.data import Dataset, DataLoader
import os
import glob
import matplotlib.pyplot as plt
import pickle
import torchvision
import tqdm
import time
import pandas as pd
from sb3_contrib.common.maskable.utils import get_action_masks
from sb3_contrib.common.wrappers import ActionMasker
from stable_baselines3 import PPO

# 현재 날짜와 시간을 가져와 포맷 지정
current_date = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

torch.backends.cudnn.enabled = False

class BinaryNet(nn.Module):
    def __init__(self, num_hologram, final='Sigmoid', in_planes=3,
                 channels=[32, 64, 128, 256, 512, 1024, 2048, 4096],
                 convReLU=True, convBN=True, poolReLU=True, poolBN=True,
                 deconvReLU=True, deconvBN=True):
        super(BinaryNet, self).__init__()

        def CRB2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias=bias)]
            if relu:
                layers += [nn.Tanh()]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]

            cbr = nn.Sequential(*layers)  # *으로 list unpacking

            return cbr

        def TRB2d(in_channels, out_channels, kernel_size=2, stride=2, bias=True, relu=True, bn=True):
            layers = []
            layers += [nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                          kernel_size=2, stride=2, padding=0,
                                          bias=True)]
            if bn:
                layers += [nn.BatchNorm2d(num_features=out_channels)]
            if relu:
                layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers)  # *으로 list unpacking

            return cbr

        self.enc1_1 = CRB2d(in_planes, channels[0], relu=convReLU, bn=convBN)
        self.enc1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)
        self.pool1 = CRB2d(channels[0], channels[0], stride=2, relu=poolReLU, bn=poolBN)

        self.enc2_1 = CRB2d(channels[0], channels[1], relu=convReLU, bn=convBN)
        self.enc2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)
        self.pool2 = CRB2d(channels[1], channels[1], stride=2, relu=poolReLU, bn=poolBN)

        self.enc3_1 = CRB2d(channels[1], channels[2], relu=convReLU, bn=convBN)
        self.enc3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)
        self.pool3 = CRB2d(channels[2], channels[2], stride=2, relu=poolReLU, bn=poolBN)

        self.enc4_1 = CRB2d(channels[2], channels[3], relu=convReLU, bn=convBN)
        self.enc4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)
        self.pool4 = CRB2d(channels[3], channels[3], stride=2, relu=poolReLU, bn=poolBN)

        self.enc5_1 = CRB2d(channels[3], channels[4], relu=convReLU, bn=convBN)
        self.enc5_2 = CRB2d(channels[4], channels[4], relu=convReLU, bn=convBN)

        self.deconv4 = TRB2d(channels[4], channels[3], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec4_1 = CRB2d(channels[4], channels[3], relu=convReLU, bn=convBN)
        self.dec4_2 = CRB2d(channels[3], channels[3], relu=convReLU, bn=convBN)

        self.deconv3 = TRB2d(channels[3], channels[2], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec3_1 = CRB2d(channels[3], channels[2], relu=convReLU, bn=convBN)
        self.dec3_2 = CRB2d(channels[2], channels[2], relu=convReLU, bn=convBN)

        self.deconv2 = TRB2d(channels[2], channels[1], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec2_1 = CRB2d(channels[2], channels[1], relu=convReLU, bn=convBN)
        self.dec2_2 = CRB2d(channels[1], channels[1], relu=convReLU, bn=convBN)

        self.deconv1 = TRB2d(channels[1], channels[0], relu=deconvReLU, bn=deconvBN, stride=2)
        self.dec1_1 = CRB2d(channels[1], channels[0], relu=convReLU, bn=convBN)
        self.dec1_2 = CRB2d(channels[0], channels[0], relu=convReLU, bn=convBN)

        self.classifier = CRB2d(channels[0], num_hologram, relu=False, bn=False)

    def forward(self, x):
        # Encoder
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)
        enc5_2 = self.enc5_2(enc5_1)

        deconv4 = self.deconv4(enc5_2)
        concat4 = torch.cat((deconv4, enc4_2), dim=1)
        dec4_1 = self.dec4_1(concat4)
        dec4_2 = self.dec4_2(dec4_1)

        deconv3 = self.deconv3(dec4_2)
        concat3 = torch.cat((deconv3, enc3_2), dim=1)
        dec3_1 = self.dec3_1(concat3)
        dec3_2 = self.dec3_2(dec3_1)

        deconv2 = self.deconv2(dec3_2)
        concat2 = torch.cat((deconv2, enc2_2), dim=1)
        dec2_1 = self.dec2_1(concat2)
        dec2_2 = self.dec2_2(dec2_1)

        deconv1 = self.deconv1(dec2_2)
        concat1 = torch.cat((deconv1, enc1_2), dim=1)
        dec1_1 = self.dec1_1(concat1)
        dec1_2 = self.dec1_2(dec1_1)

        # Final classifier
        out = self.classifier(dec1_2)
        out = nn.Sigmoid()(out)
        return out


model = BinaryNet(num_hologram=8, in_planes=1, convReLU=False,
                  convBN=False, poolReLU=False, poolBN=False,
                  deconvReLU=False, deconvBN=False).cuda()
test = torch.randn(1, 1, 1024, 1024).cuda()
out = model(test)
print(out.shape)


# 간단한 데이터셋 클래스
class SimpleDataset(torch.utils.data.Dataset):
    def __init__(self, target_dir, meta, transform=None, isTrain=True, padding=0):
        """
        target_dir: 데이터가 저장된 디렉토리 경로
        meta: 이미지 메타 정보
        transform: 이미지 전처리
        isTrain: 학습용(True) 또는 검증용(False) 데이터셋 여부
        padding: 이미지 패딩
        """
        self.target_list = sorted(glob.glob(target_dir + '*.png'))
        self.meta = meta
        self.transform = transform
        self.isTrain = isTrain
        self.padding = padding
        self.random_crop = transforms.RandomCrop((1024, 1024))
        self.center_crop = transforms.CenterCrop(1024)

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):
        # 이미지 로드
        img = tt.imread(self.target_list[idx], meta=self.meta, gray=True).unsqueeze(0)

        # 크기 조정
        if img.shape[-2] < 1024 or img.shape[-1] < 1024:
            img = transforms.Resize((1024, 1024))(img)

        # 학습용 랜덤 크롭 또는 검증용 센터 크롭
        if self.isTrain:
            img = self.random_crop(img)
        else:
            img = self.center_crop(img)

        # 패딩 추가
        img = transforms.functional.pad(
            img, (self.padding, self.padding, self.padding, self.padding)
        )

        # 변환 적용 (optional)
        if self.transform:
            img = self.transform(img)

        return img


#BinaryHologramEnv 클래스
class BinaryHologramEnv(gym.Env):
    def __init__(self, target_function, trainloader, max_steps=1000000, T_PSNR=30, T_steps=100):
        """
        target_function: 타겟 이미지와의 손실(MSE 또는 PSNR) 계산 함수.
        trainloader: 학습 데이터셋 로더.
        max_steps: 최대 타임스텝 제한.
        T_PSNR: 목표 PSNR 값.
        T_steps: PSNR 목표를 유지해야 하는 최소 타임스텝.
        """
        super(BinaryHologramEnv, self).__init__()

        # 관찰 공간 (1, 8, 1024, 1024)
        self.observation_space = spaces.Box(low=0, high=1, shape=(1, 8, 1024, 1024), dtype=np.float32)

        # 행동 공간: MultiBinary 데이터
        self.action_space = spaces.MultiBinary(1 * 8 * 1024 * 1024)

        # 모델 및 데이터 로더 설정
        self.target_function = target_function  # BinaryNet 모델
        self.trainloader = trainloader          # 학습 데이터 로더

        # 에피소드 설정
        self.max_steps = max_steps
        self.T_PSNR = T_PSNR
        self.T_steps = T_steps

        # 학습 상태
        self.state = None
        self.observation = None
        self.steps = 0
        self.psnr_sustained_steps = 0

        # 학습 데이터셋에서 첫 배치 추출
        self.data_iter = iter(self.trainloader)
        self.target_image = None

    def reset(self, seed=None, options=None):
        try:
            self.target_image = next(self.data_iter)
        except StopIteration:
            self.data_iter = iter(self.trainloader)
            self.target_image = next(self.data_iter)

        self.target_image = self.target_image.cuda()
        with torch.no_grad():
            model_output = self.target_function(self.target_image)
        self.observation = model_output.cpu().numpy()  # (1, 8, 1024, 1024)

        self.steps = 0
        self.psnr_sustained_steps = 0
        self.state = (self.observation >= 0.5).astype(np.int8)  # 이진화 상태

        mask = self.create_action_mask(self.observation)
        return self.observation, {"state": self.state, "mask": mask}

    def step(self, action, lr=1e-4, z=2e-3):
        action = np.reshape(action, (1, 8, 1024, 1024)).astype(np.int8)
        binary = torch.tensor(action, dtype=torch.float32).cuda()
        binary = tt.Tensor(binary, meta={'dx': (7.56e-6, 7.56e-6), 'wl': 515e-9})  # meta 정보 포함

        # 시뮬레이션
        sim = tt.simulate(binary, z).abs()**2
        result = torch.mean(sim, dim=1, keepdim=True)

        # MSE 및 PSNR 계산
        mse = tt.relativeLoss(result, self.target_image, F.mse_loss).detach().cpu().numpy()
        psnr = tt.relativeLoss(result, self.target_image, tm.get_PSNR)
        reward = -mse

        # 상태 업데이트
        self.state = np.logical_xor(self.state, action).astype(np.float32)

        # 종료 조건
        terminated = self.steps >= self.max_steps or self.psnr_sustained_steps >= self.T_steps
        truncated = self.steps >= self.max_steps

        if psnr >= self.T_PSNR:
            self.psnr_sustained_steps += 1
        else:
            self.psnr_sustained_steps = 0

        # 단일 채널 관찰값 반환
        self.observation = result.detach().cpu().numpy()

        # 추가 정보 반환
        mask = self.create_action_mask(self.observation)
        info = {"mse": mse, "psnr": psnr, "mask": mask}
        return self.observation, reward, terminated, truncated, info


    def create_action_mask(self, observation):
        """
        관찰값에 따라 행동 마스크 생성.
        관찰값이 0~0.2인 경우 -> 행동 0으로 고정.
        관찰값이 0.8~1인 경우 -> 행동 1로 고정.
        """
        mask = np.ones_like(observation, dtype=np.int8)  # 기본적으로 모든 행동 가능
        mask[observation <= 0.2] = 0  # 관찰값이 0~0.2면 행동 0으로 고정
        mask[observation >= 0.8] = 1  # 관찰값이 0.8~1이면 행동 1로 고정
        return mask

class Dataset512(Dataset):
    def __init__(self, target_dir, meta, transform=None, isTrain=True, padding=0):
        self.target_dir = target_dir
        self.transform = transform
        self.meta = meta
        self.isTrain = isTrain
        self.target_list = sorted(glob.glob(target_dir+'*.png'))
        self.center_crop = torchvision.transforms.CenterCrop(1024)
        self.random_crop = torchvision.transforms.RandomCrop((1024, 1024))
        self.padding = padding

    def __len__(self):
        return len(self.target_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        target = tt.imread(self.target_list[idx], meta=meta, gray=True).unsqueeze(0)
        if target.shape[-1] < 1024 or target.shape[-2] < 1024:
            target = torchvision.transforms.Resize(1024)(target)
        if self.isTrain:
            target = self.random_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        else:
            target = self.center_crop(target)
            target = torchvision.transforms.functional.pad(target, (self.padding, self.padding, self.padding, self.padding))
        return target


def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)

batch_size = 1
target_dir = '/nfs/dataset/DIV2K/DIV2K_train_HR/DIV2K_train_HR/'
valid_dir = '/nfs/dataset/DIV2K/DIV2K_valid_HR/DIV2K_valid_HR/'
meta = {'wl': (515e-9), 'dx': (7.56e-6, 7.56e-6)}  # 메타 정보
padding = 0

# Dataset512 클래스 사용
train_dataset = Dataset512(target_dir=target_dir, meta=meta, isTrain=True, padding=padding)
valid_dataset = Dataset512(target_dir=valid_dir, meta=meta, isTrain=False, padding=padding)

# DataLoader 생성
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# BinaryNet 모델 로드
model = BinaryNet(num_hologram=8, in_planes=1, convReLU=False, convBN=False,
                  poolReLU=False, poolBN=False, deconvReLU=False, deconvBN=False).cuda()
model.load_state_dict(torch.load('result_v/2024-12-15 14:02:27.770108_pre_reinforce_8_0.002/2024-12-15 14:02:27.770108_pre_reinforce_8_0.002'))
model.eval()

# 데이터 로더 생성
#target_dir = '/nfs/dataset/DIV2K/DIV2K_train_HR/DIV2K_train_HR/'
#valid_dir = '/nfs/dataset/DIV2K/DIV2K_valid_HR/DIV2K_valid_HR/'
#meta = {'wl': 515e-9, 'dx': (7.56e-6, 7.56e-6)}  # 메타 정보
#padding = 0

#train_dataset = SimpleDataset(target_dir=target_dir, meta=meta, isTrain=True, padding=padding)
#valid_dataset = SimpleDataset(target_dir=valid_dir, meta=meta, isTrain=False, padding=padding)

#batch_size = 1
#train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
#valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# BinaryNet 모델 로드
#model = BinaryNet(num_hologram=8, in_planes=1, convReLU=False, convBN=False,
#                  poolReLU=False, poolBN=False, deconvReLU=False, deconvBN=False).cuda()
#model.load_state_dict(torch.load('result_v/2024-12-15 14:02:27.770108_pre_reinforce_8_0.002/2024-12-15 14:02:27.770108_pre_reinforce_8_0.002'))
#model.eval()


# 마스크 함수 정의
def mask_fn(env):
    return env.create_action_mask(env.observation)

# 환경 생성에 새로운 데이터 로더 적용
env = BinaryHologramEnv(
    target_function=model,
    trainloader=train_loader,  # 업데이트된 train_loader 사용
    max_steps=1000,
    T_PSNR=30,
    T_steps=100
)

# ActionMasker 래퍼 적용
env = ActionMasker(env, mask_fn)

# Vectorized 환경 생성
venv = make_vec_env(lambda: env, n_envs=1)
venv = VecNormalize(venv, norm_obs=True, norm_reward=True, clip_obs=10.0)

# PPO 학습
ppo_model = PPO(
    "MlpPolicy",
    venv,
    verbose=1,
    n_steps=1024,
    batch_size=64,
    gamma=0.99,
    learning_rate=3e-4,
    tensorboard_log="./ppo_with_mask/"
)

ppo_model.learn(total_timesteps=100000)

# 학습된 모델 저장
ppo_model.save(f"ppo_with_mask_{current_date}")

  warn(f"Failed to load image Python extension: {e}")
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


torch.Size([1, 8, 1024, 1024])
Using cuda device


  model.load_state_dict(torch.load('result_v/2024-12-15 14:02:27.770108_pre_reinforce_8_0.002/2024-12-15 14:02:27.770108_pre_reinforce_8_0.002'))
  return super().__new__(cls, torch.Tensor(x), *args, **kwargs).to(device)
  target = tt.imread(self.target_list[idx], meta=meta, gray=True).unsqueeze(0)
  if target.shape[-1] < 1024 or target.shape[-2] < 1024:
  return x.ndim >= 2
  channels = 1 if img.ndim == 2 else img.shape[-3]
  height, width = img.shape[-2:]
  return img[..., top:bottom, left:right]
  if img.ndim < 4:
  img = img.unsqueeze(dim=0)
  out_dtype = img.dtype
  img = img.squeeze(dim=0)
  if elem.is_nested:
  if elem.layout in {torch.sparse_coo, torch.sparse_csr, torch.sparse_bsr, torch.sparse_csc, torch.sparse_bsc}:
  return torch.stack(batch, 0, out=out)
  self.target_image = self.target_image.cuda()
  return F.conv2d(input, weight, bias, self.stride,
  return F.conv_transpose2d(
  concat4 = torch.cat((deconv4, enc4_2), dim=1)
  concat3 = torch.cat((deconv3, enc3_2), dim=1)


Logging to ./ppo_with_mask/PPO_12


  size = len(tensor.shape)
  for _ in range(size - len(tensor.shape)):
  superKron = torch.ones(superSample, superSample, device=tensor.device)
  shape = tensor.shape[-2:]
  x = [torch.linspace(-0.5, 0.5, s, device=tensor.device) for s in shape]
  one = torch.ones(1, device=tensor.device)
  return fftshift(torch.fft.fftn(tensor, dim=(-2,-1)))
  shifts = [tensor.size(dim)//2 for dim in dims]
  return torch.roll(tensor, shifts, dims)
  AS = AS*filter
  return ifft(AS * get_ASM_kernel(tensor.shape, z, dx=meta['dx'], wl=meta['wl'], device=tensor.device, **kwargs))
  f = get_f(shape=shape[-2:], dx=dx, device=str(device))
  shifts = [-(tensor.size(dim)//2) for dim in dims]
  return torch.roll(tensor, shifts, dims)
  return torch.fft.ifftn(ifftshift(tensor), dim=(-2,-1))
  return torch.roll(tensor, [-int(wl/dx0*p*z/dx0+0.5) for p,dx0 in zip(filterPos, dx)], dims=(-2,-1))
  return timg[..., p_size1:-p_size1, p_size2:-p_size2]
  for i in range(len(r.shape)-size):
  sim = tt.simulate(binary, z).

-----------------------------
| time/              |      |
|    fps             | 4    |
|    iterations      | 1    |
|    time_elapsed    | 227  |
|    total_timesteps | 1024 |
-----------------------------


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 23.65 GiB of which 1.71 GiB is free. Process 684255 has 384.00 MiB memory in use. Including non-PyTorch memory, this process has 21.55 GiB memory in use. Of the allocated memory 19.67 GiB is allocated by PyTorch, and 1.43 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)