<a href="https://colab.research.google.com/github/tomonari-masada/course2021-stats2/blob/main/14_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 変分オートエンコーダの実践
* MNISTデータセットに対して変分オートエンコーダを適用してみる。

## 準備

In [None]:
import numpy as np
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import plotly.express as px

%config InlineBackend.figure_format = 'retina'

## PyTorchの準備
* PyTorchについての細かな説明は割愛します・・・。
 * https://github.com/pytorch/examples/blob/master/vae/main.py

In [None]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [None]:
torch.manual_seed(123)

* GPUが使えるときは使う。
 * ランタイムのタイプをGPUにしておく。

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
device

## MNISTデータを取得
* PyTorchに用意されている仕組みを使ってデータを取得し、学習に使える状態にする。

In [None]:
batch_size = 200

kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

## エンコーダとデコーダの実装

In [None]:
class VAE(nn.Module):
  def __init__(self, z_dim=20):
    super(VAE, self).__init__()
    # ここからエンコーダ
    self.fc1 = nn.Linear(784, 400)
    self.fc21 = nn.Linear(400, z_dim) # mean
    self.fc22 = nn.Linear(400, z_dim) # log var
    ## ここからデコーダ
    self.fc3 = nn.Linear(z_dim, 400)
    self.fc4 = nn.Linear(400, 784)

  def encode(self, x):
    h1 = F.relu(self.fc1(x))
    return self.fc21(h1), self.fc22(h1)

  def reparameterize(self, mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std

  def decode(self, z):
    h3 = F.relu(self.fc3(z))
    return torch.sigmoid(self.fc4(h3))

  def forward(self, x):
    mu, logvar = self.encode(x.view(-1, 784))
    z = self.reparameterize(mu, logvar)
    return self.decode(z), mu, logvar

## 学習の準備
* モデルのインスタンスを作成
* オプティマイザを作成

In [None]:
model = VAE(10).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

## 損失関数を定義
* ELBOにマイナスをつけたものの前半（データ尤度の項）と後半（KL情報量の項）

In [None]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
  BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

  # see Appendix B from VAE paper:
  # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
  # https://arxiv.org/abs/1312.6114
  # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
  KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

  return BCE + KLD

## 訓練データで学習を実行する関数

In [None]:
log_interval = 50

def train(epoch):
  model.train()
  train_loss = 0
  for batch_idx, (data, _) in enumerate(train_loader): # 訓練データ
    data = data.to(device) # データをGPUへ移動
    optimizer.zero_grad() # 勾配を初期化
    recon_batch, mu, logvar = model(data) # 前向き計算
    loss = loss_function(recon_batch, data, mu, logvar) # negative ELBOの計算
    loss.backward() # 勾配の計算
    train_loss += loss.item()
    optimizer.step() # パラメータの更新
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
          epoch, batch_idx * len(data), len(train_loader.dataset),
          100. * batch_idx / len(train_loader),
          loss.item() / len(data)))

  print('====> Epoch: {} Average loss: {:.4f}'.format(
      epoch, train_loss / len(train_loader.dataset)))

## テストデータ上での評価をおこなう関数

In [None]:
def test(epoch):
  model.eval()
  test_loss = 0
  with torch.no_grad(): # 計算グラフを作らない
    for i, (data, _) in enumerate(test_loader): # テストデータ
      data = data.to(device)
      recon_batch, mu, logvar = model(data)
      test_loss += loss_function(recon_batch, data, mu, logvar).item()
      if i == 0:
        n = min(data.size(0), 8)
        comparison = torch.cat([data[:n],
                                recon_batch.view(batch_size, 1, 28, 28)[:n]])
        save_image(comparison.cpu(),
                   'reconstruction_' + str(epoch) + '.png', nrow=n)

  test_loss /= len(test_loader.dataset)
  print('====> Test set loss: {:.4f}'.format(test_loss))

In [None]:
epochs = 10

for epoch in range(1, epochs + 1):
  train(epoch)
  test(epoch)
  with torch.no_grad():
    sample = torch.randn(64, 20).to(device)
    sample = model.decode(sample).cpu()
    save_image(sample.view(64, 1, 28, 28),
               'sample_' + str(epoch) + '.png')

## 全てのテストデータについて潜在表現を得る

In [None]:
means = list()
labels = list()
model.eval()
with torch.no_grad():
  for i, (data, labels_batch) in enumerate(test_loader): # テストデータ
    data = data.to(device)
    _, means_batch, _ = model(data)
    means.append(means_batch)
    labels.append(labels_batch)
labels = torch.cat(labels, 0).cpu().numpy()
print(labels.shape)
means = torch.cat(means, 0).cpu().numpy()
print(means.shape)

* https://plotly.com/python/pca-visualization/

## テストデータの潜在表現をPCAで可視化

In [None]:
pca = PCA(n_components=3)
components = pca.fit_transform(means)

total_var = pca.explained_variance_ratio_.sum() * 100

fig = px.scatter_3d(
    components, x=0, y=1, z=2, color=labels,
    title=f'Total Explained Variance: {total_var:.2f}%',
    labels={'0': 'PC 1', '1': 'PC 2', '2': 'PC 3'}
)
fig.update_layout(
    margin=dict(l=20, r=20, b=20, t=20),
    width=900,
    height=500
)
fig.show()

## FashionMNISTのテストデータだけを読み込む

In [None]:
batch_size = 200

kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
fashion_test_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST('../data', train=False, download=True,
                          transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

## MNISTで学習させたVAEを使ってFashionMNISTの全てのテストデータの潜在表現を得る

In [None]:
fashion_means = list()
fashion_labels = list()
model.eval()
with torch.no_grad():
  for i, (data, labels_batch) in enumerate(fashion_test_loader): # テストデータ
    data = data.to(device)
    _, means_batch, _ = model(data)
    fashion_means.append(means_batch)
    fashion_labels.append(labels_batch)
fashion_labels = torch.cat(fashion_labels, 0).cpu().numpy()
print(fashion_labels.shape)
fashion_means = torch.cat(fashion_means, 0).cpu().numpy()
print(fashion_means.shape)

### FashionMNISTのラベルはプラス10しておく

In [None]:
fashion_labels += 10

In [None]:
fashion_labels

## 両方のデータセットのテストデータの潜在表現を5000個ずつとって合併する

In [None]:
all_labels = np.concatenate([labels[:5000], fashion_labels[:5000]])
print(all_labels.shape)
all_means = np.concatenate([means[:5000], fashion_means[:5000]])
print(all_means.shape)

## 合併したベクトル集合をPCAで可視化

In [None]:
pca = PCA(n_components=3)
components = pca.fit_transform(all_means)

total_var = pca.explained_variance_ratio_.sum() * 100

fig = px.scatter_3d(
    components, x=0, y=1, z=2, color=all_labels,
    title=f'Total Explained Variance: {total_var:.2f}%',
    labels={'0': 'PC 1', '1': 'PC 2', '2': 'PC 3'}
)
fig.update_layout(
    margin=dict(l=20, r=20, b=20, t=20),
    width=900,
    height=500
)
fig.show()

## k-meansで二種類のテストセットの潜在表現をどのくらい綺麗に分離できるか調べる

In [None]:
kmeans = KMeans(n_clusters=2, random_state=0).fit(all_means)

In [None]:
kmeans.labels_

In [None]:
((1 - all_labels // 10) == kmeans.labels_).sum()