In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import umap

from google.colab import drive
drive.mount('/content/drive')

# 감정 레이블 정의
emotion_labels = {
    "happy": 0, "sad": 1, "anger": 2, "surprise": 3, "disgust": 4, "neutral": 5, "fear": 6,
}

# UMAP 시각화 함수
def visualize_umap(features, labels, title):
    reducer = umap.UMAP(n_neighbors=30, min_dist=0.1, random_state=42)
    embedding = reducer.fit_transform(features)

    plt.figure(figsize=(10, 10))
    scatter = plt.scatter(embedding[:, 0], embedding[:, 1], c=labels, cmap='viridis')
    plt.legend(handles=scatter.legend_elements()[0], labels=list(emotion_labels.keys()), fontsize=18)
    plt.title(title, fontsize=27)
    plt.show()

# TFRecord 파일 경로와 특성 설명
tfrecord_path = '/content/drive/MyDrive/Validation/EWHA_Alexithymia/PLEASE_FINAL/YES_Z_train.tfrecord'
feature_description = {
    'image': tf.io.FixedLenFeature([], tf.string),
    'landmarks': tf.io.FixedLenFeature([468 * 3], tf.float32),
    'label': tf.io.FixedLenFeature([], tf.int64),
    'z_values': tf.io.FixedLenFeature([8], tf.float32)
}

# TFRecord 파일에서 데이터 로드 함수
def load_data(tfrecord_path, feature_description):
    def parse_tfrecord_fn(example):
        return tf.io.parse_single_example(example, feature_description)

    dataset = tf.data.TFRecordDataset(tfrecord_path).map(parse_tfrecord_fn)
    return dataset

# 데이터셋 로드
dataset = load_data(tfrecord_path, feature_description)

# 데이터 추출 및 변환
images, landmarks, z_values, labels = [], [], [], []
for record in dataset:
    image = tf.io.decode_jpeg(record['image']).numpy().flatten()  # 이미지를 1D 배열로 변환

    lm = record['landmarks'].numpy()
    zv = record['z_values'].numpy()
    label = record['label'].numpy()

    images.append(image)
    landmarks.append(lm)
    z_values.append(zv)
    labels.append(label)

# UMAP 시각화
visualize_umap(np.array(images), np.array(labels), 'Image Data UMAP Visualization')
visualize_umap(np.array(landmarks), np.array(labels), 'Landmark Data UMAP Visualization')
visualize_umap(np.array(z_values), np.array(labels), 'Z-Values Data UMAP Visualization')
