In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten, Conv2D, Conv2DTranspose, LeakyReLU, BatchNormalization, Embedding, Input, Concatenate
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import matplotlib.pyplot as plt

# 데이터 불러오기
data_dir = '/content/drive/MyDrive/code/gan'
categories = ['angry', 'happy', 'panic', 'sadness']

# 이미지 로드 함수
def load_images(data_dir, categories):
    images = []
    labels = []
    for idx, category in enumerate(categories):
        category_dir = os.path.join(data_dir, category)
        for file in os.listdir(category_dir):
            if file.endswith('.png') or file.endswith('.jpg'):
                img_path = os.path.join(category_dir, file)
                img = load_img(img_path, target_size=(64, 64))
                img_array = img_to_array(img)
                images.append(img_array)
                labels.append(idx)
    images = np.array(images)
    labels = np.array(labels)
    return images, labels

images, labels = load_images(data_dir, categories)
images = (images - 127.5) / 127.5  # Normalize images to [-1, 1]

# GAN 모델 정의
def build_generator():
    noise = Input(shape=(100,))
    label = Input(shape=(1,), dtype='int32')

    label_embedding = Embedding(len(categories), 100)(label)
    label_embedding = Flatten()(label_embedding)

    model_input = Concatenate()([noise, label_embedding])

    x = Dense(256 * 8 * 8, activation="relu")(model_input)
    x = Reshape((8, 8, 256))(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Conv2DTranspose(128, kernel_size=4, strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Conv2DTranspose(64, kernel_size=4, strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    img = Conv2DTranspose(3, kernel_size=4, strides=2, padding='same', activation='tanh')(x)

    model = Model([noise, label], img)
    return model

def build_discriminator():
    img = Input(shape=(64, 64, 3))
    label = Input(shape=(1,), dtype='int32')

    label_embedding = Embedding(len(categories), np.prod((64, 64, 3)))(label)
    label_embedding = Flatten()(label_embedding)
    label_embedding = Reshape((64, 64, 3))(label_embedding)

    concatenated = Concatenate()([img, label_embedding])

    x = Conv2D(64, kernel_size=4, strides=2, padding='same')(concatenated)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(128, kernel_size=4, strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Flatten()(x)
    x = Dense(1, activation='sigmoid')(x)

    model = Model([img, label], x)
    return model

# Compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])

# Build and compile the combined model
generator = build_generator()
noise = Input(shape=(100,))
label = Input(shape=(1,))
img = generator([noise, label])
discriminator.trainable = False
validity = discriminator([img, label])
combined = Model([noise, label], validity)
combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

# GAN 학습
def train(epochs, batch_size=128, save_interval=50):
    half_batch = int(batch_size / 2)

    for epoch in range(epochs):
        idx = np.random.randint(0, images.shape[0], half_batch)
        imgs, labels_batch = images[idx], labels[idx].reshape(-1, 1)

        noise = np.random.normal(0, 1, (half_batch, 100))
        gen_labels = np.random.randint(0, len(categories), half_batch).reshape(-1, 1)
        gen_imgs = generator.predict([noise, gen_labels])

        d_loss_real = discriminator.train_on_batch([imgs, labels_batch], np.ones((half_batch, 1)))
        d_loss_fake = discriminator.train_on_batch([gen_imgs, gen_labels], np.zeros((half_batch, 1)))
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        noise = np.random.normal(0, 1, (batch_size, 100))
        sampled_labels = np.random.randint(0, len(categories), batch_size).reshape(-1, 1)
        valid_y = np.array([1] * batch_size)

        g_loss = combined.train_on_batch([noise, sampled_labels], valid_y)

        print(f"{epoch} [D loss: {d_loss[0]} | D accuracy: {d_loss[1]}] [G loss: {g_loss}]")

        if epoch % save_interval == 0:
            save_imgs(epoch)

def save_imgs(epoch):
    r, c = 2, 2
    noise = np.random.normal(0, 1, (r * c, 100))
    sampled_labels = np.array([num for num in range(r * c)]).reshape(-1, 1)
    gen_imgs = generator.predict([noise, sampled_labels])
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(gen_imgs[cnt, :, :, :])
            axs[i, j].set_title(categories[sampled_labels[cnt][0]])
            axs[i, j].axis('off')
            cnt += 1
    if not os.path.exists("gan_images"):
        os.makedirs("gan_images")
    fig.savefig(f"gan_images/gan_image_{epoch}.png")
    plt.close()

train(epochs=10000, batch_size=32, save_interval=200)

# 모델 저장
generator.save('/content/drive/MyDrive/code/gan_generator.h5')
discriminator.save('/content/drive/MyDrive/code/gan_discriminator.h5')
combined.save('/content/drive/MyDrive/code/gan_combined.h5')
