In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional as F
from torchvision import datasets, transforms, models ##

import matplotlib.pyplot as plt

In [23]:
rng = np.random.RandomState(1234)
random_state = 42
batch_size = 100
if torch.cuda.is_available():
  device = 'cuda' 
else:
  devide = 'cpu'
n_epochs = 30
lr = 0.001
z_dim = 10

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

In [4]:
dataloader_train = torch.utils.data.DataLoader(
    datasets.MNIST('./data/MNIST', train=True, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=False
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/MNIST/raw



In [5]:
dataloader_valid = torch.utils.data.DataLoader(
    datasets.MNIST('./data/MNIST', train=False, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=False
)

In [6]:
def torch_log(x):
  return torch.log(torch.clamp(x, min=1e-10))

In [7]:
# VAEモデルの実装
class VAE(nn.Module):
    def __init__(self, z_dim):
        super(VAE, self).__init__()
        # Encoder, xを入力にガウス分布のパラメータmu, sigmaを出力
        self.dense_enc1 = nn.Linear(28*28, 200)
        self.dense_enc2 = nn.Linear(200, 200)
        self.dense_encmean = nn.Linear(200, z_dim)
        self.dense_encvar = nn.Linear(200, z_dim)
        # Decoder, zを入力にベルヌーイ分布のパラメータlambdaを出力
        self.dense_dec1 = nn.Linear(z_dim, 200)
        self.dense_dec2 = nn.Linear(200, 200)
        self.dense_dec3 = nn.Linear(200, 28*28)
    
    def _encoder(self, x):
        x = F.relu(self.dense_enc1(x))
        x = F.relu(self.dense_enc2(x))
        mean = self.dense_encmean(x)
        std = F.softplus(self.dense_encvar(x))
        return mean, std
    
    def _sample_z(self, mean, std):
        #再パラメータ化トリック
        epsilon = torch.randn(mean.shape).to(device)
        return mean + std * epsilon
 
    def _decoder(self, z):
        x = F.relu(self.dense_dec1(z))
        x = F.relu(self.dense_dec2(x))
        # 出力が0~1になるようにsigmoid
        x = torch.sigmoid(self.dense_dec3(x))
        return x

    def forward(self, x):
        mean, std = self._encoder(x)
        z = self._sample_z(mean, std)
        x = self._decoder(z)
        return x, z

    def loss(self, x):
        mean, std = self._encoder(x)
        # KL loss(正則化項)の計算. mean, stdは (batch_size , z_dim)
        # torch.sumは上式のJ(=z_dim)に関するもの. torch.meanはbatch_sizeに関するものなので,
        # 上式には書いてありません.
        KL = -0.5 * torch.mean(torch.sum(1 + torch_log(std**2) - mean**2 - std**2, dim=1))
    
        z = self._sample_z(mean, std)
        y = self._decoder(z)

        # reconstruction loss(負の再構成誤差)の計算. x, yともに (batch_size , 784)
        # torch.sumは上式のD(=784)に関するもの. torch.meanはbatch_sizeに関するもの.
        reconstruction = torch.mean(torch.sum(x * torch_log(y) + (1 - x) * torch_log(1 - y), dim=1))
        
        return KL, -reconstruction 

In [8]:
!pip3 install wandb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.13.5-py2.py3-none-any.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 7.2 MB/s 
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setproctitle
  Downloading setproctitle-1.3.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.10.1-py2.py3-none-any.whl (166 kB)
[K     |████████████████████████████████| 166 kB 34.7 MB/s 
Collecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.9-py3-none-any.whl (9.4 kB)
Collecting GitPython>=1.0.0
  Downloading GitPython-3.1.29-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 24.4 MB/s 
[?25hCollecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-4.0.9-

In [9]:
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [10]:
import wandb

In [22]:
hyperparams = {
    'epochs': n_epochs,
    'batch_size': batch_size,
    'lr': lr
}

In [12]:
wandb.init(config=hyperparams, project="VAE-221027", name='Adam')

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mramu13[0m. Use [1m`wandb login --relogin`[0m to force relogin


#ここからTraining


In [13]:
!git clone https://github.com/ramu13/KFAC-Pytorch.git

Cloning into 'KFAC-Pytorch'...
remote: Enumerating objects: 133, done.[K
remote: Counting objects: 100% (87/87), done.[K
remote: Compressing objects: 100% (59/59), done.[K
remote: Total 133 (delta 52), reused 54 (delta 28), pack-reused 46[K
Receiving objects: 100% (133/133), 48.22 KiB | 12.05 MiB/s, done.
Resolving deltas: 100% (62/62), done.


In [14]:
import sys
sys.path.append('/content/KFAC-Pytorch')

In [15]:
from optimizers import kfac

In [16]:
model = VAE(z_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [17]:
# training
for epoch in range(n_epochs * 3):
  losses = []
  KL_losses = []
  reconstruction_losses = []

  model.train()
  for x, _ in dataloader_train:
    x = x.to(device)
    model.zero_grad()
    KL_loss, reconstruction_loss = model.loss(x)
    loss = KL_loss + reconstruction_loss

    loss.backward()
    optimizer.step()

    losses.append(loss.cpu().detach().numpy())
    KL_losses.append(KL_loss.cpu().detach().numpy())
    reconstruction_losses.append(reconstruction_loss.cpu().detach().numpy())

    wandb.log({'epoch': epoch+1,
               'train_KL_loss': KL_loss.item(),
               'train_reconstruction_loss': reconstruction_loss.item(),
               'train_loss': loss.item()})

  
  # testing
  losses_val = []
  model.eval()
  with torch.no_grad():
    for (x, t) in dataloader_valid:
      x = x.to(device)
      KL_loss, reconstruction_loss = model.loss(x)
      loss = KL_loss + reconstruction_loss
      losses_val.append(loss.cpu().detach().numpy())
    print('EPOCH: %d    Train Lower Bound: %lf (KL_loss: %lf. reconstruction_loss: %lf)    Valid Lower Bound: %lf' %
          (epoch+1, np.average(losses), np.average(KL_losses), np.average(reconstruction_losses), np.average(losses_val)))
    
    wandb.log({'epoch': epoch+1,
               'valid_KL_loss': KL_loss.item(),
               'valid_reconstruction_loss': reconstruction_loss.item(),
               'valid_loss': loss.item()})

EPOCH: 1    Train Lower Bound: 187.695435 (KL_loss: 5.513043. reconstruction_loss: 182.182388)    Valid Lower Bound: 151.207916
EPOCH: 2    Train Lower Bound: 139.199783 (KL_loss: 11.091196. reconstruction_loss: 128.108582)    Valid Lower Bound: 135.177231
EPOCH: 3    Train Lower Bound: 128.665939 (KL_loss: 12.611670. reconstruction_loss: 116.054260)    Valid Lower Bound: 126.094315
EPOCH: 4    Train Lower Bound: 123.869164 (KL_loss: 13.162484. reconstruction_loss: 110.706680)    Valid Lower Bound: 122.605858
EPOCH: 5    Train Lower Bound: 121.096809 (KL_loss: 13.485538. reconstruction_loss: 107.611275)    Valid Lower Bound: 119.892914
EPOCH: 6    Train Lower Bound: 118.813202 (KL_loss: 13.944313. reconstruction_loss: 104.868889)    Valid Lower Bound: 117.763741
EPOCH: 7    Train Lower Bound: 116.462288 (KL_loss: 14.588029. reconstruction_loss: 101.874268)    Valid Lower Bound: 116.287498
EPOCH: 8    Train Lower Bound: 114.669769 (KL_loss: 14.911075. reconstruction_loss: 99.758690)    

# KFACによる訓練

In [18]:
hyperparams = {
    'epochs': n_epochs,
    'batch_size': batch_size,
    'lr': lr
}

In [19]:
wandb.init(config=hyperparams, project="VAE-221027", name='KFAC')

VBox(children=(Label(value='0.001 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.195071…

0,1
epoch,▁▁▁▁▁▁▂▂▃▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇███
train_KL_loss,▁▂▄▄▅▅▅▆▆▆▆▆▆▆▇▆▆▇▇▇▇▇▇▇▇▇▇█▇▇▇▇█▇██████
train_loss,█▆▅▄▄▃▂▃▂▂▃▂▂▂▃▂▂▂▂▂▂▂▂▂▁▂▂▂▂▁▁▁▂▁▁▁▁▂▁▁
train_reconstruction_loss,█▆▅▄▄▃▂▃▂▂▃▂▂▂▃▂▂▂▂▂▂▂▂▂▁▂▂▂▂▁▁▁▂▁▁▁▁▂▁▁
valid_KL_loss,▁▂▄▄▄▅▅▆▆▇▇▇▇▇█
valid_loss,█▅▄▃▃▂▂▂▂▁▁▁▁▁▁
valid_reconstruction_loss,█▅▄▃▃▃▂▂▂▁▁▁▁▁▁

0,1
epoch,15.0
train_KL_loss,16.98786
train_loss,108.54575
train_reconstruction_loss,91.55789
valid_KL_loss,16.7439
valid_loss,111.58166
valid_reconstruction_loss,94.83775


In [20]:
model = VAE(z_dim).to(device)
optimizer = kfac.KFACOptimizer(model, lr=lr)

VAE(
  (dense_enc1): Linear(in_features=784, out_features=200, bias=True)
  (dense_enc2): Linear(in_features=200, out_features=200, bias=True)
  (dense_encmean): Linear(in_features=200, out_features=10, bias=True)
  (dense_encvar): Linear(in_features=200, out_features=10, bias=True)
  (dense_dec1): Linear(in_features=10, out_features=200, bias=True)
  (dense_dec2): Linear(in_features=200, out_features=200, bias=True)
  (dense_dec3): Linear(in_features=200, out_features=784, bias=True)
)
=> We keep following layers in KFAC. 
(0): Linear(in_features=784, out_features=200, bias=True)
(1): Linear(in_features=200, out_features=200, bias=True)
(2): Linear(in_features=200, out_features=10, bias=True)
(3): Linear(in_features=200, out_features=10, bias=True)
(4): Linear(in_features=10, out_features=200, bias=True)
(5): Linear(in_features=200, out_features=200, bias=True)
(6): Linear(in_features=200, out_features=784, bias=True)


In [21]:
# training
for epoch in range(n_epochs * 3):
  losses = []
  KL_losses = []
  reconstruction_losses = []

  model.train()
  for x, _ in dataloader_train:
    x = x.to(device)
    model.zero_grad()
    KL_loss, reconstruction_loss = model.loss(x)
    loss = KL_loss + reconstruction_loss

    loss.backward()
    optimizer.step()

    losses.append(loss.cpu().detach().numpy())
    KL_losses.append(KL_loss.cpu().detach().numpy())
    reconstruction_losses.append(reconstruction_loss.cpu().detach().numpy())

    wandb.log({'epoch': epoch+1,
               'train_KL_loss': KL_loss.item(),
               'train_reconstruction_loss': reconstruction_loss.item(),
               'train_loss': loss.item()})

  
  # testing
  losses_val = []
  model.eval()
  with torch.no_grad():
    for (x, t) in dataloader_valid:
      x = x.to(device)
      KL_loss, reconstruction_loss = model.loss(x)
      loss = KL_loss + reconstruction_loss
      losses_val.append(loss.cpu().detach().numpy())
    print('EPOCH: %d    Train Lower Bound: %lf (KL_loss: %lf. reconstruction_loss: %lf)    Valid Lower Bound: %lf' %
          (epoch+1, np.average(losses), np.average(KL_losses), np.average(reconstruction_losses), np.average(losses_val)))
    
    wandb.log({'epoch': epoch+1,
               'valid_KL_loss': KL_loss.item(),
               'valid_reconstruction_loss': reconstruction_loss.item(),
               'valid_loss': loss.item()})

The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion.
L, _ = torch.symeig(A, upper=upper)
should be replaced with
L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')
and
L, V = torch.symeig(A, eigenvectors=True)
should be replaced with
L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:3029.)
  self.m_aa[m], eigenvectors=True)
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1174.)
  p.data.add_(-group['lr'], d_p)


EPOCH: 1    Train Lower Bound: 214.947632 (KL_loss: 10.005415. reconstruction_loss: 204.942215)    Valid Lower Bound: 139.109787
EPOCH: 2    Train Lower Bound: 122.845482 (KL_loss: 18.739027. reconstruction_loss: 104.106461)    Valid Lower Bound: 115.293121
EPOCH: 3    Train Lower Bound: 114.036743 (KL_loss: 18.530228. reconstruction_loss: 95.506516)    Valid Lower Bound: 112.131447
EPOCH: 4    Train Lower Bound: 111.694633 (KL_loss: 18.470825. reconstruction_loss: 93.223816)    Valid Lower Bound: 110.934525
EPOCH: 5    Train Lower Bound: 110.410767 (KL_loss: 18.462471. reconstruction_loss: 91.948296)    Valid Lower Bound: 109.998138
EPOCH: 6    Train Lower Bound: 120.940155 (KL_loss: 18.509882. reconstruction_loss: 102.430267)    Valid Lower Bound: 132.267334
EPOCH: 7    Train Lower Bound: 131.829346 (KL_loss: 18.530416. reconstruction_loss: 113.298935)    Valid Lower Bound: 131.887085
EPOCH: 8    Train Lower Bound: 131.308960 (KL_loss: 18.529598. reconstruction_loss: 112.779350)    V