In [None]:
""" Trainer.ipynb """

# import
import os
import torch
import gc
from torch.optim import Adam
from torchvision.utils import save_image
import torch.nn as nn

## Trainer

In [None]:
""" Trainer """

# Generator와 Discriminator 학습 + fade-in 기법 적용
class Trainer():
    def __init__(self, steps, device, train_loader, val_loader, checkpoint_path=None):
      # Progressive 단계 수
      self.steps = steps
      self.device = device
      self.train_loader = train_loader
      self.val_loader = val_loader

      # Generator & Discriminator 모델 -> device
      self.generator = Generator(steps).to(device)
      self.discriminator = Discriminator(steps).to(device)

      # 이진 분류용 loss function -> 진짜 vs 가짜
      self.criterion = nn.BCELoss()

      # Optimizer: Adam
      self.g_opt = Adam(self.generator.parameters(), lr=0.003, betas=(0.0,0.99))
      self.d_opt = Adam(self.discriminator.parameters(), lr=0.001, betas=(0.0,0.99))

      # ---fade-in 제어---
      self.alpha = 0
      # 한 epoch 동안 alpha가 0~1이 되도록 조정
      self.alpha_gap = 1 / (len(train_loader) * 5) # 총 5 epoch동안 전환

      self.final_sample_dir = "/content/drive/MyDrive/ProGAN/Final_images" # 최종 생성 이미지 저장 위치 - 변경 가능
      os.makedirs(self.final_sample_dir, exist_ok=True)

      self.epoch_image_dir = "/content/drive/MyDrive/ProGAN/epoch_images" # 각 epoch 별 이미지 저장 위치 - 변경 가능
      os.makedirs(self.epoch_image_dir, exist_ok=True)

      self.checkpoint_dir = "/content/drive/MyDrive/ProGAN/checkpoints" # checkpoint 저장 위치 - 변경 가능
      os.makedirs(self.checkpoint_dir, exist_ok=True)

      # ---Sampling용 고정 noise---
      # 매 epoch마다 같은 z 입력 -> 생성 결과 변화 비교
      self.test_z = torch.randn(1, 128, 1, 1, device=self.device)
      self.test_z_last = torch.randn(50, 128, 1, 1, device=self.device)

      self.start_epoch = 0
      self.history = {'g_train': [], 'd_train': [], 'g_val': [], 'd_val': []}

      if checkpoint_path is not None:
        self._load_checkpoint(checkpoint_path)

    # Save checkpoint
    def _save_checkpoint(self, epoch):
      path = os.path.join(self.checkpoint_dir, f"checkpoint_epoch{epoch}.pt") # Google Drive 내에 저장하는 것을 추천드립니다.(런타임 끊기는 것을 대비)
      torch.save({
          'epoch': epoch,
          'generator': self.generator.state_dict(),
          'discriminator': self.discriminator.state_dict(),
          'g_opt': self.g_opt.state_dict(),
          'd_opt': self.d_opt.state_dict(),
          'alpha': self.alpha,
          'history': self.history
      }, path)
      print(f"Checkpoint saved: {path}")


    # Load checkpoint
    def _load_checkpoint(self, path):
      checkpoint = torch.load(path)
      self.generator.load_state_dict(checkpoint['generator'])
      self.discriminator.load_state_dict(checkpoint['discriminator'])
      self.g_opt.load_state_dict(checkpoint['g_opt'])
      self.d_opt.load_state_dict(checkpoint['d_opt'])
      self.alpha = checkpoint['alpha']
      self.start_epoch = checkpoint['epoch'] + 1
      self.history = checkpoint['history']
      print(f"Resuming training from epoch {self.start_epoch}")


    # 매 epoch 후 고정 노이즈로 이미지 생성 및 저장
    def save_epoch_image(self, epoch):
      # BatchNorm, Dropout 비활성화
      self.generator.eval()
      with torch.no_grad():
        fake = self.generator(self.test_z)
        save_image(fake[0], f"{self.epoch_image_dir}/epoch_{epoch:03d}.png", normalize=True, value_range=(-1, 1))

    # 마지막 epoch 최종 이미지 생성 및 저장
    def save_last_epoch_image(self, epoch):
      self.generator.eval()
      with torch.no_grad():
        fake = self.generator(self.test_z_last)
        for i in range(fake.size(0)):
          save_image(fake[i], f"{self.final_sample_dir}/epoch_{epoch:03d}_sample_{i:02d}.png", normalize=True, value_range=(-1, 1))


    # 한 epoch동안 Generator와 Discriminator 학습
    def train_epoch(self):
      self.generator.train()
      self.discriminator.train()
      g_loss_avg = 0
      d_loss_avg = 0

      for real in self.train_loader:
        real = real.to(self.device)

        # batch size
        bs = real.size(0)
        real_lbl = torch.full((bs, 1), 0.9, device=self.device)
        fake_lbl = torch.full((bs, 1), 0.1, device=self.device)

        # ---Discriminator Train---
        # 가짜 이미지 생성
        z = torch.randn(bs, 128, 1, 1, device=self.device)
        fake = self.generator(z)

        # 진짜/가짜 판별
        d_fake = self.discriminator(fake.detach(), self.alpha)
        d_real = self.discriminator(real, self.alpha)

        # loss 계산 - 진짜 = 1, 가짜 = 0
        d_loss = self.criterion(d_fake, fake_lbl) + self.criterion(d_real, real_lbl)

        # 역전파 & parameter 업데이트
        self.d_opt.zero_grad()
        d_loss.backward()
        self.d_opt.step()

        # 두 loss의 평균값
        d_loss_avg += d_loss.item()/2

        # ---Generator Train---
        # 새로운 노이즈로 가짜 이미지 생성
        for _ in range(2):
          z = torch.randn(bs, 128, 1, 1, device=self.device)
          fake = self.generator(z)

          # Discriminator가 가짜를 진짜로 판단하도록 유도
          g_loss = self.criterion(self.discriminator(fake, self.alpha), real_lbl)

          # 역전파 & parameter 업데이트
          self.g_opt.zero_grad()
          g_loss.backward()
          self.g_opt.step()
          g_loss_avg += g_loss.item()

        # fade-in 비율 업데이트 - block 간 부드러운 전환
        self.alpha = min(1, self.alpha + self.alpha_gap)

        gc.collect()
        torch.cuda.empty_cache()

      # epcoh 별 평균 loss 반환
      return g_loss_avg/len(self.train_loader), d_loss_avg/len(self.train_loader)


    def valid_epoch(self):
      # 한 epoch 동안 Generator/Discriminator의 validation loss 계산
      self.generator.eval()
      self.discriminator.eval()
      g_loss_avg = 0
      d_loss_avg = 0

      with torch.no_grad():
        for real in self.val_loader:
          real = real.to(self.device)
          bs = real.size(0)
          real_lbl = torch.full((bs, 1), 0.9, device=self.device)
          fake_lbl = torch.full((bs, 1), 0.1, device=self.device)

          # Discriminator 검증
          z = torch.randn(bs, 128, 1, 1, device=self.device)
          fake = self.generator(z)
          d_fake = self.discriminator(fake, self.alpha)
          d_real = self.discriminator(real, self.alpha)
          d_loss_avg += (self.criterion(d_fake, fake_lbl) + self.criterion(d_real, real_lbl)).item() / 2

          # Generator 검증
          z = torch.randn(bs, 128, 1, 1, device=self.device)
          fake = self.generator(z)
          g_loss_avg += self.criterion(self.discriminator(fake, self.alpha), real_lbl).item()

      return g_loss_avg/len(self.val_loader), d_loss_avg/len(self.val_loader)


    # 지정된 epoch 수만큼 학습 및 검증을 수행하고 epoch마다 이미지 저장
    def run(self, epochs):
      for epoch in range(self.start_epoch, epochs):
        # 학습 & validation loss 계산
        g_t, d_t = self.train_epoch()
        g_v, d_v = self.valid_epoch()
        self.history['g_train'].append(g_t)
        self.history['d_train'].append(d_t)
        self.history['g_val'].append(g_v)
        self.history['d_val'].append(d_v)

        # loss log 출력
        print(f"Epoch {epoch}: G_loss {g_t:.4f}/{g_v:.4f}, D_loss {d_t:.4f}/{d_v:.4f}")

        # epoch별 생성 이미지 저장
        self.save_epoch_image(epoch)
        self._save_checkpoint(epoch)

        # 마지막 epoch -> 50개 개별 이미지도 저장
        if epoch == epochs - 1:
          self.save_last_epoch_image(epoch)
          self._save_checkpoint(epoch)

      return self.history