# CNN

## Regression

单通道

In [10]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler #only can be imported above pytorch0.2.0

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.pool0 = nn.AvgPool2d((2, 4), padding=1)
        self.conv1 = nn.Conv2d(1, 5, 4)
        self.pool1 = nn.AvgPool2d(3)
        self.conv2 = nn.Conv2d(5, 5, 3)
        self.pool2 = nn.MaxPool2d(3)
        self.fc1 = nn.Linear(80, 1)
        
    def forward(self, x):
        x = self.pool0(x)
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 5*4*4)
        x = self.fc1(x)
        return x

net = Net()
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)

In [28]:
import netCDF4  # netCDF4非Python自带包，需要自行下载

import numpy as np

f = netCDF4.Dataset('D:\\sst.mnmean.nc')  # ftp://ftp.cdc.noaa.gov/Datasets/noaa.ersst.v5/sst.mnmean.nc
SST = f.variables['sst'][-1203:, :, :].data

SST[SST < -2] = 0

In [29]:
month_mean = np.zeros((12, 89, 180))
month_s = np.arange(0, 1188, 12)
for i in range(12):
    month_mean[i] = np.average(SST[month_s + i], axis = 0)
    SST[month_s + i] -= month_mean[i]
for i in range(12):
    SST[1188 + i] -= month_mean[i]
for i in range(3):
    SST[1200 + i] -= month_mean[i]

In [8]:
SST = SST.reshape((1200, 1, 89, 180))

X_train = SST[:-16]
X_train = torch.from_numpy(X_train)
X_valid = SST[-16:-4]
X_valid = torch.from_numpy(X_valid)

In [6]:
loc = ((134, 44), \
       (94, 44), \
       (57, 37), \
       (39, 44), \
       (21, 36), \
       (9, 16), \
       (172, 59), \
       (158, 28), \
       (65, 64), \
       (0, 10))

In [11]:
for row in loc:
    y_train = SST[1:-15, :, row[1], row[0]]
    y_train = torch.from_numpy(y_train)
    y_valid = SST[-15:-3, :, row[1], row[0]]
    y_valid = torch.from_numpy(y_valid)

    for epoch in range(500):  # loop over the dataset multiple times
        optimizer.zero_grad()
        outputs = net(X_train)
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()

        outputs_valid = net(X_valid)
        loss_valid = criterion(outputs_valid, y_valid)
        print(epoch, loss.item(), loss_valid.item())

    break

0 1.0465993881225586 0.6835797429084778
1 0.9747850298881531 0.6549382209777832
2 0.8863504528999329 0.6098104119300842
3 0.7840589284896851 0.5596358180046082
4 0.6766080260276794 0.5468974709510803
5 0.5807129144668579 0.7936881184577942
6 0.5363522171974182 1.088138222694397
7 0.5773686766624451 0.8834357261657715
8 0.615320086479187 0.9258086085319519
9 0.5854175686836243 0.7049086093902588
10 0.5361378788948059 0.474771648645401
11 0.49880969524383545 0.41472306847572327
12 0.4840785562992096 0.4199715554714203
13 0.48431453108787537 0.4274919927120209
14 0.49079370498657227 0.4431491792201996
15 0.49424758553504944 0.46046414971351624
16 0.4938221871852875 0.4491209089756012
17 0.4867117702960968 0.43406984210014343
18 0.47634685039520264 0.4229389429092407
19 0.46304330229759216 0.4076906144618988
20 0.45169714093208313 0.3971867561340332
21 0.44332918524742126 0.4074654281139374
22 0.4409140646457672 0.38971027731895447
23 0.4423217177391052 0.41832205653190613
24 0.44405153393

193 0.18328599631786346 0.27480044960975647
194 0.18276703357696533 0.2728077471256256
195 0.18256059288978577 0.2880725562572479
196 0.18251702189445496 0.2705892026424408
197 0.1822158694267273 0.2832965850830078
198 0.18179568648338318 0.28649887442588806
199 0.1815796047449112 0.2770906388759613
200 0.18146997690200806 0.2947790026664734
201 0.18121550977230072 0.2852121889591217
202 0.18086321651935577 0.28724566102027893
203 0.18062308430671692 0.2973720133304596
204 0.18047688901424408 0.2842743992805481
205 0.1802806854248047 0.2968388497829437
206 0.18000203371047974 0.2930295765399933
207 0.17974315583705902 0.29091888666152954
208 0.1795731633901596 0.3027160167694092
209 0.17941376566886902 0.29282739758491516
210 0.17918448150157928 0.3013637661933899
211 0.17891918122768402 0.30124548077583313
212 0.17869751155376434 0.29702091217041016
213 0.17851226031780243 0.3069303631782532
214 0.17830677330493927 0.2990483343601227
215 0.17807386815547943 0.3057812750339508
216 0.17

384 0.1473223716020584 0.2667428255081177
385 0.14896926283836365 0.3070680797100067
386 0.14825516939163208 0.2888113260269165
387 0.1467156857252121 0.2749788761138916
388 0.1477700173854828 0.30958208441734314
389 0.1476530134677887 0.2895793616771698
390 0.14622369408607483 0.2831544280052185
391 0.14677470922470093 0.3147099018096924
392 0.1470494419336319 0.29177483916282654
393 0.14587847888469696 0.29103267192840576
394 0.14580053091049194 0.3136032521724701
395 0.14629392325878143 0.2900228202342987
396 0.14554639160633087 0.2943248450756073
397 0.14511269330978394 0.3095169961452484
398 0.1455204039812088 0.28800180554389954
399 0.14522916078567505 0.29858991503715515
400 0.14461293816566467 0.30427733063697815
401 0.1446652114391327 0.2897147834300995
402 0.14477141201496124 0.3049187958240509
403 0.14433437585830688 0.30090177059173584
404 0.14404179155826569 0.29287612438201904
405 0.1441570371389389 0.3082398772239685
406 0.14404381811618805 0.29848915338516235
407 0.1436

预测（92°W，0°）今年3月份海温（真实值为27.26599℃）

In [30]:
X_test = SST[-2].reshape((1, 1, 89, 180))
X_test = torch.from_numpy(X_test)

y_fore = net(X_test) + month_mean[2, row[1], row[0]]
y_fore

tensor([[ 26.5089]])

双通道

In [29]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler #only can be imported above pytorch0.2.0

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.pool0 = nn.AvgPool2d((2, 4), padding=1)
        self.conv1 = nn.Conv2d(2, 5, 4)
        self.pool1 = nn.AvgPool2d(3)
        self.conv2 = nn.Conv2d(5, 5, 3)
        self.pool2 = nn.AvgPool2d(3)
        self.fc1 = nn.Linear(80, 1)
        
    def forward(self, x):
        x = self.pool0(x)
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(-1, 5*4*4)
        x = self.fc1(x)
        return x

In [16]:
import netCDF4  # netCDF4非Python自带包，需要自行下载

import numpy as np

f = netCDF4.Dataset('D:\\sst.mnmean.nc')  # ftp://ftp.cdc.noaa.gov/Datasets/noaa.ersst.v5/sst.mnmean.nc
SST = f.variables['sst'][-1203:, :, :].data

SST[SST < -2] = 0

In [17]:
month_mean = np.zeros((12, 89, 180))
month_s = np.arange(0, 1188, 12)
for i in range(12):
    month_mean[i] = np.average(SST[month_s + i], axis = 0)
    SST[month_s + i] -= month_mean[i]
for i in range(12):
    SST[1188 + i] -= month_mean[i]
for i in range(3):
    SST[1200 + i] -= month_mean[i]

In [18]:
SST = SST.reshape((1203, 1, 89, 180))

X_train = np.zeros((1186, 2, 89, 180), dtype=np.float32)
X_train[:, 0] = SST[:-17, 0]
X_train[:, 1] = SST[1:-16, 0]
X_train = torch.from_numpy(X_train)

X_valid = np.zeros((12, 2, 89, 180), dtype=np.float32)
X_valid[:, 0] = SST[-17:-5, 0]
X_valid[:, 1] = SST[-16:-4, 0]
X_valid = torch.from_numpy(X_valid)

In [19]:
loc = ((134, 44), \
       (94, 44), \
       (57, 37), \
       (39, 44), \
       (21, 36), \
       (9, 16), \
       (172, 59), \
       (158, 28), \
       (65, 64), \
       (0, 10))

In [33]:
for row in loc:
    y_train = SST[2:-15, :, row[1], row[0]]
    y_train = torch.from_numpy(y_train)
    y_valid = SST[-15:-3, :, row[1], row[0]]
    y_valid = torch.from_numpy(y_valid)
    
    net = Net()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=0.01)

    for epoch in range(500):  # loop over the dataset multiple times
        optimizer.zero_grad()
        outputs = net(X_train)
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()

        outputs_valid = net(X_valid)
        loss_valid = criterion(outputs_valid, y_valid)
        print(epoch, loss.item(), loss_valid.item())

    break

0 1.0924222469329834 0.7783412933349609
1 1.0704162120819092 0.7477688789367676
2 1.0477246046066284 0.7139258980751038
3 1.017866849899292 0.689863920211792
4 0.9753577709197998 0.7009252905845642
5 0.9219707250595093 0.7884209156036377
6 0.8792669177055359 0.9841131567955017
7 0.8626843690872192 1.0334067344665527
8 0.8389838337898254 0.8623300194740295
9 0.7954409718513489 0.656138002872467
10 0.7594743371009827 0.5161804556846619
11 0.7386359572410583 0.43809643387794495
12 0.7192434072494507 0.39140284061431885
13 0.6935243606567383 0.3567187786102295
14 0.6629032492637634 0.3238140642642975
15 0.6298184394836426 0.28833717107772827
16 0.5976099371910095 0.2548743784427643
17 0.5731301307678223 0.2385539561510086
18 0.5558366179466248 0.27531805634498596
19 0.5374846458435059 0.3896186649799347
20 0.5245131850242615 0.4896676540374756
21 0.5149450898170471 0.49002495408058167
22 0.5047157406806946 0.4387580156326294
23 0.4967308044433594 0.39900103211402893
24 0.48543277382850647 

193 0.17737534642219543 0.10822869092226028
194 0.17647625505924225 0.11043223738670349
195 0.17786617577075958 0.10831903666257858
196 0.1777893453836441 0.10807955265045166
197 0.17614613473415375 0.10726223140954971
198 0.17587417364120483 0.1052391454577446
199 0.17678102850914001 0.10809841006994247
200 0.17647632956504822 0.10445096343755722
201 0.17529596388339996 0.10432963818311691
202 0.17525231838226318 0.1078687533736229
203 0.17580030858516693 0.10448169708251953
204 0.17538657784461975 0.10577940940856934
205 0.17459642887115479 0.10627207159996033
206 0.17458784580230713 0.1050928607583046
207 0.174891397356987 0.10685473680496216
208 0.1745796501636505 0.10535714775323868
209 0.17398780584335327 0.10542616993188858
210 0.17385916411876678 0.1068679466843605
211 0.17403532564640045 0.10561370849609375
212 0.17391647398471832 0.10663079470396042
213 0.1734778881072998 0.10603330284357071
214 0.17316702008247375 0.10555749386548996
215 0.17317035794258118 0.106790624558925

382 0.15562200546264648 0.1267538219690323
383 0.15575744211673737 0.12887120246887207
384 0.15559297800064087 0.12736012041568756
385 0.15524767339229584 0.12829825282096863
386 0.15493090450763702 0.1281834840774536
387 0.15479522943496704 0.1277867704629898
388 0.15483234822750092 0.12933449447155
389 0.15490949153900146 0.12841914594173431
390 0.15490420162677765 0.13042187690734863
391 0.154795303940773 0.12926225364208221
392 0.1545923501253128 0.1303350329399109
393 0.15436862409114838 0.12970729172229767
394 0.1541934460401535 0.1298225373029709
395 0.15409450232982635 0.13034918904304504
396 0.15405558049678802 0.13000120222568512
397 0.15404854714870453 0.1312214881181717
398 0.15404565632343292 0.13034051656723022
399 0.15402184426784515 0.1318170428276062
400 0.15396752953529358 0.13082094490528107
401 0.15389753878116608 0.13268032670021057
402 0.15380370616912842 0.1316050887107849
403 0.1537168025970459 0.13347980380058289
404 0.15363629162311554 0.13218368589878082
405 

查看参数

In [45]:
params = net.state_dict()
for key, value in params.items():
    print(key)
    print(value)

conv1.weight
tensor([[[[ 0.1510, -0.1077, -0.0834,  0.0309],
          [-0.0083, -0.0001, -0.1092,  0.1961],
          [-0.1258, -0.0724,  0.0615,  0.2049],
          [-0.2861, -0.2778, -0.0905,  0.0435]],

         [[ 0.2954,  0.3766,  0.2090,  0.3081],
          [ 0.2486,  0.3672,  0.3810,  0.5710],
          [-0.2887,  0.1020,  0.2948,  0.5198],
          [-0.3312,  0.2737,  0.4264,  0.2467]]],


        [[[ 0.4285,  0.3927,  0.4240,  0.3062],
          [ 0.2997,  0.3778,  0.2909,  0.1247],
          [ 0.1028, -0.0430, -0.0430, -0.0094],
          [ 0.1991, -0.1548, -0.1521, -0.2047]],

         [[-0.3585, -0.3561,  0.0138, -0.1086],
          [ 0.0507, -0.1134, -0.1795,  0.0252],
          [ 0.1738,  0.0330, -0.1451, -0.1173],
          [ 0.1612, -0.1298, -0.3141, -0.5173]]],


        [[[-0.1371, -0.0787, -0.1715, -0.0309],
          [ 0.0039, -0.1633,  0.0560, -0.0061],
          [-0.1149, -0.2434, -0.0266,  0.1660],
          [-0.3392, -0.1659,  0.0258,  0.3732]],

         [[ 0

In [50]:
params['fc1.bias'].numpy()

array([-0.29721215], dtype=float32)

In [64]:
X_test

tensor([[[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]]])

In [71]:
mat = net.pool0(X_test)
mat

tensor([[[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.5791,  1.1537,  1.0077,  ..., -0.0000, -0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.6774,  1.3424,  1.3536,  ..., -0.0000, -0.0000,  0.0000],
          ...,
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]]])

In [72]:
mat = net.pool1(F.relu(net.conv1(mat)))
mat

tensor([[[[ 2.9896,  2.7273,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
            0.0000,  0.0000,  0.0000,  0.0083,  0.0000,  0.0000,  0.3413],
          [ 3.0476,  4.1616,  0.6891,  0.0000,  0.0000,  0.0000,  0.5120,
            0.1046,  0.0506,  0.0256,  0.0559,  0.3914,  0.2587,  0.8099],
          [ 0.4298,  0.0136,  0.0000,  0.0000,  0.0000,  0.0177,  2.2565,
            3.5687,  2.2694,  0.0290,  0.0000,  0.6461,  0.4188,  0.5486],
          [ 0.2837,  0.1487,  0.0131,  0.0000,  1.3816,  1.3646,  0.9308,
            1.7003,  1.4421,  0.3033,  0.0007,  0.8213,  0.7713,  1.3598],
          [ 0.8529,  0.3343,  0.0017,  0.0275,  3.2024,  2.9152,  3.2534,
            1.6996,  0.0464,  0.7146,  0.3012,  2.0293,  1.7992,  1.7493],
          [ 0.1673,  0.7290,  0.0984,  0.5467,  2.5228,  2.6183,  1.9869,
            1.4354,  1.0777,  2.1503,  1.3492,  1.9878,  1.6343,  0.8734],
          [ 0.0161,  1.4348,  1.2058,  1.2815,  0.0603,  0.6792,  1.7345,
            2.1348,  1.8072,  1.

In [73]:
mat = net.pool2(F.relu(net.conv2(mat)))
mat

tensor([[[[ 1.7245,  2.9144,  2.0041,  1.7712],
          [ 2.3932,  1.3243,  2.3043,  1.8784],
          [ 1.4030,  2.8169,  1.7986,  2.3987],
          [ 1.6967,  1.3228,  1.4515,  2.0170]],

         [[ 2.6577,  1.0501,  1.6945,  1.1951],
          [ 0.7797,  1.7055,  0.6594,  0.9934],
          [ 1.5289,  0.6554,  1.5661,  1.0234],
          [ 1.6763,  1.5860,  1.6955,  1.1643]],

         [[ 0.3246,  0.6026,  0.1709,  0.3178],
          [ 0.4125,  0.6620,  0.0015,  0.2274],
          [ 0.4246,  0.6517,  0.1932,  0.7214],
          [ 0.2037,  0.8563,  0.1220,  0.2548]],

         [[ 0.4498,  0.6473,  0.6233,  0.2590],
          [ 0.4052,  1.1604,  0.7552,  1.0241],
          [ 1.0047,  0.7187,  1.1670,  1.3284],
          [ 0.9126,  0.9888,  0.5209,  0.2752]],

         [[ 0.1560,  0.6472,  0.0675,  0.3956],
          [ 0.5628,  0.0311,  0.0000,  0.0269],
          [ 0.0000,  0.5598,  0.0000,  0.0000],
          [ 0.0566,  0.2672,  0.0352,  0.5312]]]])

In [74]:
mat = mat.view(-1, 80)
mat

tensor([[ 1.7245,  2.9144,  2.0041,  1.7712,  2.3932,  1.3243,  2.3043,
          1.8784,  1.4030,  2.8169,  1.7986,  2.3987,  1.6967,  1.3228,
          1.4515,  2.0170,  2.6577,  1.0501,  1.6945,  1.1951,  0.7797,
          1.7055,  0.6594,  0.9934,  1.5289,  0.6554,  1.5661,  1.0234,
          1.6763,  1.5860,  1.6955,  1.1643,  0.3246,  0.6026,  0.1709,
          0.3178,  0.4125,  0.6620,  0.0015,  0.2274,  0.4246,  0.6517,
          0.1932,  0.7214,  0.2037,  0.8563,  0.1220,  0.2548,  0.4498,
          0.6473,  0.6233,  0.2590,  0.4052,  1.1604,  0.7552,  1.0241,
          1.0047,  0.7187,  1.1670,  1.3284,  0.9126,  0.9888,  0.5209,
          0.2752,  0.1560,  0.6472,  0.0675,  0.3956,  0.5628,  0.0311,
          0.0000,  0.0269,  0.0000,  0.5598,  0.0000,  0.0000,  0.0566,
          0.2672,  0.0352,  0.5312]])

In [75]:
mat = net.fc1(mat)
mat

tensor([[ 0.5504]])

In [76]:
mat + month_mean[2, row[1], row[0]]

tensor([[ 27.2570]])

预测（92°W，0°）今年3月份海温（真实值为27.26599℃）

In [31]:
X_test = np.zeros((1, 2, 89, 180), dtype=np.float32)
X_test[0, 0] = SST[-3, 0]
X_test[0, 1] = SST[-2, 0]
X_test = torch.from_numpy(X_test)

y_fore = net(X_test) + month_mean[2, row[1], row[0]]
y_fore

tensor([[ 27.5039]])

In [32]:
for row in loc:
    y_train = SST[2:-15, :, row[1], row[0]]
    y_train = torch.from_numpy(y_train)
    y_valid = SST[-15:-3, :, row[1], row[0]]
    y_valid = torch.from_numpy(y_valid)
    
    net = Net()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=0.01)

    for epoch in range(1000):  # loop over the dataset multiple times
        optimizer.zero_grad()
        outputs = net(X_train)
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()

        outputs_valid = net(X_valid)
        loss_valid = criterion(outputs_valid, y_valid)
    
    print(net(X_test) + month_mean[2, row[1], row[0]])

tensor([[ 27.5720]])
tensor([[ 29.1317]])
tensor([[ 27.6299]])
tensor([[ 29.9275]])
tensor([[ 27.4515]])
tensor([[ 3.9730]])
tensor([[ 24.6470]])
tensor([[ 19.6922]])
tensor([[ 16.1494]])
tensor([[ 5.9008]])


## Classfication

数据上传到百度网盘 内涝等级识别 文件夹

In [8]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler #only can be imported above pytorch0.2.0
from torchvision import datasets, models, transforms
import os
import numpy as np

In [None]:
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

image_datasets = datasets.ImageFolder('train',
                                      data_transforms)

test_datasets = datasets.ImageFolder('test',
                                      data_transforms)

In [None]:
def make_weights_for_balanced_classes(images, nclasses):                        
    count = [0] * nclasses
    for item in images:
        count[item[1]] += 1
    weight_per_class = [0.] * nclasses
    N = float(sum(count))
    for i in range(nclasses):
        weight_per_class[i] = N/float(count[i])
    weight = [0] * len(images)
    for idx, val in enumerate(images):
        weight[idx] = weight_per_class[val[1]]
    return weight

weight = make_weights_for_balanced_classes(image_datasets, 5)

sampler = torch.utils.data.sampler.WeightedRandomSampler(weight, len(weight))

dataloders = torch.utils.data.DataLoader(image_datasets,
                                         batch_size=1,
                                         sampler = sampler,
                                         num_workers=0)

datatest = torch.utils.data.DataLoader(test_datasets,
                                         batch_size=1,
                                         shuffle=True,
                                         num_workers=0)

In [None]:
net = models.resnet18(pretrained=True)
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 5)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [None]:
for epoch in range(25):  # loop over the dataset multiple times
    scheduler.step()
    i = 0
    loss_sum = 0
    for data in dataloders:
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss_sum += loss.item()
        loss.backward()
        optimizer.step()

        i += 1

        if i % 1000 == 0:

            print(i, loss_sum / 1000)

            loss_sum = 0

    correct = np.zeros((5, 5))
    for data in datatest:
        inputs, labels = data
        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        correct[int(labels), int(predicted)] += 1
    for row in correct:
        print(row)

In [None]:
torch.save(net.state_dict(), 'test.pkl')

预测

In [14]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler #only can be imported above pytorch0.2.0
from torchvision import datasets, models, transforms
import os
import numpy as np

data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

test_datasets = datasets.ImageFolder('test_new',
                                      data_transforms)

datatest = torch.utils.data.DataLoader(test_datasets,
                                         batch_size=1,
                                         shuffle=True,
                                         num_workers=0)

net = models.resnet18(pretrained=True)
num_ftrs = net.fc.in_features
net.fc = nn.Linear(num_ftrs, 5)

net.load_state_dict(torch.load('test.pkl'), strict=False)

for data in datatest:
    inputs, labels = data
    outputs = F.softmax(net(inputs))
    print(outputs.detach().numpy()[0], int(labels))



[6.1513543e-05 9.9989283e-01 4.3832726e-05 1.4410913e-06 3.7497728e-07] 1
[1.9392589e-06 9.7590164e-06 1.0727131e-05 1.8801609e-04 9.9978954e-01] 4
[1.6476604e-05 3.4625569e-04 9.9959826e-01 3.7069316e-05 1.9144195e-06] 2
[1.8656605e-05 2.6466485e-04 2.7656858e-04 4.3591481e-01 5.6352532e-01] 3
[9.9962622e-01 3.5922590e-04 7.4621644e-06 5.9296272e-06 1.1686778e-06] 0
