# Setting

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# FFHQ 압축파일 Colab 런타임 VM으로 복사
# 디렉토리에 ffhq.zip 파일이 있어야 합니다.
!cp /content/sample_directory/ffhq.zip /content/

# VM에 파일 압축풀기
# /content/data 경로에 모든 이미지가 저장됩니다.
!unzip -q /content/ffhq.zip -d /content/data/

In [None]:
# fake_imgs 압축파일 Colab 런타임 VM으로 복사
# 디렉토리에 fake_imgs.zip 파일이 있어야 합니다.
# fake_imgs은 DDIM 모델로 샘플링한 가짜 이미지들입니다.
!cp /content/sample_directory/fake_imgs.zip /content/
!mkdir /content/data/fake

# VM에 파일 압축풀기
# /content/data/fake 경로에 모든 이미지가 저장됩니다.
!unzip -q /content/fake_imgs.zip -d /content/data/fake/

In [None]:
import numpy as np
import os
from PIL import Image
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm_
from torchvision import transforms, datasets, utils

# Dataset

In [None]:
class FFHQDataset(Dataset):
    def __init__(self, num_total_img, real_img_paths, fake_img_paths, transform=None):
        num_real_imgs = num_total_img // 2
        num_fake_imgs = num_total_img - num_real_imgs
        self.num_img = num_total_img
        self.real_img_paths = real_img_paths
        self.fake_img_paths = fake_img_paths
        self.transform = transform

    def __len__(self):
        return len(self.real_img_paths + self.fake_img_paths)

    def __getitem__(self, idx):
        if idx < len(self.real_img_paths):
            path = self.real_img_paths[idx]
            label = 1
        else:
            path = self.fake_img_paths[idx - len(self.real_img_paths)]
            label = 0

        img = Image.open(path).convert('RGB')
        if self.transform:
            img = self.transform(img)

        label = torch.tensor(label, dtype=torch.float32)

        return img, label

real_img_dir = '/content/data/thumbnails128x128'
fake_img_dir = '/content/data/fake'

temp_real = os.listdir(real_img_dir)
temp_fake = os.listdir(fake_img_dir)

# train/test 이미지 개수 => 진짜 이미지 + 가짜 이미지
num_train_image = 6000
num_test_image = 200

train_real = temp_real[:num_train_image//2]
train_fake = temp_fake[:num_train_image//2]
test_real = temp_real[num_train_image//2:num_train_image//2+num_test_image//2]
test_fake = temp_fake[num_train_image//2:num_train_image//2+num_test_image//2]

train_real_paths = [os.path.join(real_img_dir, i) for i in train_real]
train_fake_paths = [os.path.join(fake_img_dir, i) for i in train_fake]
test_real_paths = [os.path.join(real_img_dir, i) for i in test_real]
test_fake_paths = [os.path.join(fake_img_dir, i) for i in test_fake]

train_dataset = FFHQDataset(num_total_img=num_train_image, real_img_paths=train_real_paths, fake_img_paths=train_fake_paths, transform=train_transform)
test_dataset = FFHQDataset(num_total_img=num_test_image, real_img_paths=test_real_paths, fake_img_paths=test_fake_paths, transform=test_transform)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

train_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])
test_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

batch_size = 128

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

# Model

In [None]:
# ProGAN에서 구현한 discriminator 그대로 사용하시면 됩니다.

# Train setup

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loss_func = nn.BCELoss()
model = Discriminator(steps=4).to(device)

# 학습된 discriminator 가중치입니다.
checkpoint_path = 'model.pt'
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)

# Discriminator에서 가중치 업데이트를 하지 않을 (동결할) 레이어 수
n_freeze = 4

for i in range(n_freeze):
    # fromRGB 레이어 i 동결
    for p in model.fromrgb_layers[i].parameters():
        p.requires_grad = False
    # prog_block i 동결
    for p in model.prog_blocks[i].parameters():
        p.requires_grad = False

# 옵티마이저에는 requires_grad=True 인 파라미터만 넣어줍니다.
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=5e-5
)

In [None]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable: {trainable_params}/{total_params} 파라미터")

# Test before Train

In [None]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for image,label in test_loader:
        image = image.to(device)
        label = label.to(device)

        output = model(image, 1.0).view(-1)
        predicted = (output.data > 0.5).long().view(-1)
        total += label.size(0)
        correct += (predicted == label.long()).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy:.2f}%')

# Train

In [None]:
for epoch in tqdm(range(10)):
    model.train()
    for image,label in train_loader:
        image = image.to(device)
        label = label.to(device)

        output = model(image, 1.0).view(-1)
        loss = loss_func(output, label)

        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for image,label in test_loader:
            image = image.to(device)
            label = label.to(device)

            output = model(image, 1.0).view(-1)
            predicted = (output.data > 0.5).long().view(-1)
            total += label.size(0)
            correct += (predicted == label.long()).sum().item()

        accuracy = 100 * correct / total
        print(f'Epoch {epoch}, Loss: {loss.item():.4f}, Accuracy: {accuracy:.2f}%')

    # 모델 가중치 저장
    if epoch % 5==0:
        torch.save(model.state_dict(), f'/model_{epoch}.pt')