# Autoencoder for MNIST Compression and Latent Space Visualization
---
이 노트북에서는 MNIST 데이터셋을 Autoencoder로 학습하여 압축된 특징(latent space)을 추출하고 이를 시각화합니다.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.manifold import TSNE

## 1. 데이터 준비

In [1]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.astype("float32") / 255.
x_test = x_test.astype("float32") / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))

print("Train shape:", x_train.shape)
print("Test shape:", x_test.shape)

NameError: name 'tf' is not defined

## 2. Autoencoder 모델 정의

In [None]:
latent_dim = 64

# Encoder
input_img = layers.Input(shape=(28, 28, 1))
x = layers.Flatten()(input_img)
x = layers.Dense(128, activation="relu")(x)
latent = layers.Dense(latent_dim, activation="relu")(x)

# Decoder
x = layers.Dense(128, activation="relu")(latent)
x = layers.Dense(28*28, activation="sigmoid")(x)
decoded = layers.Reshape((28, 28, 1))(x)

autoencoder = models.Model(input_img, decoded)
encoder = models.Model(input_img, latent)

autoencoder.compile(optimizer="adam", loss="binary_crossentropy")
autoencoder.summary()

## 3. 모델 학습

In [None]:
history = autoencoder.fit(
    x_train, x_train,
    epochs=10,
    batch_size=256,
    shuffle=True,
    validation_data=(x_test, x_test)
)

## 4. Latent Space 추출

In [None]:
latent_vectors = encoder.predict(x_test)
print("Latent vectors shape:", latent_vectors.shape)

## 5. Latent Space 시각화 (t-SNE)

In [None]:
tsne = TSNE(n_components=2, random_state=42)
latent_tsne = tsne.fit_transform(latent_vectors[:2000])  # 일부 샘플만 사용

plt.figure(figsize=(10, 8))
scatter = plt.scatter(latent_tsne[:,0], latent_tsne[:,1], c=y_test[:2000], cmap="tab10", alpha=0.7)
plt.colorbar(scatter, ticks=range(10))
plt.title("t-SNE visualization of latent space")
plt.show()

## 6. 원본 이미지와 재구성 이미지 비교

In [None]:
decoded_imgs = autoencoder.predict(x_test)

n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
    # 원본
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28), cmap="gray")
    plt.title("Original")
    plt.axis("off")

    # 재구성
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28), cmap="gray")
    plt.title("Reconstructed")
    plt.axis("off")
plt.show()