In [2]:
import torch
from torchvision import transforms
from PIL import Image
import numpy as np
import os

from gan.model import Generator
from gan.functions import preprocess_data

In [4]:
def load_checkpoint(checkpoint_path, netG):
    checkpoint = torch.load(checkpoint_path)
    netG.load_state_dict(checkpoint['netG_state_dict'])
    print(f"Checkpoint loaded from {checkpoint_path}")

In [6]:
def infer(input_image_path, output_image_path, checkpoint_path, device):
    # 모델 초기화
    netG = Generator().to(device)
    
    # 체크포인트 로드
    load_checkpoint(checkpoint_path, netG)
    
    # 모델을 평가 모드로 설정
    netG.eval()
    
    # 이미지 전처리
    transform = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor()
    ])
    
    image = Image.open(input_image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)  # 배치 차원을 추가하고 디바이스로 이동
    
    # 인퍼런스 수행
    with torch.no_grad():
        generated_image = netG(image)
    
    # 결과 이미지를 저장
    generated_image = generated_image.squeeze(0).cpu().detach().numpy()  # 배치 차원 제거
    generated_image = np.transpose(generated_image, (1, 2, 0))  # (C, H, W) -> (H, W, C)
    
    # 3채널씩 끊어서 RGB 이미지로 저장
    num_channels = generated_image.shape[2]
    assert num_channels % 3 == 0, "Generated image channels should be divisible by 3"
    
    for i in range(0, num_channels, 3):
        rgb_image = generated_image[:, :, i:i+3]
        rgb_image = (rgb_image * 255).astype(np.uint8)  # 0-1 범위를 0-255 범위로 변환
        rgb_image = Image.fromarray(rgb_image)
        rgb_output_path = f"{output_image_path}_rgb_{i//3}.png"
        rgb_image.save(rgb_output_path)
        print(f"RGB image {i//3} saved to {rgb_output_path}")

In [11]:
if __name__ == "__main__":
    input_image_path = '054000.png'
    output_image_path = './'  # 확장자 없이 기본 경로만 지정
    
    
    files = os.listdir(os.path.join(os.getcwd(), "checkpoints"))[1]
    checkpoint_path = os.path.join(os.path.join(os.getcwd(), "checkpoints"), files)
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    infer(input_image_path, output_image_path, checkpoint_path, device)

Checkpoint loaded from c:\work\ffhq_texture_generator\checkpoints\checkpoint_epoch_5.pth
RGB image 0 saved to ./_rgb_0.png
RGB image 1 saved to ./_rgb_1.png
RGB image 2 saved to ./_rgb_2.png
RGB image 3 saved to ./_rgb_3.png


'c:\\work\\ffhq_texture_generator\\checkpoints\\checkpoint_epoch_5.pth'