In [1]:
import torch
from encoder import VAE_Encoder
from decoder import VAE_Decoder
from PIL import Image
import numpy as np
import model_loader

from skimage.metrics import structural_similarity as ssim
import lpips

import torch
import torch.nn.functional as F


from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os


from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn, optim

from torchvision.utils import save_image


In [2]:
# 디바이스 설정
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 모델 정의
encoder = VAE_Encoder().to(DEVICE)
decoder = VAE_Decoder().to(DEVICE)

# Fine-tuned 모델 가중치 파일 경로
encoder_weight_path = "/home/fall/latent-diffusion-homemade/ldms/checkpoints/blur_encoder_conv_epoch_220.pth"
decoder_weight_path = "/home/fall/latent-diffusion-homemade/ldms/checkpoints/blur_decoder_conv_epoch_220.pth"

def load_state_dict_without_module(model, state_dict):
    new_state_dict = {}
    for k, v in state_dict.items():
        name = k.replace("module.", "")  # Remove the 'module.' prefix
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)

# Load encoder weights
encoder_state_dict = torch.load(encoder_weight_path, map_location=DEVICE)
load_state_dict_without_module(encoder, encoder_state_dict)

# Load decoder weights
decoder_state_dict = torch.load(decoder_weight_path, map_location=DEVICE)
load_state_dict_without_module(decoder, decoder_state_dict)

print("Fine-tuned encoder and decoder successfully loaded.")



  encoder_state_dict = torch.load(encoder_weight_path, map_location=DEVICE)
  decoder_state_dict = torch.load(decoder_weight_path, map_location=DEVICE)


Fine-tuned encoder and decoder successfully loaded.


In [3]:
# def preprocess_image(image_path, target_size=(256, 256)):
#     image = Image.open(image_path).convert("RGB")
#     image = image.crop(target_size)  # Autoencoder 입력 크기로 조정
#     image = np.array(image).astype(np.float32) / 255.0  # Normalize to [0, 1]
#     image = torch.tensor(image).permute(2, 0, 1).unsqueeze(0).to(DEVICE)  # (H, W, C) -> (1, C, H, W)
    
#     return image

def reconstruct_image(encoder, decoder, image):
    with torch.no_grad():
        # Encoder로 Latent 공간으로 변환
        noise = torch.randn(image.shape[0], 4, image.shape[2] // 8, image.shape[3] // 8).to(DEVICE)
        latents = encoder(image, noise)
        
        # Decoder로 Latent를 복원
        reconstructed_image = decoder(latents)
        reconstructed_image = reconstructed_image.clamp(0, 1)  # 값 제한 [0, 1]
        
        return reconstructed_image
    
def compute_mse(original, reconstructed):
    return ((original - reconstructed) ** 2).mean().item()


def compute_ssim(input_image, reconstructed_image):
    # 디버깅용 출력
    print("Input image shape:", input_image.shape)
    print("Reconstructed image shape:", reconstructed_image.shape)

    # 텐서를 numpy 배열로 변환 및 차원 변환 (C, H, W -> H, W, C)
    input_image_np = input_image.squeeze().permute(1, 2, 0).cpu().numpy()
    reconstructed_image_np = reconstructed_image.squeeze().permute(1, 2, 0).cpu().numpy()

    # SSIM 계산
    ssim_value = ssim(
        input_image_np,
        reconstructed_image_np,
        data_range=1.0,        # 정규화된 값 [0, 1] 범위로 설정
        multichannel=True,     # 다중 채널 이미지 처리
        channel_axis=-1        # 채널 축 위치 (-1은 마지막 축)
    )
    return ssim_value

lpips_loss = lpips.LPIPS(net="alex").to(DEVICE)

def compute_lpips(original, reconstructed):
    return lpips_loss(original, reconstructed).item()

def reconstruction_loss(original, reconstructed):
    return F.mse_loss(reconstructed, original)




Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/fall/anaconda3/envs/ldms_311/lib/python3.11/site-packages/lpips/weights/v0.1/alex.pth


  self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)


In [None]:
# 입력 이미지 경로
image_path = "/home/fall/latent-diffusion-homemade/images/blur_image.png"

# 이미지 전처리
input_image = preprocess_image(image_path)

# Reconstruction
reconstructed_image = reconstruct_image(encoder, decoder, input_image)

# 성능 평가
mse = compute_mse(input_image, reconstructed_image)
ssim_value = compute_ssim(input_image, reconstructed_image)
lpips_value = compute_lpips(input_image, reconstructed_image)

# 결과 출력
print(f"MSE: {mse:.4f}, SSIM: {ssim_value:.4f}, LPIPS: {lpips_value:.4f}")

# 복원된 이미지 저장
reconstructed_image_np = reconstructed_image.squeeze().cpu().numpy().transpose(1, 2, 0) * 255.0
reconstructed_image_np = reconstructed_image_np.astype(np.uint8)
Image.fromarray(reconstructed_image_np).save("reconstructed_image.png")

In [3]:
class GOPRODataset(Dataset):
    def __init__(self, root_dir, mode='train', transform=None):
        """
        GOPRODataset 생성자
        :param root_dir: 데이터셋의 최상위 디렉토리 경로
        :param mode: 'train' 또는 'test'
        :param transform: 이미지 전처리 변환
        """
        self.root_dir = root_dir
        self.mode = mode
        self.transform = transform
        self.data = []

        # 허용된 이미지 확장자
        valid_extensions = {".jpg", ".jpeg", ".png"}

        # 폴더 탐색 및 blur, sharp 이미지 경로 저장
        base_dir = os.path.join(root_dir, mode)
        for subdir in os.listdir(base_dir):
            blur_dir = os.path.join(base_dir, subdir, 'blur')
            sharp_dir = os.path.join(base_dir, subdir, 'sharp')
            if os.path.exists(blur_dir) and os.path.exists(sharp_dir):
                blur_images = sorted(os.listdir(blur_dir))
                sharp_images = sorted(os.listdir(sharp_dir))
                for blur_img, sharp_img in zip(blur_images, sharp_images):
                    if os.path.splitext(blur_img)[1].lower() in valid_extensions and \
                       os.path.splitext(sharp_img)[1].lower() in valid_extensions:
                        self.data.append({
                            'blur': os.path.join(blur_dir, blur_img),
                            'sharp': os.path.join(sharp_dir, sharp_img)
                        })
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        blur_image = Image.open(item['blur']).convert('RGB')
        sharp_image = Image.open(item['sharp']).convert('RGB')

        if self.transform:
            blur_image = self.transform(blur_image)
            sharp_image = self.transform(sharp_image)

        return blur_image, sharp_image

# 이미지 전처리
transform = transforms.Compose([
    transforms.CenterCrop((512, 512)),  # AE 입력 크기로 조정
    transforms.ToTensor(),         # [0, 1] 범위로 정규화
])

# 데이터셋 경로
root_dir = "/home/NAS_mount/seunghan/GOPRO/"

# 데이터셋 생성
train_dataset = GOPRODataset(root_dir=root_dir, mode='train', transform=transform)
test_dataset = GOPRODataset(root_dir=root_dir, mode='test', transform=transform)

# DataLoader 생성
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)




In [5]:
# 모델 로드 (encoder, decoder는 미리 로드되어 있다고 가정)
# 모델을 평가 모드로 설정
encoder.eval()
decoder.eval()

# 테스트 평가 루프
total_loss = 0
criterion = nn.MSELoss()

with torch.no_grad():
    for blur_images, sharp_images in test_loader:
        # DataLoader에서 반환된 데이터를 GPU로 전송
        sharp_images = sharp_images.to(DEVICE)
        
        # Latent 공간으로 변환 (Encoder)
        noise = torch.randn(sharp_images.size(0), 4, sharp_images.size(2) // 8, sharp_images.size(3) // 8).to(DEVICE)
        latents = encoder(sharp_images, noise)
        
        # 복원 이미지 생성 (Decoder)
        reconstructed = decoder(latents)
        
        # 손실 계산 (Reconstruction Loss)
        loss = nn.MSELoss(reconstructed, sharp_images)
        total_loss += loss.item()

    print(f"Test Loss: {total_loss / len(test_loader):.4f}")

TypeError: unsupported operand type(s) for /: 'tuple' and 'float'

In [5]:
# 복원된 이미지 저장
with torch.no_grad():
    for blur_images, _ in test_loader:
        blur_images = blur_images.to(DEVICE)
        
        noise = torch.randn(blur_images.size(0), 4, blur_images.size(2) // 8, blur_images.size(3) // 8).to(DEVICE)
        latents, _, _ = encoder(blur_images, noise)
        reconstructed = decoder(latents).clamp(0, 1)
        
        save_image(reconstructed, "blur_reconstructed_images.png")
        break  # 한 배치만 저장

KeyboardInterrupt: 

In [5]:
# 이미지 전처리 및 모델 설정
with torch.no_grad():
    for blur_images, sharp_images in test_loader:
        sharp_images = sharp_images.to(DEVICE)
        
        # Encoder와 Decoder를 통해 복원된 이미지 생성
        noise = torch.randn(sharp_images.size(0), 4, sharp_images.size(2) // 8, sharp_images.size(3) // 8).to(DEVICE)
        latents,_,_ = encoder(sharp_images, noise)
        reconstructed = decoder(latents).clamp(0, 1)

        # 세로 방향으로 이미지를 나란히 저장하기 위해 두 이미지를 합침
        # dim=2는 세로 방향으로 이미지를 합침
        comparison_image = torch.cat((sharp_images, reconstructed), dim=2)  # 세로로 이어붙이기
        
        # 이미지를 하나의 파일로 저장
        save_image(comparison_image, "100_bilinear_sharp_vs_reconstructed_vertical.png")
        break  # 한 배치만 저장


In [5]:
# 이미지 전처리 및 모델 설정
with torch.no_grad():
    for blur_images, sharp_images in test_loader:
        blur_images = blur_images.to(DEVICE)
        
        # Encoder와 Decoder를 통해 복원된 이미지 생성
        noise = torch.randn(blur_images.size(0), 4, blur_images.size(2) // 8, blur_images.size(3) // 8).to(DEVICE)
        latents,_,_ = encoder(blur_images, noise)
        reconstructed = decoder(latents).clamp(0, 1)

        # 세로 방향으로 이미지를 나란히 저장하기 위해 두 이미지를 합침
        # dim=2는 세로 방향으로 이미지를 합침
        comparison_image = torch.cat((blur_images, reconstructed), dim=2)  # 세로로 이어붙이기
        
        # 이미지를 하나의 파일로 저장
        save_image(comparison_image, "220_bilinear_blur_vs_reconstructed_vertical.png")
        break  # 한 배치만 저장

In [None]:

class GOPRODataset(Dataset):
    def __init__(self, root_dir, mode='train', transform=None):
        """
        GOPRODataset 생성자
        :param root_dir: 데이터셋의 최상위 디렉토리 경로
        :param mode: 'train' 또는 'test'
        :param transform: 이미지 전처리 변환
        """
        self.root_dir = root_dir
        self.mode = mode
        self.transform = transform
        self.data = []

        # 폴더 탐색 및 blur, sharp 이미지 경로 저장
        base_dir = os.path.join(root_dir, mode)
        for subdir in os.listdir(base_dir):
            blur_dir = os.path.join(base_dir, subdir, 'blur')
            sharp_dir = os.path.join(base_dir, subdir, 'sharp')
            if os.path.exists(blur_dir) and os.path.exists(sharp_dir):
                blur_images = sorted(os.listdir(blur_dir))
                sharp_images = sorted(os.listdir(sharp_dir))
                for blur_img, sharp_img in zip(blur_images, sharp_images):
                    self.data.append({
                        'blur': os.path.join(blur_dir, blur_img),
                        'sharp': os.path.join(sharp_dir, sharp_img)
                    })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        blur_image = Image.open(item['blur']).convert('RGB')
        sharp_image = Image.open(item['sharp']).convert('RGB')

        if self.transform:
            blur_image = self.transform(blur_image)
            sharp_image = self.transform(sharp_image)

        return blur_image, sharp_image

# 이미지 전처리
transform = transforms.Compose([
    transforms.Resize((512, 512)),  # AE 입력 크기로 조정
    transforms.ToTensor(),         # [0, 1] 범위로 정규화
])

# 데이터셋 경로
root_dir = "/home/NAS_mount/seunghan/GOPRO/"

# 데이터셋 생성
train_dataset = GOPRODataset(root_dir=root_dir, mode='train', transform=transform)
test_dataset = GOPRODataset(root_dir=root_dir, mode='test', transform=transform)

# DataLoader 생성
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)


In [None]:
encoder.eval()
decoder.eval()

# 입력 이미지
image_path = "/home/fall/latent-diffusion-homemade/images/blur_image.png"
image = transform(Image.open(image_path).convert("RGB")).unsqueeze(0).to(DEVICE)

# Reconstruction
with torch.no_grad():
    noise = torch.randn(1, 4, 512 // 8, 512 // 8).to(DEVICE)
    latent = encoder(image, noise)
    reconstructed_image = decoder(latent).clamp(0, 1)

# 이미지 저장
from torchvision.utils import save_image
save_image(reconstructed_image, "reconstructed_image.png")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from encoder import VAE_Encoder
from decoder import VAE_Decoder

# 모델 초기화
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = VAE_Encoder().to(device)
decoder = VAE_Decoder().to(device)

# 손실 함수 및 옵티마이저
criterion = nn.MSELoss()
optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)

In [None]:
# # 학습 루프
# num_epochs = 100
# for epoch in range(num_epochs):
#     encoder.train()
#     decoder.train()
#     total_loss = 0

#     for blur_images, sharp_images in train_loader:
#         blur_images, sharp_images = blur_images.to(device), sharp_images.to(device)
        
#         # Forward pass
#         noise = torch.randn(blur_images.size(0), 4, blur_images.size(2) // 8, blur_images.size(3) // 8).to(device)
#         latents = encoder(blur_images, noise)
#         reconstructed = decoder(latents)
        
#         # Reconstruction Loss
#         loss = criterion(reconstructed, sharp_images)

#         # Backward pass
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item()

#     # Epoch 결과 출력
#     print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(train_loader):.4f}")


In [None]:
# # 테스트 루프
# encoder.eval()
# decoder.eval()
# total_loss = 0

# with torch.no_grad():
#     for blur_images, sharp_images in test_loader:
#         blur_images, sharp_images = blur_images.to(device), sharp_images.to(device)
        
#         # Forward pass
#         noise = torch.randn(blur_images.size(0), 4, blur_images.size(2) // 8, blur_images.size(3) // 8).to(device)
#         latents = encoder(blur_images, noise)
#         reconstructed = decoder(latents)

#         # Reconstruction Loss
#         loss = criterion(reconstructed, sharp_images)
#         total_loss += loss.item()

#     print(f"Test Loss: {total_loss / len(test_loader):.4f}")


In [None]:
from torchvision.utils import save_image

# 복원된 이미지 저장
with torch.no_grad():
    for blur_images, sharp_images in test_loader:
        blur_images = blur_images.to(device)
        
        noise = torch.randn(blur_images.size(0), 4, blur_images.size(2) // 8, blur_images.size(3) // 8).to(device)
        latents = encoder(blur_images, noise)
        reconstructed = decoder(latents).clamp(0, 1)
        
        save_image(reconstructed, "reconstructed_images.png")
        break  # 한 배치만 저장


## 학습 코드

In [1]:
import torch
import torch.nn as nn  
from encoder import VAE_Encoder
from decoder import VAE_Decoder
import model_loader

from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn, optim
from contperceptual import LPIPSWithDiscriminator

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

DEVICE = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

# 미리 학습된 모델 파일 경로
model_file = "/home/NAS_mount/seunghan/v1-5-pruned-emaonly.ckpt"

# 모델 로드
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)

# Encoder와 Decoder 초기화
encoder = VAE_Encoder()
decoder = VAE_Decoder()

# DataParallel로 Multi-GPU 활용
if torch.cuda.device_count() > 1:
    encoder = nn.DataParallel(encoder)
    decoder = nn.DataParallel(decoder)

# 모델을 GPU로 이동
encoder.to(DEVICE)
decoder.to(DEVICE)

# 가중치 로드
encoder.load_state_dict(models['encoder'].state_dict(), strict=False)
decoder.load_state_dict(models['decoder'].state_dict(), strict=False)


  from .autonotebook import tqdm as notebook_tqdm


KeyError: 'encoder'

In [None]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# from contperceptual import LPIPSWithDiscriminator
from tqdm import tqdm

criterion = nn.MSELoss()

# KL-divergence와 MSE 혼합 손실 함수
kl_weight = 1e-6 # KL-divergence의 가중치

# 옵티마이저
optimizer = optim.Adam(
    list(encoder.parameters()) + list(decoder.parameters()), 
    lr=1e-4,         # 학습률
    weight_decay=1e-4  # 가중치 감쇠 (옵션, 필요 시 제거 가능)
)


# optimizer = optim.SGD(
#     list(encoder.parameters()) + list(decoder.parameters()), 
#     lr=1e-3,         # 학습률 (SGD는 일반적으로 Adam보다 더 높은 학습률 사용)
#     momentum=0.9,    # 모멘텀 (옵션)
#     weight_decay=1e-4  # 가중치 감쇠 (L2 정규화, 옵션)
# )

# optimizer_weight_path = "/home/fall/latent-diffusion-homemade/ldms/checkpoints/sharp_optimizer_encoder_decoder_bilinear_epoch_100.pth"
# optimizer_state_dict = torch.load(optimizer_weight_path, map_location=DEVICE)
# optimizer.load_state_dict(optimizer_state_dict)


class GOPRODataset(Dataset):
    def __init__(self, root_dir, mode='train', transform=None):
        """
        GOPRODataset 생성자
        :param root_dir: 데이터셋의 최상위 디렉토리 경로
        :param mode: 'train' 또는 'test'
        :param transform: 이미지 전처리 변환
        """
        self.root_dir = root_dir
        self.mode = mode
        self.transform = transform
        self.data = []

        # 허용된 이미지 확장자
        valid_extensions = {".jpg", ".jpeg", ".png"}

        # 폴더 탐색 및 blur, sharp 이미지 경로 저장
        base_dir = os.path.join(root_dir, mode)
        for subdir in os.listdir(base_dir):
            blur_dir = os.path.join(base_dir, subdir, 'blur')
            sharp_dir = os.path.join(base_dir, subdir, 'sharp')
            if os.path.exists(blur_dir) and os.path.exists(sharp_dir):
                blur_images = sorted(os.listdir(blur_dir))
                sharp_images = sorted(os.listdir(sharp_dir))
                for blur_img, sharp_img in zip(blur_images, sharp_images):
                    if os.path.splitext(blur_img)[1].lower() in valid_extensions and \
                       os.path.splitext(sharp_img)[1].lower() in valid_extensions:
                        self.data.append({
                            'blur': os.path.join(blur_dir, blur_img),
                            'sharp': os.path.join(sharp_dir, sharp_img)
                        })
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        blur_image = Image.open(item['blur']).convert('RGB')
        sharp_image = Image.open(item['sharp']).convert('RGB')

        if self.transform:
            blur_image = self.transform(blur_image)
            sharp_image = self.transform(sharp_image)

        return blur_image, sharp_image

# 이미지 전처리
transform = transforms.Compose([
    transforms.CenterCrop((256, 256)),  # AE 입력 크기로 조정
    transforms.ToTensor(),         # [0, 1] 범위로 정규화
])

# 데이터셋 경로
root_dir = "/home/NAS_mount/seunghan/GOPRO/"

# 데이터셋 생성
train_dataset = GOPRODataset(root_dir=root_dir, mode='train', transform=transform)
# test_dataset = GOPRODataset(root_dir=root_dir, mode='test', transform=transform)

# DataLoader 생성
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
# test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=1)

# 학습 루프
num_epochs = 300
# 학습 루프 수정
for epoch in range(num_epochs):
    encoder.train()
    decoder.train()

    total_loss = 0

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    total_loss = 0
    with tqdm(train_loader, unit="batch", desc=f"Epoch {epoch+1}/{num_epochs}") as tepoch:
        for blur_images, sharp_images in train_loader:
            sharp_images = sharp_images.to(DEVICE)
            
            # Forward pass
            noise = torch.randn(sharp_images.size(0), 4, sharp_images.size(2) // 8, sharp_images.size(3) // 8).to(DEVICE)
            latents, mean, log_variance = encoder(sharp_images, noise)
            reconstructed = decoder(latents)

            
            # Reconstruction Loss
            mse_loss  = criterion(reconstructed, sharp_images)
            # KL-divergence Loss
            kl_loss = torch.mean(-0.5 * torch.sum(1 + log_variance - mean.pow(2) - log_variance.exp(), dim=1))

            # Combined Loss
            loss = mse_loss + kl_weight * kl_loss

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(train_loader):.4f}")

    # 모델 저장 (50 에포크마다)
    if (epoch + 1) % 25 == 0:
        torch.save(encoder.state_dict(), f"/home/fall/latent-diffusion-homemade/ldms/checkpoints/sharpencoder_bilinear_epoch_{epoch+1}.pth")
        torch.save(decoder.state_dict(), f"/home/fall/latent-diffusion-homemade/ldms/checkpoints/sharpdecoder_bilinear_epoch_{epoch+1}.pth")
        torch.save(optimizer.state_dict(), f"/home/fall/latent-diffusion-homemade/ldms/checkpoints/sharpoptimizer_encoder_decoder_bilinear_epoch_{epoch+1}.pth")

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


Epoch 1/300:   0%|          | 0/132 [22:35<?, ?batch/s]


Epoch [1/300], Loss: 0.0268


Epoch 2/300:   0%|          | 0/132 [24:17<?, ?batch/s]


Epoch [2/300], Loss: 0.0108


Epoch 3/300:   0%|          | 0/132 [33:46<?, ?batch/s]


Epoch [3/300], Loss: 0.0079


Epoch 4/300:   0%|          | 0/132 [23:28<?, ?batch/s]


KeyboardInterrupt: 

In [None]:
# 테스트 데이터셋 로드
test_dataset = GOPRODataset(root_dir=root_dir, mode='test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

# 평가
encoder.eval()
decoder.eval()
total_loss = 0

with torch.no_grad():
    for blur_images, sharp_images in test_loader:
        blur_images, sharp_images = blur_images.to(DEVICE), sharp_images.to(DEVICE)
        
        # Forward pass
        noise = torch.randn(blur_images.size(0), 4, blur_images.size(2) // 8, blur_images.size(3) // 8).to(DEVICE)
        latents = encoder(blur_images, noise)
        reconstructed = decoder(latents)

        # Reconstruction Loss
        loss = criterion(reconstructed, blur_images)
        total_loss += loss.item()

    print(f"Test Loss: {total_loss / len(test_loader):.4f}")

In [None]:
from torchvision.utils import save_image

# 복원된 이미지 저장
with torch.no_grad():
    for blur_images in test_loader:
        blur_images = blur_images.to(DEVICE)
        
        noise = torch.randn(blur_images.size(0), 4, blur_images.size(2) // 8, blur_images.size(3) // 8).to(DEVICE)
        latents = encoder(blur_images, noise)
        reconstructed = decoder(latents).clamp(0, 1)
        
        save_image(reconstructed, "blur_reconstructed_images.png")
        break  # 한 배치만 저장


## clear encoder decoder 학습

In [None]:
import torch
import torch.nn as nn  
from encoder import VAE_Encoder
from decoder import VAE_Decoder
import model_loader

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 미리 학습된 모델 파일 경로
model_file = "/home/NAS_mount/seunghan/v1-5-pruned-emaonly.ckpt"

# 모델 로드
models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)

# Encoder와 Decoder 초기화
encoder = VAE_Encoder()
decoder = VAE_Decoder()

# DataParallel로 Multi-GPU 활용
if torch.cuda.device_count() > 1:
    encoder = nn.DataParallel(encoder)
    decoder = nn.DataParallel(decoder)

# 모델을 GPU로 이동
encoder.to(DEVICE)
decoder.to(DEVICE)

# 가중치 로드
encoder.load_state_dict(models['encoder'].state_dict(), strict=False)
decoder.load_state_dict(models['decoder'].state_dict(), strict=False)


In [None]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# 손실 함수 및 옵티마이저 정의
criterion = nn.MSELoss()
optimizer = optim.SGD(
    list(encoder.parameters()) + list(decoder.parameters()), 
    lr=1e-3,         # 학습률 (SGD는 일반적으로 Adam보다 더 높은 학습률 사용)
    momentum=0.9,    # 모멘텀 (옵션)
    weight_decay=1e-4  # 가중치 감쇠 (L2 정규화, 옵션)
)
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class GOPRODataset(Dataset):
    def __init__(self, root_dir, mode='train', transform=None):
        """
        GOPRODataset 생성자
        :param root_dir: 데이터셋의 최상위 디렉토리 경로
        :param mode: 'train' 또는 'test'
        :param transform: 이미지 전처리 변환
        """
        self.root_dir = root_dir
        self.mode = mode
        self.transform = transform
        self.data = []

        # 허용된 이미지 확장자
        valid_extensions = {".jpg", ".jpeg", ".png"}

        # 폴더 탐색 및 blur, sharp 이미지 경로 저장
        base_dir = os.path.join(root_dir, mode)
        for subdir in os.listdir(base_dir):
            blur_dir = os.path.join(base_dir, subdir, 'blur')
            sharp_dir = os.path.join(base_dir, subdir, 'sharp')
            if os.path.exists(blur_dir) and os.path.exists(sharp_dir):
                blur_images = sorted(os.listdir(blur_dir))
                sharp_images = sorted(os.listdir(sharp_dir))
                for blur_img, sharp_img in zip(blur_images, sharp_images):
                    if os.path.splitext(blur_img)[1].lower() in valid_extensions and \
                       os.path.splitext(sharp_img)[1].lower() in valid_extensions:
                        self.data.append({
                            'blur': os.path.join(blur_dir, blur_img),
                            'sharp': os.path.join(sharp_dir, sharp_img)
                        })
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        blur_image = Image.open(item['blur']).convert('RGB')
        sharp_image = Image.open(item['sharp']).convert('RGB')

        if self.transform:
            blur_image = self.transform(blur_image)
            sharp_image = self.transform(sharp_image)

        return blur_image, sharp_image

# 이미지 전처리
transform = transforms.Compose([
    transforms.Resize((512, 512)),  # AE 입력 크기로 조정
    transforms.ToTensor(),         # [0, 1] 범위로 정규화
])

# 데이터셋 경로
root_dir = "/home/NAS_mount/seunghan/GOPRO/"

# 데이터셋 생성
train_dataset = GOPRODataset(root_dir=root_dir, mode='train', transform=transform)
# test_dataset = GOPRODataset(root_dir=root_dir, mode='test', transform=transform)

# DataLoader 생성
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
# test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=1)

# 학습 루프
num_epochs = 100
for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    total_loss = 0

    for _, sharp_images in train_loader:
        sharp_images = sharp_images.to(DEVICE)
        
        # Forward pass
        noise = torch.randn(sharp_images.size(0), 4, sharp_images.size(2) // 8, sharp_images.size(3) // 8).to(DEVICE)
        latents = encoder(sharp_images, noise)
        reconstructed = decoder(latents)
        
        # Reconstruction Loss
        loss = criterion(reconstructed, sharp_images)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(train_loader):.4f}")

    # 모델 저장 (10 에포크마다)
    if (epoch + 1) % 25 == 0:
        torch.save(encoder.state_dict(), f"fine_tuned_CLEAR_encoder_epoch_{epoch+101}.pth")
        torch.save(decoder.state_dict(), f"fine_tuned_CLEAR_decoder_epoch_{epoch+101}.pth")


In [None]:
# 평가
encoder.eval()
decoder.eval()
total_loss = 0

with torch.no_grad():
    for blur_images, sharp_images in test_loader:
        sharp_images = sharp_images.to(DEVICE)
        
        # Forward pass
        noise = torch.randn(sharp_images.size(0), 4, sharp_images.size(2) // 8, sharp_images.size(3) // 8).to(DEVICE)
        latents = encoder(sharp_images, noise)
        reconstructed = decoder(latents)

        # Reconstruction Loss
        loss = criterion(reconstructed, sharp_images)
        total_loss += loss.item()

    print(f"Test Loss: {total_loss / len(test_loader):.4f}")

In [None]:
# 이미지 전처리 및 모델 설정
with torch.no_grad():
    for blur_images, sharp_images in test_loader:
        sharp_images = sharp_images.to(DEVICE)
        
        # Encoder와 Decoder를 통해 복원된 이미지 생성
        noise = torch.randn(sharp_images.size(0), 4, sharp_images.size(2) // 8, sharp_images.size(3) // 8).to(DEVICE)
        latents = encoder(sharp_images, noise)
        reconstructed = decoder(latents).clamp(0, 1)

        # 세로 방향으로 이미지를 나란히 저장하기 위해 두 이미지를 합침
        # dim=2는 세로 방향으로 이미지를 합침
        comparison_image = torch.cat((sharp_images, reconstructed), dim=2)  # 세로로 이어붙이기
        
        # 이미지를 하나의 파일로 저장
        save_image(comparison_image, "sharp_vs_reconstructed_vertical.png")
        break  # 한 배치만 저장