In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import EarlyStopping
from skimage.metrics import structural_similarity as ssim
import itertools

# 감정 레이블 정의
emotion_labels = {
    "happy": 0, "sad": 1, "anger": 2, "surprise": 3,
    "disgust": 4, "neutral": 5, "fear": 6,
}

# landmark indices 정의
landmark_indices = [70, 63, 105, 66, 107, 55, 65, 52, 53, 46, 300, 293, 334, 296, 336, 285, 295, 282, 283, 276, 33, 246, 161, 160, 159, 158, 157, 173, 133, 155, 154, 153, 145, 144, 163, 7, 263, 466, 388, 387, 386, 385, 384, 398, 362, 382, 381, 380, 374, 373, 390, 249, 78, 191, 80, 81, 82, 13, 312, 311, 310, 415, 308, 324, 318, 402, 317, 14, 87, 178, 88, 95, 61, 185, 40, 39, 37, 0, 267, 269, 270, 409, 291, 375, 321, 405, 314, 17, 84, 181, 91, 146, 64, 4, 294, 168, 6, 197, 195, 5, 1]

# landmark 추출 함수 정의
def extract_selected_landmarks(landmarks):
    selected_landmarks = tf.gather(landmarks, indices=landmark_indices, axis=0)
    return tf.reshape(selected_landmarks, [len(landmark_indices) * 3])

def calculate_z_value_combinations(z_values):
    z_combinations = list(itertools.combinations(range(8), 3))
    z_product_combinations = []

    for indices in z_combinations:
        selected_values = tf.gather(z_values, indices)
        product = tf.reduce_prod(selected_values)
        z_product_combinations.append(product)

    return tf.stack(z_product_combinations)

# TFRecord 파싱 함수
def parse_and_prepare_tfrecord(tfrecord):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "landmarks": tf.io.VarLenFeature(tf.float32),
        "label": tf.io.FixedLenFeature([], tf.int64),
        "z_values": tf.io.FixedLenFeature([8], tf.float32),  # 기본값 수정
    }
    parsed_features = tf.io.parse_single_example(tfrecord, features)

    image = tf.io.decode_jpeg(parsed_features["image"])
    image = tf.image.rgb_to_grayscale(image)
    image = tf.cast(image, tf.float32) * (2. / 255) - 1  # 정확도 높이기 위해 추가

    landmarks = tf.sparse.to_dense(parsed_features["landmarks"])
    landmarks = tf.reshape(landmarks, [-1, 3])
    selected_landmarks = extract_selected_landmarks(landmarks)
    selected_landmarks.set_shape([303])

    label = tf.one_hot(parsed_features["label"], depth=len(emotion_labels))

    z_values = parsed_features["z_values"]
    z_product_combinations = calculate_z_value_combinations(z_values)

    return (image, selected_landmarks, z_product_combinations), label

# 데이터셋 로드 및 준비 함수 정의
def load_and_prepare_dataset(tfrecord_file, batch_size=None):
    dataset = tf.data.TFRecordDataset(tfrecord_file)
    dataset = dataset.map(parse_and_prepare_tfrecord)
    if batch_size is not None:
        dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

# 모델 정의
input_dim = 303  # 랜드마크 데이터의 차원
image_shape = (48, 48, 1)  # 이미지 데이터의 형태
z_comb_dim = len(list(itertools.combinations(range(8), 3)))  # z-값 조합의 차원

# 데이터셋 로드 및 준비
train_tfrecord_file = "/content/drive/MyDrive/clean_0317/Dataset_clean/clean_YES_Z_train.tfrecord"
val_tfrecord_file = "/content/drive/MyDrive/clean_0317/Dataset_clean/clean_YES_Z_val.tfrecord"
test_tfrecord_file = "/content/drive/MyDrive/clean_0317/Dataset_clean/clean_YES_Z_test.tfrecord"

train_dataset = load_and_prepare_dataset(train_tfrecord_file, batch_size=16)
val_dataset = load_and_prepare_dataset(val_tfrecord_file, batch_size=16)
test_dataset = load_and_prepare_dataset(test_tfrecord_file, batch_size=16)

# 생성된 이미지를 시각화하는 함수 (10 에폭 당 사진 시각화)
def visualize_gen_images(epoch, generator, examples=10, dim=(1, 10), figsize=(10, 1)):
    noise = np.random.normal(0, 1, [examples, z_comb_dim])
    landmark_noise = np.random.normal(0, 1, [examples, input_dim])
    img_noise = np.random.normal(0, 1, [examples] + list(image_shape))
    generated_images = generator.predict([img_noise, landmark_noise, noise])

    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i, :, :, 0], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('gan1_0407_generated_image_epoch_%d.png' % epoch)
    plt.close()

# SSIM을 계산하는 함수 수정
def calculate_ssim(generator, test_dataset):
    ssim_values = []
    for ((images, landmarks, z_product_combinations), labels) in test_dataset:
        noise = np.random.normal(0, 1, [images.shape[0], z_comb_dim])
        generated_images = generator.predict([images, landmarks, noise])  # img_noise 대신 images 사용
        for i in range(images.shape[0]):
            original_image = images[i].numpy().squeeze()
            generated_image = generated_images[i].squeeze()

            ssim_value = ssim(original_image, generated_image, channel_axis=None, win_size=3)
            ssim_values.append(ssim_value)
    return np.mean(ssim_values)

# 모델 로드 후 수동 컴파일
generator_path = '/content/drive/MyDrive/clean_0317/gan1_clean_0407_generator_model.h5'
combined_model_path = '/content/drive/MyDrive/clean_0317/gan1_clean_0407_combined_model.h5'

generator = load_model(generator_path)
combined_model = load_model(combined_model_path)
combined_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Early stopping callback
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
epochs = 100

# 모델 학습
history = combined_model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=epochs,
    callbacks=[early_stopping]
)

# 시각화 함수를 이용하여 생성된 이미지 시각화 및 저장
for epoch in range(1, epochs + 1):
    if epoch % 10 == 0:
        visualize_gen_images(epoch, generator)

# SSIM을 이용한 목적변수 통계분석
ssim_value = calculate_ssim(generator, test_dataset)
print("테스트 데이터셋 SSIM 평균:", ssim_value)




Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
테스트 데이터셋 SSIM 평균: 4.441490425603398e-05
