In [0]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [0]:
%cd /content/drive/My Drive/nice/mynice/mynice
!ls

/content/drive/My Drive/nice/mynice/mynice
data	  image    mynice.ipynb  __pycache__  tyuumoku.md
gimon.md  memo.py  nice.py	 train.py     utils.py


In [0]:
""""
NICEの基本設定
　- mnistに特化させる
　- 最後の層の確率分布は、logistic分布を仮定 


"https://github.com/fmu2/NICE"を参考に作った
"""

import torch, torchvision
import numpy as np
import torch.utils.data as data
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import nice
import utils



class args():
    batch_size = 200
    max_itr = 25000
    sample_size = 64# number of size to generate
    lr = 1e-3
    momentum = 0.9# adamの変数
    decay = 0.99#adamのベータの値

def main():
    # 諸々数値設定
    debug = False

    device = torch.device("cuda:0")
    batch_size = 200
    # latent = "logistic"
    max_itr = 25000
    sample_size = 64
    coupling = 4
    mask_config = 1
    num_workers = 8#num_worksは、データを読み込む時に使うスレッドの数

    lr = 1e-3
    momentum = 0.9
    decay = 0.99

    # 学習用もテスト用もテンソルに変換するだけ
    transform = transforms.Compose([
        transforms.ToTensor(), 
        # transforms.RandomHorizontalFlip()# 0.5の確率で左右反転
        ])

    trainset = torchvision.datasets.MNIST(root="data", train=True, transform=transform, download=True) #trainset = ((c,h,w),ラベル)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, shuffle=True)

    testset = torchvision.datasets.MNIST(root="data", train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=num_workers)# イテレータ

#     print(trainloader)

    (full_dim, mid_dim, hidden) = (1 * 28 * 28, 1000, 5)

    if debug:
        # showimg = trainset[0]
        print(trainset[0][0].size())
        showimg = trainset[0][0].numpy().reshape(28, 28)
        plt.imshow(showimg)
        plt.show()
        # plt.imshow(testset[0])

    prior = utils.logistic_distribution()# 事後分布を定義

    flow = nice.NICE(prior, in_out_dim=full_dim, mid_dim=mid_dim, 
                        num_coupling=coupling,hidden=hidden, mask_config=mask_config).to(device)

    optimizer = torch.optim.Adam(
                    flow.parameters(), lr=lr, betas=(momentum, decay), eps=1e-4)
    
    total_iter = 0
    train = True
    running_loss = 0

    while train:
        for _, data in enumerate(trainloader, 1):
            flow.train()  # 学習モード
            if total_iter >= max_itr:# 所定の回数学習したら
                train = False
                break

            total_iter += 1
            optimizer.zero_grad()# 勾配を初期化

#             print(data)

            input, _ = data# 上でも書いたが、二つ目の返り値は、ラベル

#             print(input)

            inputs = utils.prepare_data(input).to(device)# inputはデータローダ。テンソルにする。

            outputs, loss = flow(inputs)
            loss = -loss.mean()# 出力値(損失関数)のバッチ平均を取っている
            running_loss += float(loss)

            loss.backward()
            optimizer.step()
            
            if debug:
              test_rev = flow(outputs, reverse=True)
              print(inputs[0])
              print(test_rev[0])

            if total_iter % 1000 == 0:# 1000回イテレーションしたら
                mean_loss = running_loss / 1000
                bit_per_dim = (mean_loss + np.log(256.) * full_dim) \
                            / (full_dim * np.log(2.))
                print('iter %s:' % total_iter, 
                    'loss = %.3f' % mean_loss)
                running_loss = 0.0
                
                # 推論モード。よくわからんけど、学習に最適化された状態から、推論や生成に最適化された状態になる
                # 勾配のデータとかがなくなったり、メモリとかが最適化されたりする。
                # flow.eval()
                device_str = "cuda:0"
                with torch.no_grad():
                    log_dist = utils.logistic_distribution()
                    z, _ = flow(inputs)
                    testsample = flow(z, reverse=True)
                    testsample = utils.prepare_data(testsample, reverse=True)
#                     print(testsample)
                    torchvision.utils.save_image(torchvision.utils.make_grid(testsample),
                        './image/' + "testmnist" +'iter%d.png' % total_iter)
                    sample = log_dist.sample(size=(10, full_dim), device="cuda:0")
                    sample = flow(sample, reverse=True)
                    sample = utils.prepare_data(sample, reverse=True)
                    torchvision.utils.save_image(torchvision.utils.make_grid(sample),
                        './image/' + "mnist" +'iter%d.png' % total_iter)
                    
                    

            if debug:
                # 推論モード。よくわからんけど、学習に最適化された状態から、推論や生成に最適化された状態になる
                # 勾配のデータとかがなくなったり、メモリとかが最適化されたりする。
                flow.eval()
                with torch.no_grad():
                    log_dist = utils.logistic_distribution()
                    sample = log_dist.sample((10, full_dim), "cpu")
                    sample = flow(sample, reverse=True)
                    sample *= 255

                    print("生成完了")

In [0]:
main()

iter 1000: loss = 707.019
torch.Size([200, 784])
torch.Size([200])
torch.Size([10, 784])
torch.Size([10])
iter 2000: loss = -43.561
torch.Size([200, 784])
torch.Size([200])
torch.Size([10, 784])
torch.Size([10])
iter 3000: loss = -684.668
torch.Size([200, 784])
torch.Size([200])
torch.Size([10, 784])
torch.Size([10])
iter 4000: loss = -1146.738
torch.Size([200, 784])
torch.Size([200])
torch.Size([10, 784])
torch.Size([10])
iter 5000: loss = -1480.323
torch.Size([200, 784])
torch.Size([200])
torch.Size([10, 784])
torch.Size([10])
iter 6000: loss = -1729.545
torch.Size([200, 784])
torch.Size([200])
torch.Size([10, 784])
torch.Size([10])
iter 7000: loss = -1905.179
torch.Size([200, 784])
torch.Size([200])
torch.Size([10, 784])
torch.Size([10])
iter 8000: loss = -1971.315
torch.Size([200, 784])
torch.Size([200])
torch.Size([10, 784])
torch.Size([10])
iter 9000: loss = -1993.021
torch.Size([200, 784])
torch.Size([200])
torch.Size([10, 784])
torch.Size([10])
iter 10000: loss = -2009.674
torc