In [1]:
import torch #基本モジュール
from torch.autograd import Variable #自動微分用
import torch.nn as nn #ネットワーク構築用
import torch.optim as optim #最適化関数
import torch.nn.functional as F #ネットワーク用の様々な関数
import torch.utils.data #データセット読み込み関連
import torchvision #画像関連
from torchvision import datasets, models, transforms #画像用データセット諸々

import numpy as np

In [2]:
#画像の変形処理
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#CIFAR-10のtrain, testsetのロード
#変形はtransformを適用
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

#DataLoaderの適用->これによりバッチの割り当て・シャッフルをまとめて行うことができる
#batch_sizeでバッチサイズを指定
#num_workersでいくつのコアでデータをロードするか指定(デフォルトはメインのみ)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=16,
                                         shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5) # [3, 32, 32] => [6, 28, 28]
        self.conv2 = nn.Conv2d(6, 16, 5) # [6, 28, 28] => [16, 24, 24]
        self.pool = nn.MaxPool2d(2, 2) # [N,C,H,W] => [N, C, H/2, W/2]
        self.fc1 = nn.Linear(16 * 5 * 5, 120) # 
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [7]:
#モデル定義
model = CNN()

#Loss関数の指定
criterion = nn.CrossEntropyLoss()

#Optimizerの指定
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [11]:
#トレーニング
#エポック数の指定
for epoch in range(1):  # loop over the dataset multiple times

    #データ全てのトータルロス
    running_loss = 0.0 


    for i, (inputs, labels) in enumerate(trainloader):

        # Variableに変形
        # wrap them in Variable
        inputs, labels = Variable(inputs), Variable(labels)
        
        # optimizerの初期化
        # zero the parameter gradients
        optimizer.zero_grad()

        #一連の流れ
        # forward + backward + optimize
        outputs = model(inputs)

        #ここでラベルデータに対するCross-Entropyがとられる
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # ロスの表示
        # print statistics
        print("epoch %d, batch%d/%d, loss%f" % (epoch, i, len(trainloader), loss.item()))
        

print('Finished Training')

epoch 0, batch0/3125, loss1.537807
epoch 0, batch1/3125, loss1.248273
epoch 0, batch2/3125, loss0.902971
epoch 0, batch3/3125, loss1.172384
epoch 0, batch4/3125, loss0.994371
epoch 0, batch5/3125, loss1.498207
epoch 0, batch6/3125, loss1.110374
epoch 0, batch7/3125, loss1.486221
epoch 0, batch8/3125, loss1.571879
epoch 0, batch9/3125, loss2.169888
epoch 0, batch10/3125, loss1.440867
epoch 0, batch11/3125, loss1.777896
epoch 0, batch12/3125, loss1.552136
epoch 0, batch13/3125, loss1.505647
epoch 0, batch14/3125, loss1.245040
epoch 0, batch15/3125, loss1.219121
epoch 0, batch16/3125, loss2.003653
epoch 0, batch17/3125, loss1.407731
epoch 0, batch18/3125, loss1.516555
epoch 0, batch19/3125, loss1.697100
epoch 0, batch20/3125, loss1.508013
epoch 0, batch21/3125, loss1.652302
epoch 0, batch22/3125, loss1.473766
epoch 0, batch23/3125, loss1.503440
epoch 0, batch24/3125, loss1.241295
epoch 0, batch25/3125, loss1.243371
epoch 0, batch26/3125, loss1.020127
epoch 0, batch27/3125, loss1.539554
ep

epoch 0, batch224/3125, loss1.611403
epoch 0, batch225/3125, loss1.652764
epoch 0, batch226/3125, loss1.730226
epoch 0, batch227/3125, loss1.524780
epoch 0, batch228/3125, loss1.941669
epoch 0, batch229/3125, loss1.123502
epoch 0, batch230/3125, loss1.392051
epoch 0, batch231/3125, loss1.744522
epoch 0, batch232/3125, loss1.300610
epoch 0, batch233/3125, loss1.441268
epoch 0, batch234/3125, loss1.302802
epoch 0, batch235/3125, loss1.431087
epoch 0, batch236/3125, loss1.145019
epoch 0, batch237/3125, loss1.505837
epoch 0, batch238/3125, loss1.554678
epoch 0, batch239/3125, loss1.271971
epoch 0, batch240/3125, loss1.277088
epoch 0, batch241/3125, loss1.296713
epoch 0, batch242/3125, loss1.820934
epoch 0, batch243/3125, loss1.436921
epoch 0, batch244/3125, loss1.279294
epoch 0, batch245/3125, loss1.642675
epoch 0, batch246/3125, loss1.338000
epoch 0, batch247/3125, loss1.614737
epoch 0, batch248/3125, loss1.318908
epoch 0, batch249/3125, loss1.402424
epoch 0, batch250/3125, loss0.849732
e

epoch 0, batch445/3125, loss1.277372
epoch 0, batch446/3125, loss0.836580
epoch 0, batch447/3125, loss1.057016
epoch 0, batch448/3125, loss1.365989
epoch 0, batch449/3125, loss1.418499
epoch 0, batch450/3125, loss1.108507
epoch 0, batch451/3125, loss1.269440
epoch 0, batch452/3125, loss1.185823
epoch 0, batch453/3125, loss1.194090
epoch 0, batch454/3125, loss1.618435
epoch 0, batch455/3125, loss1.428174
epoch 0, batch456/3125, loss1.735968
epoch 0, batch457/3125, loss1.554071
epoch 0, batch458/3125, loss1.779329
epoch 0, batch459/3125, loss1.530111
epoch 0, batch460/3125, loss0.999200
epoch 0, batch461/3125, loss1.601512
epoch 0, batch462/3125, loss1.648471
epoch 0, batch463/3125, loss1.331573
epoch 0, batch464/3125, loss1.566744
epoch 0, batch465/3125, loss1.345043
epoch 0, batch466/3125, loss1.163744
epoch 0, batch467/3125, loss1.638698
epoch 0, batch468/3125, loss1.594047
epoch 0, batch469/3125, loss1.243369
epoch 0, batch470/3125, loss1.499059
epoch 0, batch471/3125, loss1.962879
e

epoch 0, batch666/3125, loss1.476666
epoch 0, batch667/3125, loss1.175212
epoch 0, batch668/3125, loss1.897943
epoch 0, batch669/3125, loss1.699627
epoch 0, batch670/3125, loss1.384135
epoch 0, batch671/3125, loss1.079165
epoch 0, batch672/3125, loss1.626162
epoch 0, batch673/3125, loss1.754001
epoch 0, batch674/3125, loss1.471847
epoch 0, batch675/3125, loss1.202730
epoch 0, batch676/3125, loss1.320449
epoch 0, batch677/3125, loss1.636298
epoch 0, batch678/3125, loss1.114363
epoch 0, batch679/3125, loss1.365265
epoch 0, batch680/3125, loss1.190663
epoch 0, batch681/3125, loss2.007893
epoch 0, batch682/3125, loss1.307651
epoch 0, batch683/3125, loss0.902619
epoch 0, batch684/3125, loss1.515663
epoch 0, batch685/3125, loss1.397136
epoch 0, batch686/3125, loss1.389241
epoch 0, batch687/3125, loss1.556030
epoch 0, batch688/3125, loss1.396267
epoch 0, batch689/3125, loss1.751937
epoch 0, batch690/3125, loss1.413673
epoch 0, batch691/3125, loss1.727708
epoch 0, batch692/3125, loss1.961950
e

epoch 0, batch887/3125, loss1.411405
epoch 0, batch888/3125, loss1.744266
epoch 0, batch889/3125, loss1.217621
epoch 0, batch890/3125, loss1.198392
epoch 0, batch891/3125, loss1.399101
epoch 0, batch892/3125, loss1.536877
epoch 0, batch893/3125, loss1.481927
epoch 0, batch894/3125, loss1.346244
epoch 0, batch895/3125, loss1.296812
epoch 0, batch896/3125, loss2.188304
epoch 0, batch897/3125, loss1.188292
epoch 0, batch898/3125, loss0.946612
epoch 0, batch899/3125, loss1.243369
epoch 0, batch900/3125, loss1.312531
epoch 0, batch901/3125, loss1.472349
epoch 0, batch902/3125, loss1.180587
epoch 0, batch903/3125, loss1.428604
epoch 0, batch904/3125, loss1.596985
epoch 0, batch905/3125, loss1.748024
epoch 0, batch906/3125, loss1.306585
epoch 0, batch907/3125, loss1.633890
epoch 0, batch908/3125, loss1.257813
epoch 0, batch909/3125, loss1.270254
epoch 0, batch910/3125, loss1.654736
epoch 0, batch911/3125, loss1.724358
epoch 0, batch912/3125, loss1.154197
epoch 0, batch913/3125, loss1.256527
e

epoch 0, batch1105/3125, loss1.404907
epoch 0, batch1106/3125, loss1.587262
epoch 0, batch1107/3125, loss1.147066
epoch 0, batch1108/3125, loss1.569286
epoch 0, batch1109/3125, loss1.470846
epoch 0, batch1110/3125, loss1.432600
epoch 0, batch1111/3125, loss1.565293
epoch 0, batch1112/3125, loss1.937914
epoch 0, batch1113/3125, loss1.677625
epoch 0, batch1114/3125, loss1.336203
epoch 0, batch1115/3125, loss1.291076
epoch 0, batch1116/3125, loss1.497771
epoch 0, batch1117/3125, loss1.413853
epoch 0, batch1118/3125, loss1.061167
epoch 0, batch1119/3125, loss0.926949
epoch 0, batch1120/3125, loss1.279889
epoch 0, batch1121/3125, loss0.994221
epoch 0, batch1122/3125, loss1.046096
epoch 0, batch1123/3125, loss1.655534
epoch 0, batch1124/3125, loss1.567599
epoch 0, batch1125/3125, loss1.321145
epoch 0, batch1126/3125, loss1.602441
epoch 0, batch1127/3125, loss1.177948
epoch 0, batch1128/3125, loss1.412328
epoch 0, batch1129/3125, loss1.206744
epoch 0, batch1130/3125, loss1.111803
epoch 0, bat

epoch 0, batch1320/3125, loss1.843869
epoch 0, batch1321/3125, loss1.767571
epoch 0, batch1322/3125, loss1.578944
epoch 0, batch1323/3125, loss1.298577
epoch 0, batch1324/3125, loss1.015782
epoch 0, batch1325/3125, loss1.221917
epoch 0, batch1326/3125, loss1.076902
epoch 0, batch1327/3125, loss1.430134
epoch 0, batch1328/3125, loss1.268219
epoch 0, batch1329/3125, loss1.195155
epoch 0, batch1330/3125, loss1.705676
epoch 0, batch1331/3125, loss1.428140
epoch 0, batch1332/3125, loss1.260656
epoch 0, batch1333/3125, loss1.410637
epoch 0, batch1334/3125, loss1.048807
epoch 0, batch1335/3125, loss0.935631
epoch 0, batch1336/3125, loss1.184982
epoch 0, batch1337/3125, loss1.168289
epoch 0, batch1338/3125, loss1.127733
epoch 0, batch1339/3125, loss0.978208
epoch 0, batch1340/3125, loss1.464987
epoch 0, batch1341/3125, loss1.798972
epoch 0, batch1342/3125, loss1.283322
epoch 0, batch1343/3125, loss0.962550
epoch 0, batch1344/3125, loss1.614294
epoch 0, batch1345/3125, loss1.953891
epoch 0, bat

epoch 0, batch1535/3125, loss1.671684
epoch 0, batch1536/3125, loss1.244860
epoch 0, batch1537/3125, loss2.060718
epoch 0, batch1538/3125, loss1.344619
epoch 0, batch1539/3125, loss1.053696
epoch 0, batch1540/3125, loss1.569426
epoch 0, batch1541/3125, loss1.564769
epoch 0, batch1542/3125, loss1.302162
epoch 0, batch1543/3125, loss1.612277
epoch 0, batch1544/3125, loss0.994574
epoch 0, batch1545/3125, loss1.049558
epoch 0, batch1546/3125, loss1.824419
epoch 0, batch1547/3125, loss1.148324
epoch 0, batch1548/3125, loss1.414185
epoch 0, batch1549/3125, loss1.662587
epoch 0, batch1550/3125, loss1.767776
epoch 0, batch1551/3125, loss1.685690
epoch 0, batch1552/3125, loss1.113812
epoch 0, batch1553/3125, loss1.205878
epoch 0, batch1554/3125, loss1.594961
epoch 0, batch1555/3125, loss1.128718
epoch 0, batch1556/3125, loss1.309935
epoch 0, batch1557/3125, loss1.982044
epoch 0, batch1558/3125, loss1.334331
epoch 0, batch1559/3125, loss1.343307
epoch 0, batch1560/3125, loss1.343604
epoch 0, bat

epoch 0, batch1750/3125, loss1.598039
epoch 0, batch1751/3125, loss1.525257
epoch 0, batch1752/3125, loss1.073471
epoch 0, batch1753/3125, loss0.754855
epoch 0, batch1754/3125, loss1.506542
epoch 0, batch1755/3125, loss1.055953
epoch 0, batch1756/3125, loss1.228026
epoch 0, batch1757/3125, loss1.541923
epoch 0, batch1758/3125, loss2.044669
epoch 0, batch1759/3125, loss1.319923
epoch 0, batch1760/3125, loss1.188153
epoch 0, batch1761/3125, loss1.435495
epoch 0, batch1762/3125, loss1.333815
epoch 0, batch1763/3125, loss1.033756
epoch 0, batch1764/3125, loss1.199210
epoch 0, batch1765/3125, loss2.179629
epoch 0, batch1766/3125, loss1.324057
epoch 0, batch1767/3125, loss1.673826
epoch 0, batch1768/3125, loss1.205296
epoch 0, batch1769/3125, loss1.204059
epoch 0, batch1770/3125, loss1.168450
epoch 0, batch1771/3125, loss1.529403
epoch 0, batch1772/3125, loss1.410415
epoch 0, batch1773/3125, loss1.785696
epoch 0, batch1774/3125, loss1.327240
epoch 0, batch1775/3125, loss1.484427
epoch 0, bat

epoch 0, batch1965/3125, loss1.649157
epoch 0, batch1966/3125, loss1.441327
epoch 0, batch1967/3125, loss0.977665
epoch 0, batch1968/3125, loss1.668116
epoch 0, batch1969/3125, loss1.358949
epoch 0, batch1970/3125, loss1.300292
epoch 0, batch1971/3125, loss1.446732
epoch 0, batch1972/3125, loss1.951679
epoch 0, batch1973/3125, loss1.122110
epoch 0, batch1974/3125, loss1.082854
epoch 0, batch1975/3125, loss1.268435
epoch 0, batch1976/3125, loss1.129991
epoch 0, batch1977/3125, loss1.611593
epoch 0, batch1978/3125, loss1.551330
epoch 0, batch1979/3125, loss1.483578
epoch 0, batch1980/3125, loss1.445419
epoch 0, batch1981/3125, loss2.099351
epoch 0, batch1982/3125, loss1.228285
epoch 0, batch1983/3125, loss1.289091
epoch 0, batch1984/3125, loss1.371178
epoch 0, batch1985/3125, loss1.428799
epoch 0, batch1986/3125, loss1.320313
epoch 0, batch1987/3125, loss0.888711
epoch 0, batch1988/3125, loss1.016828
epoch 0, batch1989/3125, loss1.146269
epoch 0, batch1990/3125, loss1.000339
epoch 0, bat

epoch 0, batch2180/3125, loss1.254025
epoch 0, batch2181/3125, loss1.473515
epoch 0, batch2182/3125, loss1.152194
epoch 0, batch2183/3125, loss1.476165
epoch 0, batch2184/3125, loss0.946811
epoch 0, batch2185/3125, loss1.075871
epoch 0, batch2186/3125, loss1.623904
epoch 0, batch2187/3125, loss1.274306
epoch 0, batch2188/3125, loss0.979686
epoch 0, batch2189/3125, loss1.058583
epoch 0, batch2190/3125, loss1.707101
epoch 0, batch2191/3125, loss1.563066
epoch 0, batch2192/3125, loss1.482247
epoch 0, batch2193/3125, loss1.640317
epoch 0, batch2194/3125, loss1.589703
epoch 0, batch2195/3125, loss1.388572
epoch 0, batch2196/3125, loss1.058723
epoch 0, batch2197/3125, loss1.553943
epoch 0, batch2198/3125, loss1.437078
epoch 0, batch2199/3125, loss1.466048
epoch 0, batch2200/3125, loss1.076219
epoch 0, batch2201/3125, loss1.111305
epoch 0, batch2202/3125, loss1.540580
epoch 0, batch2203/3125, loss1.407092
epoch 0, batch2204/3125, loss1.614998
epoch 0, batch2205/3125, loss1.794466
epoch 0, bat

epoch 0, batch2395/3125, loss1.060943
epoch 0, batch2396/3125, loss1.041597
epoch 0, batch2397/3125, loss1.275384
epoch 0, batch2398/3125, loss1.082047
epoch 0, batch2399/3125, loss1.121042
epoch 0, batch2400/3125, loss1.352094
epoch 0, batch2401/3125, loss1.936557
epoch 0, batch2402/3125, loss1.213565
epoch 0, batch2403/3125, loss1.566415
epoch 0, batch2404/3125, loss1.481473
epoch 0, batch2405/3125, loss1.431689
epoch 0, batch2406/3125, loss0.991848
epoch 0, batch2407/3125, loss1.726434
epoch 0, batch2408/3125, loss1.960534
epoch 0, batch2409/3125, loss1.163413
epoch 0, batch2410/3125, loss1.278514
epoch 0, batch2411/3125, loss1.453782
epoch 0, batch2412/3125, loss1.498124
epoch 0, batch2413/3125, loss1.558659
epoch 0, batch2414/3125, loss1.170790
epoch 0, batch2415/3125, loss1.424586
epoch 0, batch2416/3125, loss1.294485
epoch 0, batch2417/3125, loss2.077120
epoch 0, batch2418/3125, loss1.526270
epoch 0, batch2419/3125, loss0.909317
epoch 0, batch2420/3125, loss1.149612
epoch 0, bat

epoch 0, batch2610/3125, loss1.108357
epoch 0, batch2611/3125, loss1.458084
epoch 0, batch2612/3125, loss0.960295
epoch 0, batch2613/3125, loss1.273633
epoch 0, batch2614/3125, loss1.547802
epoch 0, batch2615/3125, loss1.421030
epoch 0, batch2616/3125, loss1.506984
epoch 0, batch2617/3125, loss1.420600
epoch 0, batch2618/3125, loss1.052103
epoch 0, batch2619/3125, loss1.375055
epoch 0, batch2620/3125, loss1.482740
epoch 0, batch2621/3125, loss0.978757
epoch 0, batch2622/3125, loss1.059162
epoch 0, batch2623/3125, loss1.511119
epoch 0, batch2624/3125, loss1.075711
epoch 0, batch2625/3125, loss1.426616
epoch 0, batch2626/3125, loss1.407129
epoch 0, batch2627/3125, loss0.903393
epoch 0, batch2628/3125, loss1.879467
epoch 0, batch2629/3125, loss1.283966
epoch 0, batch2630/3125, loss1.186715
epoch 0, batch2631/3125, loss1.230825
epoch 0, batch2632/3125, loss1.294287
epoch 0, batch2633/3125, loss1.746232
epoch 0, batch2634/3125, loss1.142106
epoch 0, batch2635/3125, loss1.322174
epoch 0, bat

epoch 0, batch2825/3125, loss1.447660
epoch 0, batch2826/3125, loss0.957587
epoch 0, batch2827/3125, loss1.167505
epoch 0, batch2828/3125, loss1.656495
epoch 0, batch2829/3125, loss0.978698
epoch 0, batch2830/3125, loss1.443651
epoch 0, batch2831/3125, loss1.057103
epoch 0, batch2832/3125, loss0.947616
epoch 0, batch2833/3125, loss1.418913
epoch 0, batch2834/3125, loss1.077430
epoch 0, batch2835/3125, loss1.445808
epoch 0, batch2836/3125, loss1.459154
epoch 0, batch2837/3125, loss1.617881
epoch 0, batch2838/3125, loss1.649063
epoch 0, batch2839/3125, loss1.148422
epoch 0, batch2840/3125, loss1.673270
epoch 0, batch2841/3125, loss1.359378
epoch 0, batch2842/3125, loss1.808248
epoch 0, batch2843/3125, loss1.428088
epoch 0, batch2844/3125, loss1.118110
epoch 0, batch2845/3125, loss1.178697
epoch 0, batch2846/3125, loss1.319587
epoch 0, batch2847/3125, loss1.399467
epoch 0, batch2848/3125, loss1.158671
epoch 0, batch2849/3125, loss1.755124
epoch 0, batch2850/3125, loss1.946159
epoch 0, bat

epoch 0, batch3040/3125, loss1.483438
epoch 0, batch3041/3125, loss1.092061
epoch 0, batch3042/3125, loss1.500550
epoch 0, batch3043/3125, loss1.226773
epoch 0, batch3044/3125, loss1.275818
epoch 0, batch3045/3125, loss1.063264
epoch 0, batch3046/3125, loss1.810203
epoch 0, batch3047/3125, loss1.484831
epoch 0, batch3048/3125, loss1.302783
epoch 0, batch3049/3125, loss1.439333
epoch 0, batch3050/3125, loss1.288465
epoch 0, batch3051/3125, loss1.606471
epoch 0, batch3052/3125, loss1.331125
epoch 0, batch3053/3125, loss0.952145
epoch 0, batch3054/3125, loss1.820880
epoch 0, batch3055/3125, loss1.146161
epoch 0, batch3056/3125, loss1.468532
epoch 0, batch3057/3125, loss0.980824
epoch 0, batch3058/3125, loss1.308669
epoch 0, batch3059/3125, loss1.671742
epoch 0, batch3060/3125, loss1.044429
epoch 0, batch3061/3125, loss1.326264
epoch 0, batch3062/3125, loss1.276380
epoch 0, batch3063/3125, loss1.588317
epoch 0, batch3064/3125, loss1.410218
epoch 0, batch3065/3125, loss1.523468
epoch 0, bat