In [1]:
import model_loader
import pipeline
from PIL import Image
from transformers import CLIPTokenizer
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
import os
from torchvision import transforms
from torchvision.transforms import functional as F

import torch
from torch.optim import  SGD
from tqdm import tqdm
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import functional as TF


from clip import CLIP
from ddpm import DDPMSampler
from pipeline import generate, get_time_embedding

import numpy as np

from encoder import VAE_Encoder
from decoder import VAE_Decoder


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# GPU 설정
device_ids = [0, 1, 2, 3]  
# device = "cuda" if torch.cuda.is_available() else "cpu"
# print(f"Using device: {device} on GPUs {device_ids}")

device = f"cuda:{device_ids[0]}"  # 첫 번째 GPU를 기본 device로 설정
print(f"Using device: {device} on GPUs {device_ids}")

# 토크나이저 및 모델 로드
tokenizer = CLIPTokenizer("../data/vocab.json", merges_file="../data/merges.txt")
model_file = "/home/NAS_mount/seunghan/v1-5-pruned-emaonly.ckpt"
models = model_loader.preload_models_from_standard_weights(model_file, device)

# Diffusion 모델 병렬화 설정
model = models["diffusion"]
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model, device_ids=device_ids)  # 모델 병렬화
# model = model.to(device)
model = model  # .to(device) 필요 없음, DataParallel이 처리

Using device: cuda:0 on GPUs [0, 1, 2, 3]


In [None]:
import torch
from encoder import VAE_Encoder
from decoder import VAE_Decoder
import model_loader
from transformers import CLIPTokenizer

# GPU 설정
device_ids = [0, 1, 2, 3]
device = f"cuda:{device_ids[0]}"  # 첫 번째 GPU를 기본 device로 설정
print(f"Using device: {device} on GPUs {device_ids}")

# 토크나이저 및 Pre-trained 모델 로드
tokenizer = CLIPTokenizer("../data/vocab.json", merges_file="../data/merges.txt")
model_file = "/home/NAS_mount/seunghan/v1-5-pruned-emaonly.ckpt"

# Fine-tuned Encoder 및 Decoder 가중치 경로
encoder_weight_path = "/home/fall/latent-diffusion-homemade/ldms/checkpoints/fine_tuned_encoder_epoch_100.pth"
decoder_weight_path = "/home/fall/latent-diffusion-homemade/ldms/checkpoints/fine_tuned_decoder_epoch_100.pth"

# Fine-tuned 모델 로드 함수
def load_finetuned_model(model_file, encoder_weight_path, decoder_weight_path, device, device_ids):
    # Pretrained 모델 로드
    models = model_loader.preload_models_from_standard_weights(model_file, device)

    # Encoder와 Decoder 초기화
    encoder = VAE_Encoder().to(device)
    decoder = VAE_Decoder().to(device)

    # 가중치 로드 함수
    def load_state_dict_without_module(model, state_dict):
        new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        model.load_state_dict(new_state_dict, strict=False)

    # Fine-tuned 가중치 로드
    encoder_state_dict = torch.load(encoder_weight_path, map_location=device)
    decoder_state_dict = torch.load(decoder_weight_path, map_location=device)
    load_state_dict_without_module(encoder, encoder_state_dict)
    load_state_dict_without_module(decoder, decoder_state_dict)

    # Encoder와 Decoder를 모델 딕셔너리에 추가
    models['encoder'] = encoder
    models['decoder'] = decoder

    # Multi-GPU 처리
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs for parallel processing")
        for key in models:
            if isinstance(models[key], torch.nn.Module):
                models[key] = torch.nn.DataParallel(models[key], device_ids=device_ids)

    return models

# Fine-tuned 모델 로드
models = load_finetuned_model(
    model_file=model_file,
    encoder_weight_path=encoder_weight_path,
    decoder_weight_path=decoder_weight_path,
    device=device,
    device_ids=device_ids
)

# Diffusion 모델 Multi-GPU 설정
diffusion_model = models["diffusion"]
if torch.cuda.device_count() > 1:
    diffusion_model = torch.nn.DataParallel(diffusion_model, device_ids=device_ids)

print("Models successfully loaded and configured.")


In [3]:
import os
from PIL import Image
from torch.utils.data import Dataset

class PairedImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.pairs = self._make_pairs()

    def _make_pairs(self):
        pairs = []
        valid_extensions = {".jpg", ".jpeg", ".png"}  # 유효한 이미지 확장자

        # 각 폴더를 순회하면서 blur와 sharp 이미지 쌍을 생성합니다.
        for folder_name in os.listdir(self.root_dir):
            folder_path = os.path.join(self.root_dir, folder_name)
            blur_dir = os.path.join(folder_path, "blur")
            sharp_dir = os.path.join(folder_path, "sharp")
            
            if not (os.path.isdir(blur_dir) and os.path.isdir(sharp_dir)):
                continue
            
            # blur와 sharp 디렉토리에서 동일한 파일 이름을 가진 이미지 쌍을 찾습니다.
            for image_name in os.listdir(blur_dir):
                # 유효한 이미지 파일만 선택
                if not any(image_name.lower().endswith(ext) for ext in valid_extensions):
                    continue
                
                blur_image_path = os.path.join(blur_dir, image_name)
                sharp_image_path = os.path.join(sharp_dir, image_name)
                
                # 두 파일 모두 존재할 때만 추가
                if os.path.isfile(blur_image_path) and os.path.isfile(sharp_image_path):
                    pairs.append((blur_image_path, sharp_image_path))
        
        return pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        blur_image_path, sharp_image_path = self.pairs[idx]
        blur_image = Image.open(blur_image_path).convert("RGB")
        sharp_image = Image.open(sharp_image_path).convert("RGB")
        
        if self.transform:
            blur_image = self.transform(blur_image)
            sharp_image = self.transform(sharp_image)
        
        return blur_image, sharp_image  # (blurred input, sharp target)

# 이미지 중앙 부분 crop
transform = transforms.Compose([
    transforms.CenterCrop((512, 512)), 
    transforms.ToTensor()
])

# 데이터셋 및 데이터 로더
train_data = PairedImageDataset("/home/NAS_mount/seunghan/GOPRO/train/", transform=transform)
train_loader = DataLoader(train_data, batch_size=1, shuffle=True)

# Test 데이터셋과 DataLoader 생성
# test_dataset = PairedImageDataset("../images/GOPRO/test", transform=transform)
# test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [4]:
from torchvision.models import vgg16
from torchvision.transforms import Normalize

class PerceptualLoss(torch.nn.Module):
    def __init__(self, device):
        super().__init__()
        # VGG 모델의 feature extractor를 사용
        self.vgg = vgg16(pretrained=True).features[:16].to(device).eval()
        for param in self.vgg.parameters():
            param.requires_grad = False  # VGG 파라미터를 고정
        
        # VGG 입력 정규화를 위한 Normalize
        self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def forward(self, pred, target):
        # VGG 입력 정규화 및 특징 추출
        pred = self.normalize(pred)
        target = self.normalize(target)
        pred_features = self.vgg(pred)
        target_features = self.vgg(target)
        
        # Perceptual Loss (특징 차이의 L2 Norm)
        loss = torch.nn.functional.mse_loss(pred_features, target_features)
        return loss

In [None]:
WIDTH = 256
HEIGHT = 256
LATENTS_WIDTH = WIDTH // 8
LATENTS_HEIGHT = HEIGHT // 8

class LDMFineTuner:
    def __init__(self, models, tokenizer, device, log_dir="runs/ldm_finetune", save_path="./checkpoints"):
        self.models = models
        self.tokenizer = tokenizer
        
        self.device = device
        self.clip = models['clip'].to(device)
        self.diffusion = models['diffusion'].to(device)
        self.encoder = models['encoder'].to(device)
        self.decoder = models['decoder'].to(device)
        
        # 옵티마이저 설정 (Diffusion 모델만 학습)
        self.optimizer = torch.optim.SGD(self.diffusion.parameters(), lr=1e-4)
        
        self.sampler = DDPMSampler(generator=torch.Generator(device=device))
        self.sampler.set_inference_timesteps(num_inference_steps=1000) 

        # TensorBoard writer 설정
        self.writer = SummaryWriter(log_dir=log_dir)
        
        # 모델과 옵티마이저 저장 경로 설정
        self.save_path = save_path
        os.makedirs(self.save_path, exist_ok=True)  # 디렉토리 생성

        # Initialize PerceptualLoss
        self.perceptual_loss_fn = PerceptualLoss(device).to(device)        


    def prepare_batch(self, batch):
        blur_images, sharp_images = batch
        blur_images = blur_images.to(self.device)
        sharp_images = sharp_images.to(self.device)

        # 블러 이미지와 선명한 이미지에 대한 다운샘플링된 노이즈 생성
        batch_size, _, height, width = blur_images.size()
        noise_height, noise_width = height // 8, width // 8  # 다운샘플링 크기에 맞게 설정
        noise_for_blur = torch.randn(batch_size, 4, noise_height, noise_width, device=self.device)
        noise_for_sharp = torch.zeros(batch_size, 4, noise_height, noise_width, device=self.device)

        # 블러 이미지를 latent space로 인코딩
        with torch.no_grad():
            blur_latents = self.encoder(blur_images, noise_for_blur)
            sharp_latents = self.encoder(sharp_images, noise_for_sharp)  # Target latents for deblurring
        
        # CLIP을 사용하여 텍스트 임베딩 생성 (기본 프롬프트 사용)
        # 여기서는 임의의 "Deblur image" 프롬프트를 사용하여 context 생성
        tokens = self.tokenizer(["Deblur image"] * blur_images.size(0), 
                                padding="max_length", max_length=77, 
                                return_tensors="pt").input_ids.to(self.device)
        context = self.clip(tokens)
        
        return blur_latents, sharp_latents, context
    
    def train_step(self, blur_latents, sharp_latents, context):
        batch_size = blur_latents.shape[0]
        
        # 랜덤 타임스텝 선택
        t = torch.randint(0, self.sampler.num_train_timesteps, (batch_size,), device=self.device).long()
        
        # 타깃 latents에 노이즈 추가
        noise = torch.randn_like(sharp_latents)
        noisy_sharp_latents = self.sampler.add_noise(sharp_latents, t)

        # 블러 latent에 노이즈 추가
        # noisy_blur_latents = self.sampler.add_noise(blur_latents, t)  # sharp image만을 사용하기 위해 주석처리 
        
        # 시간 임베딩
        time_embedding = get_time_embedding(t).to(self.device)
        
        # Diffusion 모델로 노이즈 예측
        # predicted_noise = self.diffusion(noisy_blur_latents, context, time_embedding) # sharp image만을 사용하기 위해 주석처리 및 아래로 수정
        predicted_noise = self.diffusion(noisy_sharp_latents, context, time_embedding)
        
        # 손실 계산 (예측된 노이즈와 실제 노이즈 간의 차이)
        mse_loss  = torch.nn.functional.mse_loss(predicted_noise, noise)
        
        # Perceptual Loss 계산
        # Decoder를 통해 latent를 이미지로 복원
        predicted_sharp_latents = noisy_sharp_latents - predicted_noise
        predicted_images = self.decoder(predicted_sharp_latents).clamp(0, 1)

        # target_images = self.decoder(noisy_sharp_latents).clamp(0, 1) # sharp image만을 사용하기 위해 주석처리 및 아래로 수정
        target_images = self.decoder(sharp_latents).clamp(0, 1)

        perceptual_loss = self.perceptual_loss_fn(predicted_images, target_images)

        # 손실 결합
        lambda_p = 0.1  # Perceptual loss에 대한 가중치
        combined_loss = mse_loss + lambda_p * perceptual_loss

        return combined_loss

    
    def save_model(self, epoch):
        # 모델과 옵티마이저 상태를 지정된 경로에 저장
        model_path = os.path.join(self.save_path, f"model_epoch_{epoch}.pth")

      
        optimizer_path = os.path.join(self.save_path, f"optimizer_epoch_{epoch}.pth")
        
        torch.save(self.diffusion.state_dict(), model_path)
        torch.save(self.optimizer.state_dict(), optimizer_path)
        
        print(f"Model and optimizer saved at epoch {epoch} in {self.save_path}")


    def train(self, dataloader, num_epochs):
        for epoch in range(num_epochs):
            total_loss = 0
            for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
                blur_images, sharp_images = batch
                blur_latents, sharp_latents, context = self.prepare_batch(batch)
                
                self.optimizer.zero_grad()
                loss = self.train_step(blur_latents, sharp_latents, context)
                loss.backward()
                self.optimizer.step()
                                
                total_loss += loss.item()
            
            avg_loss = total_loss / len(dataloader)
            print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
            self.writer.add_scalar("Loss/Train", avg_loss, epoch)

            # 모델과 옵티마이저 저장
            self.save_model(epoch + 1)

            # 샘플 블러 이미지를 선택하여 디블러링 성능을 시각화
            # sample_blur_image = blur_images[0]  # 첫 번째 이미지 선택
            # self.log_generated_images(epoch + 1, sample_blur_image)

    from PIL import Image

    def log_generated_images(self, epoch, blur_image, output_dir="./generated_outputs"):
        """
        디블러링 과정을 수행하고 TensorBoard에 기록 및 결과 저장
        """
        output_image = pipeline.generate(
            prompt="Deblur image",
            uncond_prompt="",
            input_image=blur_image,
            strength=0.5,
            do_cfg=True,
            cfg_scale=8,
            sampler_name="ddpm",
            n_inference_steps=50,
            seed=42,
            models=self.models,
            device=self.device,
            idle_device="cpu",
            tokenizer=self.tokenizer,
        )
        # Convert output to PIL image
        deblurred_image = Image.fromarray(output_image)

        # Resize blur image to match deblurred image size
        blur_image_resized = blur_image.resize(deblurred_image.size)

        # Save the blurred and deblurred images
        blur_image_resized.save(os.path.join(output_dir, f"blurred_epoch_{epoch}.png"))
        deblurred_image.save(os.path.join(output_dir, f"deblurred_epoch_{epoch}.png"))

        # Log images to TensorBoard
        blur_tensor = TF.to_tensor(blur_image_resized).unsqueeze(0)  # (1, C, H, W)
        deblurred_tensor = TF.to_tensor(deblurred_image).unsqueeze(0)  # (1, C, H, W)

        self.writer.add_image(f"Blurred_Image/Epoch_{epoch}", blur_tensor, epoch)
        self.writer.add_image(f"Deblurred_Image/Epoch_{epoch}", deblurred_tensor, epoch)


In [6]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()


In [7]:
%env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# 학습
fine_tuner = LDMFineTuner(models, tokenizer, device)
fine_tuner.train(train_loader, num_epochs=100)

env: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


Epoch 1/100: 100%|██████████| 2103/2103 [55:23<00:00,  1.58s/it]


Epoch 1/100, Average Loss: 2.2096
Model and optimizer saved at epoch 1 in ./checkpoints


Epoch 2/100:   2%|▏         | 50/2103 [01:18<53:59,  1.58s/it]


KeyboardInterrupt: 