In [3]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.models import Sequential
from tensorflow.keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
import os

# 데이터 로드 및 전처리
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train.reshape(-1, 28 * 28).astype('float32') / 255

# 생성자 모델
generator = Sequential([
    Dense(256, activation='relu', input_shape=(100,)),
    Dense(784, activation='sigmoid'),
    Reshape((28, 28))
])

# 판별자 모델
discriminator = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(256, activation='relu'),
    Dense(1, activation='sigmoid')
])

# GAN 모델
discriminator.compile(loss='binary_crossentropy', optimizer='adam')
discriminator.trainable = False

gan = Sequential([generator, discriminator])
gan.compile(loss='binary_crossentropy', optimizer='adam')

# 이미지 저장을 위한 폴더 생성
if not os.path.exists('./gan_images'):
    os.makedirs('./gan_images')

# 훈련
epochs = 10000
batch_size = 32

for epoch in range(epochs):
    # 임의의 노이즈에서 이미지 생성
    noise = np.random.normal(0, 1, (batch_size, 100))
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(batch_size, 28, 28)  # 생성된 이미지 차원 변경

    # 실제 이미지와 결합
    real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]
    real_images = real_images.reshape(batch_size, 28, 28)  # 실제 이미지 차원 변경
    x = np.concatenate([real_images, generated_images], axis=0)  # 두 배열 합치기

    # 레이블 생성
    y = np.zeros(2 * batch_size)
    y[:batch_size] = 1

    # 판별자 훈련
    discriminator.trainable = True
    discriminator.train_on_batch(x, y)

    # 생성자 훈련
    noise = np.random.normal(0, 1, (batch_size, 100))
    y2 = np.ones(batch_size)
    discriminator.trainable = False
    gan.train_on_batch(noise, y2)

    # 에포크마다 이미지 저장
    if epoch % 100 == 0:
        noise = np.random.normal(0, 1, (10, 100))
        generated_images = generator.predict(noise)
        generated_images = generated_images * 255
        generated_images = generated_images.astype('uint8')

        fig, axs = plt.subplots(1, 10, figsize=(10, 1))
        for i in range(10):
            axs[i].imshow(generated_images[i], cmap='gray')
            axs[i].axis('off')

        fig.savefig("./gan_images/generated_{}.png".format(epoch))
        plt.close()

































































































































