# CNN

## Regression

单通道

In [8]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from adabound import AdaBound

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv0 = nn.Conv2d(1, 3, 2, stride = 2, padding = 1)
        self.pool0 = nn.AvgPool2d((1, 2))
        self.conv1 = nn.Conv2d(3, 3, 3, stride = 3)
        self.conv2 = nn.Conv2d(3, 1, 3, stride = 3)
        self.fc1 = nn.Linear(25, 1)
        
    def forward(self, x):
        x = self.pool0(self.conv0(x))
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 1*5*5)
        x = self.fc1(x)
        return x

In [3]:
import netCDF4  # netCDF4非Python自带包，需要自行下载

import numpy as np

f = netCDF4.Dataset('D:\\sst.mnmean.nc')  # ftp://ftp.cdc.noaa.gov/Datasets/noaa.ersst.v5/sst.mnmean.nc
SST = f.variables['sst'][-1203:, :, :].data

SST[SST < -2] = 0

In [4]:
month_mean = np.zeros((12, 89, 180))
month_s = np.arange(0, 1200, 12)
for i in range(12):
    month_mean[i] = np.average(SST[month_s[:95] + i], axis = 0)
    SST[month_s + i] -= month_mean[i]
for i in range(3):
    SST[1200 + i] -= month_mean[i]

In [5]:
SST = SST.reshape((1203, 1, 89, 180))

X_train = SST[:1139]
X_train = torch.from_numpy(X_train)
X_valid = SST[1139:1199]
X_valid = torch.from_numpy(X_valid)
X_test = SST[-2].reshape(1, 1, 89, 180)
X_test = torch.from_numpy(X_test)

In [7]:
row = (134, 44)
y_train = SST[1:1140, :, row[1], row[0]]
y_train = torch.from_numpy(y_train)
y_valid = SST[1140:1200, :, row[1], row[0]]
y_valid = torch.from_numpy(y_valid)

In [13]:
net = Net()
criterion = nn.MSELoss()
optimizer = AdaBound(net.parameters(), lr=0.001)

In [14]:
for epoch in range(2000):  # loop over the dataset multiple times
    optimizer.zero_grad()
    outputs = net(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()

    outputs_valid = net(X_valid)
    loss_valid = criterion(outputs_valid, y_valid)
    if epoch % 100 == 0:
        print(epoch + 1, loss.item(), loss_valid.item())

1 1.0593490600585938 2.2431704998016357
101 0.39485061168670654 0.585343599319458
201 0.25435343384742737 0.3394608497619629
301 0.2065161168575287 0.28106409311294556
401 0.1944275051355362 0.28916308283805847
501 0.19147314131259918 0.28996068239212036
601 0.18947745859622955 0.29275771975517273
701 0.18798966705799103 0.2954002320766449
801 0.18664179742336273 0.2981739938259125
901 0.185383602976799 0.29874271154403687
1001 0.18432217836380005 0.29979392886161804
1101 0.18343126773834229 0.2997266948223114
1201 0.18276023864746094 0.2997497022151947
1301 0.18226659297943115 0.30181318521499634
1401 0.18182489275932312 0.30027851462364197
1501 0.18145067989826202 0.29985979199409485
1601 0.18112118542194366 0.30041536688804626
1701 0.18081676959991455 0.3008872866630554
1801 0.1805395632982254 0.3021588921546936
1901 0.18074335157871246 0.30766358971595764


In [15]:
outputs = net(X_valid)
loss = criterion(outputs, y_valid)

for i in range(60):
    print(y_valid[i].item() + month_mean[i % 12, row[1], row[0]], outputs[i].item() + month_mean[i % 12, row[1], row[0]])
    
print('\n')

print(loss.item())

25.193620681762695 24.671341314911842
25.725605010986328 26.25347925722599
26.790624618530273 25.97232162952423
26.54935073852539 25.869835674762726
26.679452896118164 25.159868732094765
26.143993377685547 25.205073356628418
24.62432098388672 24.908365726470947
23.667221069335938 23.532217383384705
23.388002395629883 23.494011878967285
23.72989273071289 23.81073033809662
24.111202239990234 24.186190724372864
24.508861541748047 24.632781267166138
25.26885223388672 25.485126793384552
26.084409713745117 26.336624816060066
27.022750854492188 26.35879099369049
27.623594284057617 26.336759105324745
27.09197425842285 26.2988920211792
26.963388442993164 25.90183401107788
26.37074851989746 25.61638045310974
24.832151412963867 25.364282608032227
25.315359115600586 24.74648380279541
25.576629638671875 25.35601830482483
26.05023956298828 25.594332695007324
26.717212677001953 26.303779363632202
27.224092483520508 27.673407554626465
27.63623046875 28.423676252365112
28.250402450561523 27.98242127895

预测（92°W，0°）今年3月份海温（真实值为27.26599℃）

In [16]:
y_test = net(X_test) + month_mean[2, row[1], row[0]]
y_test.item()

27.431476593017578

六通道

In [17]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from adabound import AdaBound

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv0 = nn.Conv2d(6, 3, 2, stride = 2, padding = 1)
        self.pool0 = nn.AvgPool2d((1, 2))
        self.conv1 = nn.Conv2d(3, 3, 3, stride = 3)
        self.conv2 = nn.Conv2d(3, 1, 3, stride = 3)
        self.fc1 = nn.Linear(25, 1)
        
    def forward(self, x):
        x = self.pool0(self.conv0(x))
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, 1*5*5)
        x = self.fc1(x)
        return x

In [18]:
seq = 6

X_train = np.zeros((1140 - seq, seq, 89, 180), dtype = np.float32)
X_valid = np.zeros((60, seq, 89, 180), dtype = np.float32)
X_test = np.zeros((1, seq, 89, 180), dtype = np.float32)

for i in range(6):
    X_train[:, i] = SST[i:1140 - seq + i, 0]
    X_valid[:, i] = SST[1140 - seq + i:1200 - seq + i, 0]
    X_test[:, i] = SST[- (seq + 1) + i, 0]

y_train = SST[seq:1140, 0, row[1], row[0]]
y_valid = SST[1140:1200, 0, row[1], row[0]]

import torch

X_train = torch.from_numpy(X_train).float()
X_valid = torch.from_numpy(X_valid).float()
X_test = torch.from_numpy(X_test).float()
y_train = torch.from_numpy(y_train.reshape(len(y_train), 1)).float()
y_valid = torch.from_numpy(y_valid.reshape(len(y_valid), 1)).float()

In [19]:
net = Net()
criterion = nn.MSELoss()
optimizer = AdaBound(net.parameters(), lr=0.001)

In [20]:
for epoch in range(2000):  # loop over the dataset multiple times
    optimizer.zero_grad()
    outputs = net(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()

    outputs_valid = net(X_valid)
    loss_valid = criterion(outputs_valid, y_valid)
    if epoch % 100 == 0:
        print(epoch + 1, loss.item(), loss_valid.item())

1 1.0924763679504395 2.330173969268799
101 0.790428876876831 1.3777086734771729
201 0.3616828918457031 1.2509651184082031
301 0.2806044816970825 0.954754650592804
401 0.21504716575145721 0.4674451947212219
501 0.19042183458805084 0.33915314078330994
601 0.1841445416212082 0.30757245421409607
701 0.17647892236709595 0.3088524043560028
801 0.18248756229877472 0.3105855882167816
901 0.17049334943294525 0.32039350271224976
1001 0.16827881336212158 0.32811346650123596
1101 0.16639171540737152 0.33694854378700256
1201 0.16497302055358887 0.3290182948112488
1301 0.16315044462680817 0.33995312452316284
1401 0.16295288503170013 0.3451271653175354
1501 0.16025038063526154 0.3488489091396332
1601 0.17014898359775543 0.36060601472854614
1701 0.15779925882816315 0.3455619812011719
1801 0.15666332840919495 0.3408467471599579
1901 0.1573016345500946 0.3537077307701111


In [21]:
outputs = net(X_valid)
loss = criterion(outputs, y_valid)

for i in range(60):
    print(y_valid[i].item() + month_mean[i % 12, row[1], row[0]], outputs[i].item() + month_mean[i % 12, row[1], row[0]])
    
print('\n')

print(loss.item())

25.193620681762695 24.78334690630436
25.725605010986328 26.405107468366623
26.790624618530273 25.93099135160446
26.54935073852539 26.343162536621094
26.679452896118164 25.540600538253784
26.143993377685547 25.35278069972992
24.62432098388672 24.916329264640808
23.667221069335938 23.117464900016785
23.388002395629883 23.464094519615173
23.72989273071289 23.92610800266266
24.111202239990234 24.185763955116272
24.508861541748047 24.65487825870514
25.26885223388672 25.239870429039
26.084409713745117 25.964152932167053
27.022750854492188 25.77655065059662
27.623594284057617 26.037035584449768
27.09197425842285 26.034534811973572
26.963388442993164 25.598950624465942
26.37074851989746 25.315369606018066
24.832151412963867 24.98161506652832
25.315359115600586 24.6444730758667
25.576629638671875 25.4720139503479
26.05023956298828 25.534537315368652
26.717212677001953 26.452805519104004
27.224092483520508 27.550655841827393
27.63623046875 28.315449237823486
28.250402450561523 27.80140233039856


预测（92°W，0°）今年3月份海温（真实值为27.26599℃）

In [22]:
y_test = net(X_test) + month_mean[2, row[1], row[0]]
y_test.item()

27.169986724853516

查看参数

In [23]:
params = net.state_dict()
for key, value in params.items():
    print(key)
    print(value)

conv0.weight
tensor([[[[ 0.1626, -0.1818],
          [ 0.3055, -0.2712]],

         [[-0.0182,  0.1153],
          [ 0.2552, -0.3488]],

         [[ 0.2211, -0.1829],
          [ 0.0003, -0.0799]],

         [[-0.1093, -0.0536],
          [ 0.1107, -0.1153]],

         [[ 0.2299, -0.2277],
          [ 0.3637, -0.1182]],

         [[-0.0832, -0.3807],
          [-0.1627, -0.3990]]],


        [[[ 0.3674, -0.2971],
          [ 0.0445,  0.0142]],

         [[-0.3230,  0.0227],
          [ 0.0703,  0.3456]],

         [[-0.0224,  0.2676],
          [-0.1968, -0.0582]],

         [[-0.0434,  0.1085],
          [ 0.2416,  0.1632]],

         [[-0.0490,  0.0637],
          [-0.2543,  0.0810]],

         [[ 0.2915,  0.4196],
          [ 0.2939,  0.1732]]],


        [[[-0.2022,  0.1321],
          [-0.2898, -0.2228]],

         [[-0.1758, -0.2476],
          [-0.2026,  0.1506]],

         [[ 0.0349,  0.0800],
          [ 0.1520, -0.0624]],

         [[-0.1311, -0.4041],
          [-0.2348, -0.

In [24]:
params['fc1.bias'].numpy()

array([-0.07206452], dtype=float32)

先卷积后RNN

In [27]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from adabound import AdaBound

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv0 = nn.Conv2d(1, 3, 2, stride = 2, padding = 1)
        self.pool0 = nn.AvgPool2d((1, 2))
        self.conv1 = nn.Conv2d(3, 3, 3, stride = 3)
        self.conv2 = nn.Conv2d(3, 1, 3, stride = 3)
        self.rnn = nn.RNN(
            input_size=25,
            hidden_size=2,
            num_layers=1,
            batch_first=True
        )
        self.fc1 = nn.Linear(2, 1)
        
    def forward(self, x):
        outs = []
        for time_step in range(x.size(1)):
            temp = self.pool0(self.conv0(x[:, time_step]))
            temp = F.relu(self.conv1(temp))
            temp = F.relu(self.conv2(temp))
            temp = temp.view(-1, 1*5*5)
            outs.append(temp)
        outs = torch.stack(outs, dim=1)
        x, _ = self.rnn(outs, None)
        x = self.fc1(x[:, -1, :])
        return x

In [1]:
seq = 6

X_train = np.zeros((1140 - seq, seq, 1, 89, 180), dtype = np.float32)
X_valid = np.zeros((60, seq, 1, 89, 180), dtype = np.float32)
X_test = np.zeros((1, seq, 1, 89, 180), dtype = np.float32)

for i in range(6):
    X_train[:, i, 0] = SST[i:1140 - seq + i, 0]
    X_valid[:, i, 0] = SST[1140 - seq + i:1200 - seq + i, 0]
    X_test[:, i, 0] = SST[- (seq + 1) + i, 0]

y_train = SST[seq:1140, 0, row[1], row[0]]
y_valid = SST[1140:1200, 0, row[1], row[0]]

import torch

X_train = torch.from_numpy(X_train).float()
X_valid = torch.from_numpy(X_valid).float()
X_test = torch.from_numpy(X_test).float()
y_train = torch.from_numpy(y_train.reshape(len(y_train), 1)).float()
y_valid = torch.from_numpy(y_valid.reshape(len(y_valid), 1)).float()

In [30]:
net = Net()
criterion = nn.MSELoss()
optimizer = AdaBound(net.parameters(), lr=0.001)

In [31]:
for epoch in range(4000):  # loop over the dataset multiple times
    optimizer.zero_grad()
    outputs = net(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()

    outputs_valid = net(X_valid)
    loss_valid = criterion(outputs_valid, y_valid)
    if epoch % 10 == 0:
        print(epoch + 1, loss.item(), loss_valid.item())

1 1.453825831413269 3.4785237312316895
11 1.3641479015350342 3.290745496749878
21 1.289772629737854 3.150963068008423
31 1.2160483598709106 2.9811084270477295
41 1.126387357711792 2.7234649658203125
51 1.0313825607299805 2.36138653755188
61 0.9643883109092712 1.9766194820404053
71 0.9423432946205139 1.7252557277679443
81 0.9316885471343994 1.6652253866195679
91 0.9186010360717773 1.7021613121032715
101 0.9077789187431335 1.7163318395614624
111 0.8953521251678467 1.6632744073867798
121 0.8818134069442749 1.622038722038269
131 0.8667826056480408 1.5955651998519897
141 0.8499716520309448 1.529783844947815
151 0.8319727182388306 1.4340851306915283
161 0.8139520883560181 1.3503117561340332
171 0.7955959439277649 1.2886059284210205
181 0.775851845741272 1.243478536605835
191 0.7544150352478027 1.1932389736175537
201 0.7302984595298767 1.1343501806259155
211 0.7035053372383118 1.0706740617752075
221 0.6747082471847534 1.0151780843734741
231 0.6445348858833313 0.9651433825492859
241 0.61414730

1901 0.1591615229845047 0.2918263375759125
1911 0.15904268622398376 0.29157131910324097
1921 0.15892083942890167 0.29087117314338684
1931 0.15879783034324646 0.2904498875141144
1941 0.15867513418197632 0.2903498709201813
1951 0.15855339169502258 0.29001638293266296
1961 0.1584317684173584 0.28984779119491577
1971 0.15831208229064941 0.2895461916923523
1981 0.15819433331489563 0.2890479266643524
1991 0.15807512402534485 0.2889874279499054
2001 0.15795713663101196 0.28873297572135925
2011 0.15783974528312683 0.28844282031059265
2021 0.15772013366222382 0.28776687383651733
2031 0.15759922564029694 0.2877436578273773
2041 0.15747198462486267 0.2875958979129791
2051 0.15734907984733582 0.2869258224964142
2061 0.1572280079126358 0.2864903211593628
2071 0.15710864961147308 0.28616100549697876
2081 0.15699060261249542 0.28590041399002075
2091 0.1568754017353058 0.28536608815193176
2101 0.15675915777683258 0.28523728251457214
2111 0.1566445529460907 0.2851386070251465
2121 0.15653274953365326 0

3761 0.1416388899087906 0.28284141421318054
3771 0.14158231019973755 0.28298231959342957
3781 0.14152663946151733 0.28294071555137634
3791 0.1414782851934433 0.2824057936668396
3801 0.14142553508281708 0.2828768789768219
3811 0.14137478172779083 0.282870352268219
3821 0.14133167266845703 0.28273120522499084
3831 0.1412825882434845 0.2829110026359558
3841 0.14123421907424927 0.28209590911865234
3851 0.14118598401546478 0.2831716239452362
3861 0.14113298058509827 0.283015638589859
3871 0.1410873979330063 0.2822521924972534
3881 0.14103804528713226 0.2826055884361267
3891 0.14099128544330597 0.28282082080841064
3901 0.1409469097852707 0.2821837067604065
3911 0.14089640974998474 0.2822893559932709
3921 0.14084702730178833 0.28288862109184265
3931 0.14080318808555603 0.2827669382095337
3941 0.1407611072063446 0.28265073895454407
3951 0.1407202184200287 0.2828146815299988
3961 0.14068321883678436 0.2827502191066742
3971 0.14064501225948334 0.28297895193099976
3981 0.14060911536216736 0.28334

5631 0.13794517517089844 0.2894881069660187
5641 0.13792909681797028 0.2894783020019531
5651 0.13792046904563904 0.2895129919052124
5661 0.13791495561599731 0.2893691062927246
5671 0.13790304958820343 0.28927621245384216
5681 0.13789476454257965 0.29052576422691345
5691 0.13788066804409027 0.2905583679676056
5701 0.13787665963172913 0.29006436467170715
5711 0.13786260783672333 0.29039424657821655
5721 0.13785555958747864 0.2899400591850281
5731 0.13783757388591766 0.28992384672164917
5741 0.13782630860805511 0.29006442427635193
5751 0.1378134787082672 0.2905634343624115
5761 0.13781093060970306 0.2897886633872986
5771 0.13779282569885254 0.2893907427787781
5781 0.1377752125263214 0.29012778401374817
5791 0.13776487112045288 0.28930041193962097
5801 0.13773909211158752 0.2891503572463989
5811 0.13772737979888916 0.2900186777114868
5821 0.13771869242191315 0.2897266149520874
5831 0.1376902312040329 0.29028764367103577
5841 0.13768824934959412 0.2890856862068176
5851 0.13766445219516754 0

7491 0.13670286536216736 0.29155606031417847
7501 0.13673120737075806 0.28891584277153015
7511 0.13669827580451965 0.2925949990749359
7521 0.1366950273513794 0.2925216257572174
7531 0.13668319582939148 0.29104506969451904
7541 0.13666926324367523 0.2911246120929718
7551 0.13666000962257385 0.29157528281211853
7561 0.13665421307086945 0.29013898968696594
7571 0.13665467500686646 0.2906499207019806
7581 0.13666194677352905 0.28969907760620117
7591 0.13664312660694122 0.29196277260780334
7601 0.13666297495365143 0.29278114438056946
7611 0.13670222461223602 0.2891678512096405
7621 0.13662710785865784 0.2905576825141907
7631 0.13661304116249084 0.2917362153530121
7641 0.13662150502204895 0.2919571101665497
7651 0.13664893805980682 0.28979238867759705
7661 0.13661940395832062 0.2922920882701874
7671 0.1366599202156067 0.28958073258399963
7681 0.13659577071666718 0.2914360761642456
7691 0.1365918517112732 0.29121801257133484
7701 0.13659387826919556 0.2900587022304535
7711 0.13658776879310608

KeyboardInterrupt: 

In [32]:
outputs = net(X_valid)
loss = criterion(outputs, y_valid)

for i in range(60):
    print(y_valid[i].item() + month_mean[i % 12, row[1], row[0]], outputs[i].item() + month_mean[i % 12, row[1], row[0]])
    
print('\n')

print(loss.item())

25.193620681762695 25.01722028851509
25.725605010986328 26.42677190899849
26.790624618530273 26.16276741027832
26.54935073852539 26.095219999551773
26.679452896118164 25.261855751276016
26.143993377685547 25.640332102775574
24.62432098388672 24.748792052268982
23.667221069335938 23.253288745880127
23.388002395629883 22.943576455116272
23.72989273071289 23.45765197277069
24.111202239990234 24.028814792633057
24.508861541748047 24.712013721466064
25.26885223388672 25.380118012428284
26.084409713745117 26.31078115105629
27.022750854492188 26.47216036915779
27.623594284057617 26.428744226694107
27.09197425842285 26.740322589874268
26.963388442993164 26.447693824768066
26.37074851989746 25.8074848651886
24.832151412963867 25.606664419174194
25.315359115600586 24.760225772857666
25.576629638671875 26.038988828659058
26.05023956298828 25.8993136882782
26.717212677001953 26.674001932144165
27.224092483520508 28.19626259803772
27.63623046875 28.339473485946655
28.250402450561523 28.058966040611

预测（92°W，0°）今年3月份海温（真实值为27.26599℃）

In [33]:
y_test = net(X_test) + month_mean[2, row[1], row[0]]
y_test.item()

27.58335304260254