# 필체 스타일 트랜스퍼 GAN - Colab용 전체 코드 (PyTorch)

1. 라이브러리 로드

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import io, os
import zipfile
import random
from google.colab import files

2. 이미지 업로드 및 압축 해제

In [None]:
UPLOAD_DIR = "uploaded_data"
os.makedirs(UPLOAD_DIR, exist_ok=True)

def upload_and_extract_zip():
    uploaded = files.upload()
    for name, file in uploaded.items():
        if name.endswith(".zip"):
            with zipfile.ZipFile(io.BytesIO(file), 'r') as zip_ref:
                zip_ref.extractall(UPLOAD_DIR)
    print(f"압축 해제 완료: {UPLOAD_DIR}")


3. 이미지 전처리 함수

In [None]:
def preprocess(img):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor()
    ])
    return transform(img).unsqueeze(0)  # [1, 3, 64, 64]

4. 파일 로드 함수 (학습용, 테스트용으로 분리)

In [None]:
def load_dataset(split_ratio=0.8):
    all_files = [f for f in os.listdir(UPLOAD_DIR) if f.endswith(".png") or f.endswith(".jpg")]
    random.shuffle(all_files)
    split = int(len(all_files) * split_ratio)
    train_files = all_files[:split]
    test_files = all_files[split:]
    return train_files, test_files

5. 이미지 로딩 함수 (이름이 의미를 포함한다고 가정)

In [None]:
def load_image_by_name(filename):
    path = os.path.join(UPLOAD_DIR, filename)
    return Image.open(path).convert('RGB')

6. 네트워크 구성 (Encoder, Generator, Discriminator)

In [None]:
class Encoder(nn.Module):
    def __init__(self, output_dim=128):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        self.fc = nn.Linear(128, output_dim)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=128):
        super().__init__()
        self.fc = nn.Linear(latent_dim * 2, 128 * 8 * 8)
        self.deconv = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, style, content):
        x = torch.cat([style, content], dim=1)
        x = self.fc(x).view(-1, 128, 8, 8)
        return self.deconv(x)


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.conv(x)

7. 손실 함수 및 옵티마이저 정의

In [None]:
BCE = nn.BCELoss()
L1 = nn.L1Loss()

style_encoder = Encoder()
content_encoder = Encoder()
generator = Generator()
discriminator = Discriminator()

optim_G = torch.optim.Adam(list(generator.parameters()) + list(style_encoder.parameters()) + list(content_encoder.parameters()), lr=1e-4)
optim_D = torch.optim.Adam(discriminator.parameters(), lr=1e-4)


8. 학습 루프 (여러 이미지)

In [None]:
def train_step(style_img, content_vec):
    style_tensor = preprocess(style_img)
    content_tensor = torch.tensor(content_vec).float().unsqueeze(0)

    real_label = torch.ones(1, 1)
    fake_label = torch.zeros(1, 1)

    F = style_encoder(style_tensor)
    C = content_tensor

    G_out = generator(F, C)

    D_real = discriminator(style_tensor)
    D_fake = discriminator(G_out.detach())
    loss_D = BCE(D_real, real_label) + BCE(D_fake, fake_label)
    optim_D.zero_grad()
    loss_D.backward()
    optim_D.step()

    D_fake_for_G = discriminator(G_out)
    adv_loss = BCE(D_fake_for_G, real_label)
    style_loss = L1(style_encoder(G_out), F)
    loss_G = adv_loss + 5 * style_loss
    optim_G.zero_grad()
    loss_G.backward()
    optim_G.step()

    return G_out.detach(), loss_D.item(), loss_G.item()

9. 시각화


In [None]:
def show_image(tensor_img):
    img = tensor_img.squeeze().permute(1, 2, 0).numpy()
    plt.imshow(img)
    plt.axis('off')
    plt.show()

10. 실행

In [None]:
print("손글씨 zip 파일 업로드:")
upload_and_extract_zip()

train_files, test_files = load_dataset()
print(f"학습용 {len(train_files)}개, 테스트용 {len(test_files)}개")

d_loss, g_loss = 0.0, 0.0

def string_to_vector(s, dim=128):
    v = torch.zeros(dim)
    for i, c in enumerate(s.encode('utf-8')):
        v[i % dim] += c / 255.0
    return v

print("학습 시작")
num_epochs = 30
for epoch in range(num_epochs):
    random.shuffle(train_files)
    for fname in train_files:
        style_img = load_image_by_name(fname)
        label = os.path.splitext(fname)[0]  # 이름 기반 의미 벡터
        content_vec = string_to_vector(label)
        output, d_loss, g_loss = train_step(style_img, content_vec)
    print(f"Epoch {epoch+1}/{num_epochs} - D Loss: {d_loss:.4f}, G Loss: {g_loss:.4f}")

show_image(output)

print("전체 학습 완료")

손글씨 zip 파일 업로드:


TypeError: 'NoneType' object is not subscriptable