In [121]:
import torch #基本モジュール
from torch.autograd import Variable #自動微分用
import torch.nn as nn #ネットワーク構築用
import torch.optim as optim #最適化関数
import torch.nn.functional as F #ネットワーク用の様々な関数
import torch.utils.data #データセット読み込み関連
import torchvision #画像関連
from torchvision import datasets, models, transforms #画像用データセット諸々
from torchvision.utils import save_image

import argparse
import os
import numpy as np
import math

In [2]:
#画像の変形処理
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#CIFAR-10のtrain, testsetのロード
#変形はtransformを適用
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

#DataLoaderの適用->これによりバッチの割り当て・シャッフルをまとめて行うことができる
#batch_sizeでバッチサイズを指定
#num_workersでいくつのコアでデータをロードするか指定(デフォルトはメインのみ)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=16,
                                         shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [131]:
# Generator
# input [100] reshape?
# layer [256, 4, 4] conv
# layer [128, 8, 8] conv
# layer [64, 16, 16]
# output [3, 32, 32] tanh

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.pc1 = nn.Linear(100, 256 * 4 * 4)
        conv1 = nn.ConvTranspose2d(256, 128, 5, stride=2, padding=2, output_padding=1)
        conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1)
        conv3 = nn.ConvTranspose2d(64, 3, 5, stride=2, padding=2, output_padding=1)
        relu = nn.LeakyReLU()
        batch1 =  nn.BatchNorm2d(128)
        batch2 =  nn.BatchNorm2d(64)
        batch3 =  nn.BatchNorm2d(3)
        out = nn.Tanh()
        
        self.model = nn.Sequential(
            conv1, batch1, relu, conv2, batch2, relu, conv3, batch3, out
        )

    def forward(self, z):
        x = self.pc1(z).view(-1, 256, 4, 4)
        return self.model(x)

In [128]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        conv1 = nn.Conv2d(3, 3, 5, stride=2, padding=2)
        conv2 = nn.Conv2d(3, 3, 5, stride=2, padding=2)
        conv3 = nn.Conv2d(3, 3, 5, stride=2, padding=2)
        pc = nn.Conv2d(3, 1, 4)
        relu = nn.LeakyReLU()
        batch1 =  nn.BatchNorm2d(3)
        batch2 =  nn.BatchNorm2d(3)
        out = nn.Sigmoid()
        
        self.model = nn.Sequential(
            conv1, batch1, relu, conv2, batch2, relu, conv3, pc, out
        )

    def forward(self, img):
        return self.model(img).flatten()

In [132]:
#モデル定義
generator = Generator()
discriminator = Discriminator()

#Loss関数の指定
criterion = nn.BCELoss()

#Optimizerの指定
optimizer_g = optim.Adam(generator.parameters())
optimizer_d = optim.Adam(discriminator.parameters())

In [149]:
os.makedirs("images", exist_ok=True)

#トレーニング
#エポック数の指定
for epoch in range(1):  # loop over the dataset multiple times

    #データ全てのトータルロス
    running_loss = 0.0 


    for i, (inputs, _) in enumerate(trainloader):

        # Variableに変形
        # wrap them in Variable
        inputs = Variable(inputs, requires_grad=True)
        
        valid = Variable(torch.ones(16), requires_grad=False)
        fake = Variable(torch.zeros(16), requires_grad=False)
        
        #generator
        
        # optimizerの初期化
        # zero the parameter gradients
        optimizer_g.zero_grad()

        z = Variable(torch.randn(16, 100), requires_grad=True)
        
        #一連の流れ
        # forward + backward + optimize
        img = generator(z)

        #ここでラベルデータに対するCross-Entropyがとられる
        loss_g = criterion(discriminator(img), valid)
        loss_g.backward(retain_graph=True)
        optimizer_g.step()
        
        #discriminator
        optimizer_d.zero_grad()
        
        # optimizerの初期化
        # zero the parameter gradients

        #ここでラベルデータに対するCross-Entropyがとられる
        valid_loss = criterion(discriminator(inputs), valid)
        fake_loss = criterion(discriminator(img.clone()), fake)
        loss_d = (valid_loss + fake_loss) / 2
        loss_d.backward(retain_graph=True)
        optimizer_d.step()
        
        # ロスの表示
        # print statistics
        print("epoch %d, batch%d/%d, loss%f vs %f" 
              % (epoch, i, len(trainloader), loss_g.item(), loss_d.item()))
        if i % 100 == 0:
            save_image(img.data[:9], "images/%d.png" % (i + epoch * len(trainloader)),
                       nrow=3, normalize=True)

print('Finished Training')

epoch 0, batch0/3125, loss0.351315 vs 1.028077
epoch 0, batch1/3125, loss0.369206 vs 1.006213
epoch 0, batch2/3125, loss0.350297 vs 1.030429
epoch 0, batch3/3125, loss0.367223 vs 0.985919
epoch 0, batch4/3125, loss0.394143 vs 0.993026
epoch 0, batch5/3125, loss0.387170 vs 0.972744
epoch 0, batch6/3125, loss0.368839 vs 0.997858
epoch 0, batch7/3125, loss0.409399 vs 0.966508
epoch 0, batch8/3125, loss0.412364 vs 0.944446
epoch 0, batch9/3125, loss0.458398 vs 0.922321
epoch 0, batch10/3125, loss0.414565 vs 0.934955
epoch 0, batch11/3125, loss0.416813 vs 0.932097
epoch 0, batch12/3125, loss0.445249 vs 0.918775
epoch 0, batch13/3125, loss0.445315 vs 0.924134
epoch 0, batch14/3125, loss0.458750 vs 0.901384
epoch 0, batch15/3125, loss0.460268 vs 0.902660
epoch 0, batch16/3125, loss0.486586 vs 0.880614
epoch 0, batch17/3125, loss0.493724 vs 0.870169
epoch 0, batch18/3125, loss0.486443 vs 0.883592
epoch 0, batch19/3125, loss0.515550 vs 0.860184
epoch 0, batch20/3125, loss0.521310 vs 0.853265
ep

epoch 0, batch169/3125, loss0.727194 vs 0.692473
epoch 0, batch170/3125, loss0.722783 vs 0.703479
epoch 0, batch171/3125, loss0.718220 vs 0.695258
epoch 0, batch172/3125, loss0.713538 vs 0.698288
epoch 0, batch173/3125, loss0.707654 vs 0.709992
epoch 0, batch174/3125, loss0.701174 vs 0.703226
epoch 0, batch175/3125, loss0.693240 vs 0.713136
epoch 0, batch176/3125, loss0.688583 vs 0.713919
epoch 0, batch177/3125, loss0.678131 vs 0.723100
epoch 0, batch178/3125, loss0.668854 vs 0.719102
epoch 0, batch179/3125, loss0.659923 vs 0.723920
epoch 0, batch180/3125, loss0.652393 vs 0.718799
epoch 0, batch181/3125, loss0.641119 vs 0.734658
epoch 0, batch182/3125, loss0.635338 vs 0.736421
epoch 0, batch183/3125, loss0.633072 vs 0.742476
epoch 0, batch184/3125, loss0.632770 vs 0.738190
epoch 0, batch185/3125, loss0.634247 vs 0.747009
epoch 0, batch186/3125, loss0.630301 vs 0.739177
epoch 0, batch187/3125, loss0.629357 vs 0.751808
epoch 0, batch188/3125, loss0.632941 vs 0.734642
epoch 0, batch189/31

epoch 0, batch336/3125, loss0.727612 vs 0.668810
epoch 0, batch337/3125, loss0.725739 vs 0.678604
epoch 0, batch338/3125, loss0.722607 vs 0.672090
epoch 0, batch339/3125, loss0.719462 vs 0.679131
epoch 0, batch340/3125, loss0.709117 vs 0.679688
epoch 0, batch341/3125, loss0.710622 vs 0.683856
epoch 0, batch342/3125, loss0.698518 vs 0.691499
epoch 0, batch343/3125, loss0.700238 vs 0.697086
epoch 0, batch344/3125, loss0.691933 vs 0.680220
epoch 0, batch345/3125, loss0.685243 vs 0.692387
epoch 0, batch346/3125, loss0.686278 vs 0.687990
epoch 0, batch347/3125, loss0.681989 vs 0.686250
epoch 0, batch348/3125, loss0.674850 vs 0.687187
epoch 0, batch349/3125, loss0.671148 vs 0.692507
epoch 0, batch350/3125, loss0.666930 vs 0.698576
epoch 0, batch351/3125, loss0.668789 vs 0.690871
epoch 0, batch352/3125, loss0.667391 vs 0.688939
epoch 0, batch353/3125, loss0.667970 vs 0.692920
epoch 0, batch354/3125, loss0.670497 vs 0.709496
epoch 0, batch355/3125, loss0.667954 vs 0.682828
epoch 0, batch356/31

epoch 0, batch503/3125, loss0.616125 vs 0.754467
epoch 0, batch504/3125, loss0.617960 vs 0.750324
epoch 0, batch505/3125, loss0.616274 vs 0.769526
epoch 0, batch506/3125, loss0.622403 vs 0.785701
epoch 0, batch507/3125, loss0.610477 vs 0.776591
epoch 0, batch508/3125, loss0.620184 vs 0.754143
epoch 0, batch509/3125, loss0.618746 vs 0.733567
epoch 0, batch510/3125, loss0.616921 vs 0.764644
epoch 0, batch511/3125, loss0.617928 vs 0.756597
epoch 0, batch512/3125, loss0.621627 vs 0.747213
epoch 0, batch513/3125, loss0.616730 vs 0.749250
epoch 0, batch514/3125, loss0.624305 vs 0.756073
epoch 0, batch515/3125, loss0.624351 vs 0.744381
epoch 0, batch516/3125, loss0.620777 vs 0.751109
epoch 0, batch517/3125, loss0.625525 vs 0.751195
epoch 0, batch518/3125, loss0.626569 vs 0.741080
epoch 0, batch519/3125, loss0.628033 vs 0.739348
epoch 0, batch520/3125, loss0.633797 vs 0.743570
epoch 0, batch521/3125, loss0.641131 vs 0.736515
epoch 0, batch522/3125, loss0.645296 vs 0.738671
epoch 0, batch523/31

epoch 0, batch670/3125, loss0.873734 vs 0.602039
epoch 0, batch671/3125, loss0.876556 vs 0.613187
epoch 0, batch672/3125, loss0.879726 vs 0.618883
epoch 0, batch673/3125, loss0.875889 vs 0.619223
epoch 0, batch674/3125, loss0.879724 vs 0.593682
epoch 0, batch675/3125, loss0.852143 vs 0.621951
epoch 0, batch676/3125, loss0.854031 vs 0.628927
epoch 0, batch677/3125, loss0.854962 vs 0.630547
epoch 0, batch678/3125, loss0.856077 vs 0.602278
epoch 0, batch679/3125, loss0.854980 vs 0.609248
epoch 0, batch680/3125, loss0.846259 vs 0.633608
epoch 0, batch681/3125, loss0.835921 vs 0.627037
epoch 0, batch682/3125, loss0.835920 vs 0.603821
epoch 0, batch683/3125, loss0.828570 vs 0.638448
epoch 0, batch684/3125, loss0.816882 vs 0.637848
epoch 0, batch685/3125, loss0.811159 vs 0.623969
epoch 0, batch686/3125, loss0.799242 vs 0.661276
epoch 0, batch687/3125, loss0.781074 vs 0.659362
epoch 0, batch688/3125, loss0.762626 vs 0.640435
epoch 0, batch689/3125, loss0.752848 vs 0.690076
epoch 0, batch690/31

epoch 0, batch837/3125, loss0.562739 vs 0.758953
epoch 0, batch838/3125, loss0.580777 vs 0.771114
epoch 0, batch839/3125, loss0.589790 vs 0.772877
epoch 0, batch840/3125, loss0.604410 vs 0.746437
epoch 0, batch841/3125, loss0.617985 vs 0.749029
epoch 0, batch842/3125, loss0.631235 vs 0.735968
epoch 0, batch843/3125, loss0.641826 vs 0.742466
epoch 0, batch844/3125, loss0.651029 vs 0.723550
epoch 0, batch845/3125, loss0.668165 vs 0.729784
epoch 0, batch846/3125, loss0.675598 vs 0.723440
epoch 0, batch847/3125, loss0.683447 vs 0.717912
epoch 0, batch848/3125, loss0.698117 vs 0.722677
epoch 0, batch849/3125, loss0.709967 vs 0.702456
epoch 0, batch850/3125, loss0.722660 vs 0.708402
epoch 0, batch851/3125, loss0.733836 vs 0.703364
epoch 0, batch852/3125, loss0.745275 vs 0.700453
epoch 0, batch853/3125, loss0.751785 vs 0.682356
epoch 0, batch854/3125, loss0.760481 vs 0.712287
epoch 0, batch855/3125, loss0.768226 vs 0.694195
epoch 0, batch856/3125, loss0.776745 vs 0.692823
epoch 0, batch857/31

epoch 0, batch1004/3125, loss0.658488 vs 0.695910
epoch 0, batch1005/3125, loss0.656453 vs 0.697136
epoch 0, batch1006/3125, loss0.656863 vs 0.699016
epoch 0, batch1007/3125, loss0.656617 vs 0.689900
epoch 0, batch1008/3125, loss0.655782 vs 0.691795
epoch 0, batch1009/3125, loss0.655264 vs 0.698885
epoch 0, batch1010/3125, loss0.655400 vs 0.710662
epoch 0, batch1011/3125, loss0.654945 vs 0.695100
epoch 0, batch1012/3125, loss0.654166 vs 0.693339
epoch 0, batch1013/3125, loss0.653848 vs 0.697614
epoch 0, batch1014/3125, loss0.655178 vs 0.700433
epoch 0, batch1015/3125, loss0.653128 vs 0.696547
epoch 0, batch1016/3125, loss0.651823 vs 0.698709
epoch 0, batch1017/3125, loss0.651321 vs 0.703621
epoch 0, batch1018/3125, loss0.649552 vs 0.702234
epoch 0, batch1019/3125, loss0.649226 vs 0.706136
epoch 0, batch1020/3125, loss0.649584 vs 0.701922
epoch 0, batch1021/3125, loss0.648416 vs 0.702965
epoch 0, batch1022/3125, loss0.648264 vs 0.709415
epoch 0, batch1023/3125, loss0.647926 vs 0.697988


epoch 0, batch1167/3125, loss0.683243 vs 0.695996
epoch 0, batch1168/3125, loss0.684510 vs 0.698816
epoch 0, batch1169/3125, loss0.683345 vs 0.696856
epoch 0, batch1170/3125, loss0.686377 vs 0.694806
epoch 0, batch1171/3125, loss0.690341 vs 0.698029
epoch 0, batch1172/3125, loss0.690001 vs 0.702957
epoch 0, batch1173/3125, loss0.692164 vs 0.693830
epoch 0, batch1174/3125, loss0.692800 vs 0.692767
epoch 0, batch1175/3125, loss0.695077 vs 0.689118
epoch 0, batch1176/3125, loss0.695695 vs 0.689545
epoch 0, batch1177/3125, loss0.698347 vs 0.692364
epoch 0, batch1178/3125, loss0.699806 vs 0.690811
epoch 0, batch1179/3125, loss0.699487 vs 0.694152
epoch 0, batch1180/3125, loss0.699178 vs 0.688630
epoch 0, batch1181/3125, loss0.701934 vs 0.688067
epoch 0, batch1182/3125, loss0.701504 vs 0.688391
epoch 0, batch1183/3125, loss0.701974 vs 0.691896
epoch 0, batch1184/3125, loss0.702047 vs 0.689286
epoch 0, batch1185/3125, loss0.703447 vs 0.689653
epoch 0, batch1186/3125, loss0.705093 vs 0.688876


epoch 0, batch1330/3125, loss0.649027 vs 0.696290
epoch 0, batch1331/3125, loss0.645612 vs 0.698823
epoch 0, batch1332/3125, loss0.642898 vs 0.691755
epoch 0, batch1333/3125, loss0.638432 vs 0.697426
epoch 0, batch1334/3125, loss0.637343 vs 0.694197
epoch 0, batch1335/3125, loss0.631594 vs 0.703834
epoch 0, batch1336/3125, loss0.629014 vs 0.706536
epoch 0, batch1337/3125, loss0.624333 vs 0.709349
epoch 0, batch1338/3125, loss0.621392 vs 0.701724
epoch 0, batch1339/3125, loss0.619854 vs 0.710807
epoch 0, batch1340/3125, loss0.617218 vs 0.713360
epoch 0, batch1341/3125, loss0.615155 vs 0.713093
epoch 0, batch1342/3125, loss0.614470 vs 0.713004
epoch 0, batch1343/3125, loss0.612798 vs 0.719730
epoch 0, batch1344/3125, loss0.611635 vs 0.721325
epoch 0, batch1345/3125, loss0.612531 vs 0.716045
epoch 0, batch1346/3125, loss0.613998 vs 0.719832
epoch 0, batch1347/3125, loss0.615596 vs 0.719686
epoch 0, batch1348/3125, loss0.617751 vs 0.723676
epoch 0, batch1349/3125, loss0.620776 vs 0.715079


epoch 0, batch1493/3125, loss0.673396 vs 0.698233
epoch 0, batch1494/3125, loss0.673061 vs 0.697635
epoch 0, batch1495/3125, loss0.673079 vs 0.698134
epoch 0, batch1496/3125, loss0.673123 vs 0.698090
epoch 0, batch1497/3125, loss0.673224 vs 0.697243
epoch 0, batch1498/3125, loss0.673204 vs 0.698309
epoch 0, batch1499/3125, loss0.673240 vs 0.697255
epoch 0, batch1500/3125, loss0.673196 vs 0.699462
epoch 0, batch1501/3125, loss0.673506 vs 0.698599
epoch 0, batch1502/3125, loss0.673491 vs 0.697249
epoch 0, batch1503/3125, loss0.673897 vs 0.698262
epoch 0, batch1504/3125, loss0.673714 vs 0.698082
epoch 0, batch1505/3125, loss0.674008 vs 0.695885
epoch 0, batch1506/3125, loss0.674006 vs 0.696610
epoch 0, batch1507/3125, loss0.674456 vs 0.697308
epoch 0, batch1508/3125, loss0.674702 vs 0.696199
epoch 0, batch1509/3125, loss0.674937 vs 0.695662
epoch 0, batch1510/3125, loss0.675293 vs 0.696237
epoch 0, batch1511/3125, loss0.675317 vs 0.695083
epoch 0, batch1512/3125, loss0.675616 vs 0.694928


epoch 0, batch1656/3125, loss0.726984 vs 0.685105
epoch 0, batch1657/3125, loss0.726991 vs 0.687924
epoch 0, batch1658/3125, loss0.726327 vs 0.687989
epoch 0, batch1659/3125, loss0.726717 vs 0.684669
epoch 0, batch1660/3125, loss0.726258 vs 0.686509
epoch 0, batch1661/3125, loss0.726215 vs 0.685492
epoch 0, batch1662/3125, loss0.726046 vs 0.686899
epoch 0, batch1663/3125, loss0.725745 vs 0.686410
epoch 0, batch1664/3125, loss0.725660 vs 0.686392
epoch 0, batch1665/3125, loss0.725009 vs 0.685047
epoch 0, batch1666/3125, loss0.724952 vs 0.688154
epoch 0, batch1667/3125, loss0.724693 vs 0.680991
epoch 0, batch1668/3125, loss0.724764 vs 0.684649
epoch 0, batch1669/3125, loss0.724347 vs 0.686098
epoch 0, batch1670/3125, loss0.723984 vs 0.689513
epoch 0, batch1671/3125, loss0.723869 vs 0.687787
epoch 0, batch1672/3125, loss0.723325 vs 0.687152
epoch 0, batch1673/3125, loss0.722727 vs 0.685314
epoch 0, batch1674/3125, loss0.722165 vs 0.684502
epoch 0, batch1675/3125, loss0.721478 vs 0.686661


epoch 0, batch1819/3125, loss0.681280 vs 0.684242
epoch 0, batch1820/3125, loss0.680291 vs 0.683715
epoch 0, batch1821/3125, loss0.678675 vs 0.684205
epoch 0, batch1822/3125, loss0.678135 vs 0.685081
epoch 0, batch1823/3125, loss0.676070 vs 0.687911
epoch 0, batch1824/3125, loss0.675564 vs 0.687828
epoch 0, batch1825/3125, loss0.673461 vs 0.685797
epoch 0, batch1826/3125, loss0.671102 vs 0.691071
epoch 0, batch1827/3125, loss0.670116 vs 0.692520
epoch 0, batch1828/3125, loss0.668146 vs 0.694975
epoch 0, batch1829/3125, loss0.666860 vs 0.686930
epoch 0, batch1830/3125, loss0.666075 vs 0.690137
epoch 0, batch1831/3125, loss0.664441 vs 0.690781
epoch 0, batch1832/3125, loss0.664157 vs 0.689784
epoch 0, batch1833/3125, loss0.662460 vs 0.699289
epoch 0, batch1834/3125, loss0.661228 vs 0.693533
epoch 0, batch1835/3125, loss0.660194 vs 0.696986
epoch 0, batch1836/3125, loss0.660995 vs 0.695308
epoch 0, batch1837/3125, loss0.660214 vs 0.695380
epoch 0, batch1838/3125, loss0.660432 vs 0.696287


epoch 0, batch1982/3125, loss0.692316 vs 0.690076
epoch 0, batch1983/3125, loss0.692716 vs 0.689922
epoch 0, batch1984/3125, loss0.693100 vs 0.689534
epoch 0, batch1985/3125, loss0.693512 vs 0.689100
epoch 0, batch1986/3125, loss0.693824 vs 0.689991
epoch 0, batch1987/3125, loss0.694141 vs 0.690884
epoch 0, batch1988/3125, loss0.694575 vs 0.687850
epoch 0, batch1989/3125, loss0.695013 vs 0.689398
epoch 0, batch1990/3125, loss0.695309 vs 0.688155
epoch 0, batch1991/3125, loss0.695553 vs 0.690697
epoch 0, batch1992/3125, loss0.695948 vs 0.689236
epoch 0, batch1993/3125, loss0.696091 vs 0.689770
epoch 0, batch1994/3125, loss0.696468 vs 0.688728
epoch 0, batch1995/3125, loss0.696726 vs 0.689305
epoch 0, batch1996/3125, loss0.696998 vs 0.689138
epoch 0, batch1997/3125, loss0.697087 vs 0.689379
epoch 0, batch1998/3125, loss0.697396 vs 0.690386
epoch 0, batch1999/3125, loss0.697458 vs 0.689024
epoch 0, batch2000/3125, loss0.697858 vs 0.689440
epoch 0, batch2001/3125, loss0.697733 vs 0.688363


epoch 0, batch2145/3125, loss0.694782 vs 0.690295
epoch 0, batch2146/3125, loss0.694764 vs 0.690487
epoch 0, batch2147/3125, loss0.694855 vs 0.689764
epoch 0, batch2148/3125, loss0.694845 vs 0.690092
epoch 0, batch2149/3125, loss0.694928 vs 0.690483
epoch 0, batch2150/3125, loss0.694927 vs 0.690381
epoch 0, batch2151/3125, loss0.694994 vs 0.691025
epoch 0, batch2152/3125, loss0.695073 vs 0.690037
epoch 0, batch2153/3125, loss0.695104 vs 0.690529
epoch 0, batch2154/3125, loss0.695052 vs 0.690612
epoch 0, batch2155/3125, loss0.695162 vs 0.692249
epoch 0, batch2156/3125, loss0.695151 vs 0.690889
epoch 0, batch2157/3125, loss0.695095 vs 0.689889
epoch 0, batch2158/3125, loss0.695164 vs 0.690568
epoch 0, batch2159/3125, loss0.695185 vs 0.690026
epoch 0, batch2160/3125, loss0.695200 vs 0.690629
epoch 0, batch2161/3125, loss0.695158 vs 0.691191
epoch 0, batch2162/3125, loss0.695107 vs 0.689779
epoch 0, batch2163/3125, loss0.695156 vs 0.689876
epoch 0, batch2164/3125, loss0.695100 vs 0.690134


epoch 0, batch2308/3125, loss0.699244 vs 0.692692
epoch 0, batch2309/3125, loss0.697712 vs 0.693018
epoch 0, batch2310/3125, loss0.697847 vs 0.693460
epoch 0, batch2311/3125, loss0.697145 vs 0.693123
epoch 0, batch2312/3125, loss0.696171 vs 0.695256
epoch 0, batch2313/3125, loss0.695800 vs 0.693953
epoch 0, batch2314/3125, loss0.694835 vs 0.693844
epoch 0, batch2315/3125, loss0.694694 vs 0.694981
epoch 0, batch2316/3125, loss0.693762 vs 0.695994
epoch 0, batch2317/3125, loss0.693130 vs 0.695127
epoch 0, batch2318/3125, loss0.692437 vs 0.699352
epoch 0, batch2319/3125, loss0.691196 vs 0.695850
epoch 0, batch2320/3125, loss0.690461 vs 0.695140
epoch 0, batch2321/3125, loss0.689701 vs 0.702652
epoch 0, batch2322/3125, loss0.689247 vs 0.699389
epoch 0, batch2323/3125, loss0.689064 vs 0.697508
epoch 0, batch2324/3125, loss0.688790 vs 0.696901
epoch 0, batch2325/3125, loss0.688876 vs 0.696262
epoch 0, batch2326/3125, loss0.688877 vs 0.696733
epoch 0, batch2327/3125, loss0.688981 vs 0.695946


epoch 0, batch2471/3125, loss0.692366 vs 0.693482
epoch 0, batch2472/3125, loss0.692415 vs 0.693741
epoch 0, batch2473/3125, loss0.692418 vs 0.694042
epoch 0, batch2474/3125, loss0.692504 vs 0.698219
epoch 0, batch2475/3125, loss0.692473 vs 0.693047
epoch 0, batch2476/3125, loss0.692502 vs 0.693077
epoch 0, batch2477/3125, loss0.692531 vs 0.693145
epoch 0, batch2478/3125, loss0.692612 vs 0.694692
epoch 0, batch2479/3125, loss0.692606 vs 0.693577
epoch 0, batch2480/3125, loss0.692663 vs 0.693421
epoch 0, batch2481/3125, loss0.692644 vs 0.693208
epoch 0, batch2482/3125, loss0.692623 vs 0.692878
epoch 0, batch2483/3125, loss0.692577 vs 0.694200
epoch 0, batch2484/3125, loss0.692503 vs 0.694179
epoch 0, batch2485/3125, loss0.692550 vs 0.693489
epoch 0, batch2486/3125, loss0.692499 vs 0.693329
epoch 0, batch2487/3125, loss0.692370 vs 0.692639
epoch 0, batch2488/3125, loss0.692405 vs 0.693587
epoch 0, batch2489/3125, loss0.692424 vs 0.693756
epoch 0, batch2490/3125, loss0.692491 vs 0.692485


epoch 0, batch2634/3125, loss0.681924 vs 0.701023
epoch 0, batch2635/3125, loss0.682045 vs 0.693483
epoch 0, batch2636/3125, loss0.682597 vs 0.695980
epoch 0, batch2637/3125, loss0.682841 vs 0.701755
epoch 0, batch2638/3125, loss0.683189 vs 0.693839
epoch 0, batch2639/3125, loss0.683674 vs 0.693567
epoch 0, batch2640/3125, loss0.684183 vs 0.693681
epoch 0, batch2641/3125, loss0.684617 vs 0.693272
epoch 0, batch2642/3125, loss0.685013 vs 0.693158
epoch 0, batch2643/3125, loss0.685535 vs 0.693328
epoch 0, batch2644/3125, loss0.685745 vs 0.694790
epoch 0, batch2645/3125, loss0.686336 vs 0.693247
epoch 0, batch2646/3125, loss0.686804 vs 0.692768
epoch 0, batch2647/3125, loss0.687088 vs 0.692083
epoch 0, batch2648/3125, loss0.687405 vs 0.692578
epoch 0, batch2649/3125, loss0.687646 vs 0.691466
epoch 0, batch2650/3125, loss0.688160 vs 0.693683
epoch 0, batch2651/3125, loss0.688370 vs 0.693688
epoch 0, batch2652/3125, loss0.688694 vs 0.693160
epoch 0, batch2653/3125, loss0.688803 vs 0.692319


epoch 0, batch2797/3125, loss0.711499 vs 0.687614
epoch 0, batch2798/3125, loss0.711470 vs 0.686093
epoch 0, batch2799/3125, loss0.711464 vs 0.686079
epoch 0, batch2800/3125, loss0.711442 vs 0.686377
epoch 0, batch2801/3125, loss0.711499 vs 0.684799
epoch 0, batch2802/3125, loss0.711375 vs 0.688957
epoch 0, batch2803/3125, loss0.710979 vs 0.686188
epoch 0, batch2804/3125, loss0.710614 vs 0.687604
epoch 0, batch2805/3125, loss0.710079 vs 0.686400
epoch 0, batch2806/3125, loss0.709813 vs 0.686340
epoch 0, batch2807/3125, loss0.709364 vs 0.689810
epoch 0, batch2808/3125, loss0.708752 vs 0.683850
epoch 0, batch2809/3125, loss0.708340 vs 0.688935
epoch 0, batch2810/3125, loss0.707619 vs 0.688053
epoch 0, batch2811/3125, loss0.706883 vs 0.689843
epoch 0, batch2812/3125, loss0.706676 vs 0.687411
epoch 0, batch2813/3125, loss0.706017 vs 0.689182
epoch 0, batch2814/3125, loss0.706436 vs 0.687455
epoch 0, batch2815/3125, loss0.705642 vs 0.688503
epoch 0, batch2816/3125, loss0.704919 vs 0.690468


epoch 0, batch2960/3125, loss0.688240 vs 0.697645
epoch 0, batch2961/3125, loss0.695657 vs 0.700063
epoch 0, batch2962/3125, loss0.702104 vs 0.691762
epoch 0, batch2963/3125, loss0.706872 vs 0.698296
epoch 0, batch2964/3125, loss0.715451 vs 0.694947
epoch 0, batch2965/3125, loss0.721937 vs 0.689383
epoch 0, batch2966/3125, loss0.724801 vs 0.695927
epoch 0, batch2967/3125, loss0.724783 vs 0.706397
epoch 0, batch2968/3125, loss0.724559 vs 0.695505
epoch 0, batch2969/3125, loss0.722271 vs 0.701122
epoch 0, batch2970/3125, loss0.720421 vs 0.708512
epoch 0, batch2971/3125, loss0.718703 vs 0.701395
epoch 0, batch2972/3125, loss0.717641 vs 0.702318
epoch 0, batch2973/3125, loss0.715747 vs 0.721539
epoch 0, batch2974/3125, loss0.715152 vs 0.715133
epoch 0, batch2975/3125, loss0.716236 vs 0.713421
epoch 0, batch2976/3125, loss0.718520 vs 0.702444
epoch 0, batch2977/3125, loss0.722896 vs 0.711898
epoch 0, batch2978/3125, loss0.723846 vs 0.708758
epoch 0, batch2979/3125, loss0.725074 vs 0.705554


epoch 0, batch3123/3125, loss0.692660 vs 0.694355
epoch 0, batch3124/3125, loss0.692380 vs 0.692529
Finished Training


## debug

In [73]:

input = torch.rand(16, 100)
# input = torch.rand(16, 256, 4, 4)
# pc1 = nn.Linear(100, 256)
# pc2 = nn.Linear(256, 256 * 4)
# pc3 = nn.Linear(256 * 4, 256 * 4 * 4)
pc1 = nn.Linear(100, 256 * 4 * 4)

conv1 = nn.ConvTranspose2d(256, 128, 5, stride=2, padding=2, output_padding=1)
conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2, padding=2, output_padding=1)
# conv2 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
conv3 = nn.ConvTranspose2d(64, 3, 5, stride=2, padding=2, output_padding=1)
relu = nn.LeakyReLU()
batch1 =  nn.BatchNorm2d(128)
batch2 =  nn.BatchNorm2d(64)
batch3 =  nn.BatchNorm2d(3)
out = nn.Tanh()

pre_pipe = [pc1]
pipe = [conv1, batch1, relu, conv2, batch2, relu, conv3, batch3, out]

def print_pipe(pipe, input):
    print(tuple(input.shape))
    for layer in pipe:
        print("")
        print("↓", layer)
        input = layer(input)
        try:
            print("↓ weight:", tuple(layer.weight.shape))
            print("↓ bias:", tuple(layer.bias.shape))
        except AttributeError:
            pass
        print("")
        print(tuple(input.shape))
    return input

input = print_pipe(pre_pipe, input)
input = input.view(-1, 256, 4, 4)
output = print_pipe(pipe, input)

(16, 100)

↓ Linear(in_features=100, out_features=4096, bias=True)
↓ weight: (4096, 100)
↓ bias: (4096,)

(16, 4096)
(16, 256, 4, 4)

↓ ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
↓ weight: (256, 128, 5, 5)
↓ bias: (128,)

(16, 128, 8, 8)

↓ BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
↓ weight: (128,)
↓ bias: (128,)

(16, 128, 8, 8)

↓ LeakyReLU(negative_slope=0.01)

(16, 128, 8, 8)

↓ ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
↓ weight: (128, 64, 5, 5)
↓ bias: (64,)

(16, 64, 16, 16)

↓ BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
↓ weight: (64,)
↓ bias: (64,)

(16, 64, 16, 16)

↓ LeakyReLU(negative_slope=0.01)

(16, 64, 16, 16)

↓ ConvTranspose2d(64, 3, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), output_padding=(1, 1))
↓ weight: (64, 3, 5, 5)
↓ bias: (3,)

(16, 3, 32, 32)

↓ BatchNorm2d(3, eps=1e

In [105]:
# ! no linear

# Discriminator
# input: [3, 32, 32]
# layer [] conv & leakyReLU
# output [1]

input = torch.rand(16, 3, 32, 32)

conv1 = nn.Conv2d(3, 3, 5, stride=2, padding=2)
conv2 = nn.Conv2d(3, 3, 5, stride=2, padding=2)
conv3 = nn.Conv2d(3, 3, 5, stride=2, padding=2)
pc = nn.Conv2d(3, 1, 4)
relu = nn.LeakyReLU()
batch1 =  nn.BatchNorm2d(3)
batch2 =  nn.BatchNorm2d(3)
out = nn.Sigmoid()

pipe = [conv1, batch1, relu, conv2, batch2, relu, conv3, pc, out]

def print_pipe(pipe, input):
    print(tuple(input.shape))
    for layer in pipe:
        print("")
        print("↓", layer)
        input = layer(input)
        try:
            print("↓ weight:", tuple(layer.weight.shape))
            print("↓ bias:", tuple(layer.bias.shape))
        except AttributeError:
            pass
        print("")
        print(tuple(input.shape))
    return input

output = print_pipe(pipe, input)
output.flatten()

(16, 3, 32, 32)

↓ Conv2d(3, 3, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
↓ weight: (3, 3, 5, 5)
↓ bias: (3,)

(16, 3, 16, 16)

↓ BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
↓ weight: (3,)
↓ bias: (3,)

(16, 3, 16, 16)

↓ LeakyReLU(negative_slope=0.01)

(16, 3, 16, 16)

↓ Conv2d(3, 3, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
↓ weight: (3, 3, 5, 5)
↓ bias: (3,)

(16, 3, 8, 8)

↓ BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
↓ weight: (3,)
↓ bias: (3,)

(16, 3, 8, 8)

↓ LeakyReLU(negative_slope=0.01)

(16, 3, 8, 8)

↓ Conv2d(3, 3, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
↓ weight: (3, 3, 5, 5)
↓ bias: (3,)

(16, 3, 4, 4)

↓ Conv2d(3, 1, kernel_size=(4, 4), stride=(1, 1))
↓ weight: (1, 3, 4, 4)
↓ bias: (1,)

(16, 1, 1, 1)

↓ Sigmoid()

(16, 1, 1, 1)


tensor([0.5577, 0.5150, 0.5563, 0.5154, 0.5007, 0.5275, 0.5034, 0.5058, 0.5836,
        0.5439, 0.5099, 0.5122, 0.5819, 0.4811, 0.5670, 0.5467],
       grad_fn=<ViewBackward>)

In [119]:
torch.randn(100)

tensor([-1.1631e+00,  4.8946e-01, -9.0157e-01, -1.9192e-01,  4.3275e-01,
        -5.4418e-01, -1.6182e-01, -3.3701e-01,  7.7639e-01,  2.3676e-02,
        -1.3282e+00, -7.3611e-01, -5.1528e-02, -2.7328e-04,  1.4827e+00,
        -1.4353e+00,  1.2153e+00,  7.6078e-01,  1.7725e+00, -2.3429e-01,
        -6.8469e-01,  5.2390e-01, -1.0138e+00,  1.5600e+00, -1.4037e-01,
         1.7212e-01, -8.9528e-01, -8.2807e-02, -3.8931e-01, -2.7677e-01,
        -3.6832e-01, -1.9761e+00, -3.2232e-01,  5.7051e-01,  4.3654e-02,
         6.4523e-02, -1.0529e+00,  4.3544e-01, -4.4667e-01,  7.6477e-01,
        -2.9100e-01, -8.9874e-02,  3.8855e-01,  3.6674e-02,  1.3649e+00,
         1.8008e+00,  1.3792e+00,  4.2143e-02,  1.5043e-01,  9.8680e-01,
        -3.8085e-01,  3.0854e-01,  6.8735e-01, -4.8478e-01, -7.4162e-01,
         1.0992e-01, -2.4957e-01, -2.4525e-01,  1.5533e+00,  4.8228e-01,
         1.8783e+00, -2.9088e-01, -1.8227e+00, -4.0738e-01, -7.7897e-01,
         4.2631e-01, -1.8404e+00,  2.5480e-01, -2.1

In [118]:
trainloader.__iter__().next()


[tensor([[[[ 0.5608,  0.6000,  0.7020,  ...,  0.5608,  0.6471,  0.6863],
           [ 0.4980,  0.6627,  0.6941,  ...,  0.6078,  0.6471,  0.5765],
           [ 0.4902,  0.6627,  0.6549,  ...,  0.6549,  0.6235,  0.5529],
           ...,
           [ 0.1608, -0.0824, -0.0902,  ...,  0.2627,  0.7804,  0.6392],
           [ 0.3490, -0.1137, -0.4196,  ..., -0.3333,  0.3020,  0.5686],
           [ 0.2471, -0.2627, -0.4510,  ..., -0.4431, -0.2314,  0.3020]],
 
          [[ 0.5059,  0.5373,  0.6314,  ...,  0.4824,  0.5843,  0.6314],
           [ 0.4431,  0.6000,  0.6235,  ...,  0.5294,  0.5765,  0.5137],
           [ 0.4353,  0.6000,  0.5843,  ...,  0.5843,  0.5529,  0.4980],
           ...,
           [ 0.1059, -0.1294, -0.1451,  ...,  0.2314,  0.7490,  0.5843],
           [ 0.2863, -0.1608, -0.4667,  ..., -0.3647,  0.2627,  0.5137],
           [ 0.2157, -0.2941, -0.4824,  ..., -0.4824, -0.2784,  0.2549]],
 
          [[ 0.4275,  0.5216,  0.6471,  ...,  0.5294,  0.5922,  0.5765],
           [ 

In [None]:
nn.BCELoss()

In [139]:
Variable(torch.ones(16), requires_grad=False)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])