In [1]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

In [188]:
inputx0 = pd.read_csv('inputrnn.csv')
inputy0 = pd.read_csv('tlo2.csv')

In [189]:
days = inputx0.Nday.unique()

In [190]:
d1 = inputx0.Nday.nunique()
d2 = 48
d3 = 7

In [191]:
i1 = np.zeros((d1, d2, d3))
for i in range(0,366):
    da = 'd' + str(i+1)
    aux1 = inputx0[inputx0.Nday == da]
    aux2 = aux1.drop('Nday',axis =1).values
    i1[i] = aux2

In [192]:
i2 = inputy0.loc[:,'l118'].values

In [212]:
i2[i2 == 2] = 0

In [214]:
X_train, X_test, y_train, y_test = train_test_split(i1, i2, train_size = 0.8, test_size = 0.2)
Xt_train = torch.from_numpy(X_train).double()
Xt_test = torch.from_numpy(X_test).double()
yt_train = torch.from_numpy(y_train).double()
yt_test = torch.from_numpy(y_test).double()

In [215]:
train_dataset = torch.utils.data.TensorDataset(Xt_train,yt_train)
test_dataset = torch.utils.data.TensorDataset(Xt_test,yt_test)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          shuffle=False)

In [216]:
class RNN1(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN1, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

In [306]:
n_hidden = 400
n_categories  = 2
n_TS = 7
rnn = RNN1(n_TS, n_hidden, n_categories)

In [307]:
input = Xt_train[0]
hidden = torch.zeros(1, n_hidden)
output, next_hidden = rnn(input[0].view(1,-1).float(), hidden)
print(output)

tensor([[-36.5892,   0.0000]], grad_fn=<LogSoftmaxBackward>)


In [308]:
criterion = nn.NLLLoss()
learning_rate = 0.0001
epoch = 40
print_every = 30
iter = 0
for ep in range(1, epoch):
    for i, (data, labels) in enumerate(train_loader):
        labels = labels.long()
        #data = data.long()
        hidden = rnn.initHidden()
        rnn.zero_grad()
        for j in range(data.size(1)):
            output, hidden = rnn(data[0][j].view(1,-1).float(), hidden)
        loss = criterion(output, labels)
        print(loss.item())
        loss.backward()
        for p in rnn.parameters():
            p.data.add_(-learning_rate, p.grad.data)
        iter += 1


0.0
89.53321838378906
1553.0465087890625
0.0
0.0
0.0
0.0
0.0
0.0
172.75912475585938
404.70416259765625
284.07305908203125
200.71133422851562
1653.8017578125
618.760498046875
0.0
0.0
0.0
0.0
879.3673095703125
814.594970703125
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
417.117431640625
639.7019653320312
0.0
234.42361450195312
619.7881469726562
0.0
315.213134765625
271.1553955078125
857.1898193359375
629.6823120117188
0.0
0.0
852.4925537109375
292.9879455566406
0.0
1104.947509765625
0.0
0.0
0.0
0.019728660583496094
0.0
66.64097595214844
0.0
0.0
0.0
0.0
548.73583984375
192.96524047851562
1446.3095703125
0.0
0.0
369.06695556640625
0.0
0.0
247.77108764648438
0.0
689.3628540039062
214.83067321777344
0.0
0.0
0.0
932.679443359375
0.18509960174560547
0.0
0.0
0.0
22.26390266418457
804.2195434570312
0.0
195.20285034179688
0.0
0.0
772.5343017578125
230.71644592285156
531.4317626953125
412.38885498046875
654.3389892578125
369.62164306640625
511.6240234375
0.0
0.0
323.37664794921875
419.1397399902344
0.0
766.21

60.45775604248047
0.0
1861.279296875
0.0
0.0
0.0
0.0
106.29783630371094
0.0
0.0
1352.541259765625
194.48309326171875
0.0
597.5167846679688
172.8477783203125
0.0
421.5081787109375
0.0
233.282958984375
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
587.1295776367188
490.93414306640625
160.11077880859375
461.5560302734375
0.0
0.0
416.4915771484375
0.0
315.726318359375
0.0
390.69171142578125
0.0
0.0
240.18243408203125
0.0
0.0
861.4116821289062
0.0
0.0
0.0
0.0
0.0
0.0
643.6361083984375
43.16535186767578
0.0
0.0
345.163330078125
0.0
669.4456176757812
0.0
315.98583984375
0.0
0.0
0.0
0.0
0.0
399.24029541015625
0.0
1193.7452392578125
0.0
0.0
0.0
188.65875244140625
0.0
0.0
0.0
0.0
0.0
0.0
530.686767578125
0.0
517.353515625
0.0
0.0
712.0372314453125
0.0
405.80657958984375
1050.783935546875
0.0
0.0
214.98773193359375
0.0
304.9583740234375
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
554.5205078125
0.0
0.0
82.70448303222656
0.0
192.28460693359375
0.0
642.3883056640625
0.0
0.0
0.0
0.12779569625854492
0.0
670.9456787109

0.0
0.0
0.0
0.0
0.0
46.20372009277344
573.49072265625
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
763.0659790039062
0.0
0.0
326.0093688964844
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
557.4422607421875
0.0
0.0
0.0
0.0
215.79708862304688
0.0
46.73497009277344
872.7625732421875
0.0
0.0
0.0
0.0
0.0
509.627197265625
0.0
0.0
0.0
272.8323669433594
0.0
84.04136657714844
0.0
82.85861206054688
554.806640625
0.0
0.0
377.1687316894531
996.1141967773438
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
293.2411804199219
48.21498107910156
955.2583618164062
0.0
0.0
153.89573669433594
430.2177429199219
0.0
728.1912841796875
0.0
0.0
0.0
0.0
0.0
0.0
653.366455078125
0.0
0.0
0.0
187.59214782714844
338.7585144042969
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
37.95799255371094
292.4884948730469
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
395.871826171875
18.297943115234375
306.3072509765625
793.4114379882812
0.0
0.0
576.3594970703125
0.0
10.44573974609375
0.0
1202.1875
0.0
415.1974792480469
0.0
0.0
0.0
0.0
0.0
618.3424072265625
0.0
0.0
0.0
0.0
0.0
50.538909912109375


0.0
0.0
0.0
624.7266845703125
0.0
69.51579284667969
153.20803833007812
0.0
0.0
434.57025146484375
0.0
0.0
0.0
797.6860961914062
0.0
0.0
0.0
164.8331756591797
0.0
0.0
0.0
626.3472900390625
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
569.5530395507812
0.0
0.0
699.9198608398438
0.0
0.0
0.0
0.0
0.0
0.0
162.6106719970703
69.275390625
0.0
1336.896484375
0.0
452.356201171875
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
502.448486328125
0.0
0.0
0.0
0.0
0.0
368.0941162109375
0.0
244.82565307617188
0.0
0.0
0.0
0.0
0.0
0.0
270.9649658203125
587.1506958007812
0.0
0.0
0.0
0.0
0.0
0.0
756.6484375
0.0
0.0
0.0
406.0518798828125
0.0
0.0
0.0
521.4967041015625
239.2233428955078
0.0
0.0
603.29541015625
353.82781982421875
0.0
0.0
0.0
897.4471435546875
6.461784362792969
1057.2969970703125
288.00555419921875
0.0
0.0
0.0
0.0
0.0
0.0
301.04852294921875
0.0
0.0
0.0
0.0
0.0
0.0
0.0
72.49128723144531
0.0
0.0
464.7474365234375
119.55680847167969
898.2410888671875
0.0
291.91705322265625
0.0
927.90808

0.0
234.55113220214844
0.0
0.0
1343.396240234375
0.0
0.0
315.8402099609375
786.537353515625
0.0
154.2458953857422
0.0
0.0
941.540771484375
0.0
236.75050354003906
986.102294921875
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
35.59700012207031
0.0
0.0
0.0
0.0
883.8306274414062
0.0
0.0
33.09385681152344
0.0
0.0
0.0
0.0
0.0
0.0
508.695556640625
0.0
0.0
220.92015075683594
0.0
804.37158203125
0.0
0.0
309.183837890625
0.0
0.0
0.0
0.0
0.0
639.3785400390625
0.0
0.0
0.0
0.0
0.0
41.48213195800781
0.0
0.0
0.0
0.0
998.877197265625
0.0
0.0
0.0
0.0
330.672119140625
523.8792114257812
0.0
0.0
36.47691345214844
394.7123718261719
0.0
199.94703674316406
440.70526123046875
0.0
0.0
0.0
0.0
37.91514587402344
0.0
545.070068359375
0.0
0.0
0.0
0.0
0.0
0.0
173.1200408935547
0.0
245.38485717773438
0.0
0.0
103.00376892089844
0.0
877.755126953125
0.0
0.0
0.0
19.126571655273438
0.0
1061.67041015625
0.0
140.03395080566406
427.60150146484375
0.0
0.0
0.0
0.0
1.3945722579956055
7.2133960723876

743.861572265625
715.7996215820312
457.9841003417969
0.0
0.0
627.2990112304688
0.0
330.5274963378906
0.0
0.0
0.0
0.0
0.0
0.0
414.749267578125
0.0
105.84347534179688
400.70965576171875
198.71231079101562
622.7323608398438
401.4421081542969
0.0
867.3130493164062
420.1950378417969
0.0
102.86119079589844
0.0
0.0
170.12472534179688
0.0
0.0
573.1754150390625
0.0
188.59805297851562
1462.78564453125
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
65.29379272460938
0.0
277.997314453125
0.0
31.55465316772461
94.76927185058594
123.55087280273438
424.2607421875
0.0
0.0
267.4136657714844
140.39208984375
0.0
126.00507354736328
0.0
0.0
48.7946891784668
0.0
0.0
1197.755126953125
149.3260498046875
0.0
0.0
133.94781494140625
355.89666748046875
404.38201904296875
0.0
0.0
0.0
0.0
0.0
0.0
320.39794921875
0.0
63.48334503173828
487.19384765625
0.0
0.0
0.0
0.0
0.0
0.0
915.0335693359375
0.0
0.0
0.0
0.0
13.091565132141113
0.0
0.0
209.72808837890625
0.003960609436035156
0.0
0.0
423.2264404296875
576.0125122070312
0.0
14.2280378

0.0
423.5977783203125
0.0
0.0
576.81787109375
308.38458251953125
0.0
0.0
0.0
166.546630859375
0.0
304.65716552734375
0.0
0.0
0.04428863525390625
0.0
263.2280578613281
0.0
0.0
550.5680541992188
0.0
0.0
0.0
0.0
298.59539794921875
0.0
0.0
947.8848876953125
114.49794006347656
0.0
0.0
0.0
341.84173583984375
0.0
141.11769104003906
438.41632080078125
0.0
0.0
0.0004647970199584961
0.0
0.0
0.0
0.0
0.0
0.0
0.0
51.08644104003906
557.863525390625
0.0
0.0
0.0
0.0
124.61228942871094
0.0
443.0717468261719
297.68603515625
327.3915710449219
0.0
0.0
0.0
0.0
0.0
0.0
0.0
648.0213623046875
0.0
0.0
0.0
0.0
304.8757019042969
179.5868377685547
0.0
640.428955078125
407.56005859375
0.0
145.13693237304688
0.0
332.4517822265625
0.4466896057128906
0.0
0.0
292.0047302246094
682.4972534179688
192.51303100585938
0.0
0.0
370.9468688964844
0.0
0.0
334.6681823730469
0.0
362.8516540527344
0.0024566650390625
274.9472961425781
0.0
1738.332763671875
495.4977111816406
0.0
0.0
320.1351013183594
0.0
6.198883056640625e-06
0.0
0

353.40130615234375
356.9593505859375
499.26055908203125
203.1463623046875
0.0
0.0
0.0
0.0
588.5153198242188
0.0
370.71343994140625
0.0
305.47930908203125
0.0
0.0
0.0
0.0
0.0
0.0
153.50457763671875
0.0
1232.953857421875
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
380.85284423828125
628.7278442382812
0.0
155.895263671875
529.5330810546875
69.592041015625
0.0
0.0
0.0
0.0
0.0
0.0
1040.543701171875
441.2335205078125
0.0
0.0
213.56320190429688
0.0
79.91015625
0.0
917.1309204101562
0.0
0.0
0.0
0.0
364.49951171875
0.0
128.38912963867188
0.0
0.0
0.0
837.60986328125
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
369.2964782714844
682.99755859375
0.0
0.0
0.0
0.0
0.0
1.9073486328125e-06
0.0
757.8870239257812
0.0
184.6722412109375
0.0
0.0
0.0
0.0
319.4010925292969
800.6806030273438
0.0
108.0018539428711
0.0
0.0
431.40155029296875
0.0
267.7834167480469
58.352535247802734
0.0
0.0
212.30270385742188
0.0
0.0
710.2320556640625
0.0
0.0
0.0
0.0
0.0
0.0
524.7659301757812
0.0
416.353515625
0.0
0.0
0.0
0.0
457.187225341796

0.0
0.0
0.0
0.0
347.6812744140625
0.0
0.0
0.0
0.0
0.0
0.0
513.3698120117188
0.0
0.0
0.0
0.0
0.0
0.0
324.2777099609375
0.0
0.0
398.8779296875
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
438.7186279296875
534.4938354492188
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
2.5510525703430176
0.0
0.0
1217.52001953125
0.0
182.870849609375
0.0
0.0
0.0
444.8306579589844
850.729248046875
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
414.0486755371094
0.0
0.0
0.0
38.06360626220703
1598.579345703125
0.0
0.0
0.0
0.0
44.30150604248047
432.8711853027344
0.0
0.0
0.0
567.6114501953125
0.0
0.0
351.4607849121094
0.0
47.54308319091797
160.22439575195312
0.0
0.0
891.5550537109375
633.1436767578125
0.0
0.0
275.545654296875
0.0
59.32589340209961
0.0
0.0
0.0
0.0
0.0
471.72467041015625
394.5447692871094
0.0
338.5074462890625
0.0
0.0
511.7591552734375
0.0
1044.9730224609375
422.3844299316406
0.0
0.0
0.0
0.0
0.0
0.0
282.03271484375
0.0
0.0
0.0
1407.6494140625
0.0
313.8848876953125
1025.48779296875
0.0
0.0
0.0
0.0
0.0
0.0


77.3001480102539
388.3682556152344
270.9880065917969
0.0
560.1849365234375
0.0
489.497802734375
216.39486694335938
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
931.4561767578125
770.746337890625
0.0
88.4393539428711
796.428955078125
234.51925659179688
0.0
0.0
0.0
0.0
771.2539672851562
9.886747360229492
0.0
0.0
0.0
0.0
0.0
657.1387939453125
811.8531494140625
0.0
0.0
0.0
628.8211669921875
0.0
0.0
0.0
0.0
0.0
0.0
0.0
537.869140625
136.56832885742188
0.0
0.0
265.8985290527344
0.0
649.787109375
0.0
0.0
334.9402160644531
457.1305847167969
0.0
0.0
0.0
0.0
0.0
164.71444702148438
0.0
553.155029296875
70.1655044555664
0.0
488.0253601074219
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
240.51638793945312
147.17355346679688
0.0
223.09933471679688
0.0
0.0
0.0
0.0
159.26156616210938
421.9134216308594
0.0
0.0
0.0
142.34420776367188
521.519775390625
0.0
0.0
0.0
0.0
0.8829250335693359
0.0
272.5225830078125
0.0
0.0
0.0
146.73004150390625
142.39723205566406
0.0
0.0
0.0
40.40972900390625
0.0
615.6486206054688
0.0
25

663.6669921875
0.0
259.5634765625
157.97647094726562
0.0
0.0
0.0
0.0
0.0
0.0
0.0
272.21148681640625
0.0
0.0
0.0
27.876850128173828
0.0
631.9508666992188
208.18630981445312
609.73046875
0.0
0.0
387.249267578125
0.0
0.0
615.8674926757812
0.0
0.0
373.4820556640625
996.7236938476562
29.371318817138672
0.0
0.0
1012.6494140625
0.0
0.0
0.0
0.0
0.0
0.0
0.0
591.4609985351562
336.85113525390625
5.53131103515625e-05
0.0
0.0
1.2859010696411133
0.0
0.0
901.62939453125
79.04898071289062
680.6392822265625
0.0
336.5814514160156
0.0
0.0
0.0
0.0
0.0
342.55145263671875
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
187.59536743164062
734.0814819335938
753.1356811523438
0.0
0.0
258.29302978515625
562.9214477539062
0.0
0.0
0.0
0.0
228.4899139404297
0.0
0.0
235.19900512695312
0.0
93.79716491699219
0.0
136.26248168945312
150.28067016601562
0.0
0.0
0.0
1115.7030029296875
71.04093933105469
0.0
0.0
110.32876586914062
0.0
0.0
0.0
0.0
0.0
80.53630065917969
0.0
957.51220703125
0.0
0.0
0.0
0.0
177.18325805664062
0.0
0.0
460.9

0.0
726.4729614257812
28.810937881469727
0.0
0.0
0.0
0.0
242.09808349609375
0.0
0.0
0.0
882.31201171875
0.0
0.0
1243.157470703125
270.3220520019531
0.0
0.0
0.0
0.0
0.0
458.39190673828125
132.02542114257812
581.3975219726562
0.0
0.0
319.6181945800781
0.0
0.0
0.0
544.3316040039062
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
810.7315673828125
0.0
222.57794189453125
0.0
0.0
0.0
0.0
0.0
0.0
0.0
243.79251098632812
28.11321258544922
101.46273803710938
0.0
0.0
0.0
1581.552978515625
103.95108032226562
0.0
0.0
0.0
674.8866577148438
0.0
0.0
216.58596801757812
0.0
0.0
706.7286987304688
577.6032104492188
182.43133544921875
0.0
0.0
0.0
0.0
0.0
0.0
0.0
44.534034729003906
577.2249755859375
0.0
619.947998046875
165.52935791015625
0.0
30.561254501342773
0.0
1258.1632080078125
0.0
99.31460571289062
63.10027313232422
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
659.7764892578125
0.0
696.5527954101562
0.0
0.0
4.76837158203125e-06
85.50857543945312
0.0
0.0
0.0
0.0
0.0
0.0
173.22808837890625
866.69720

0.0
835.7368774414062
119.31094360351562
0.0
0.0
144.41799926757812
0.0
0.0
0.0
0.0
5.340576171875e-05
0.0
0.0
591.489501953125
275.34942626953125
0.0
133.30221557617188
1258.8602294921875
0.0
0.0
0.0
13.198402404785156
654.6453247070312
0.0
0.0
154.49050903320312
603.9089965820312
467.0879211425781
0.0
488.15228271484375
161.23379516601562
0.0
0.0
75.14018249511719
0.0
0.0
0.0
0.0
395.2903137207031
0.0
0.0
0.0
683.15185546875
0.0
0.0
0.0
0.0
369.2026672363281
0.0
406.26031494140625
0.0
0.0
11.055227279663086
461.63397216796875
0.0
0.0
0.0
1006.4788818359375
691.788818359375
0.0
0.0
0.0
110.93008422851562
0.0
0.0
0.0
598.9124145507812
0.0
0.0
0.0
0.0
697.4494018554688
0.0
132.1967010498047
32.19156265258789
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
158.22146606445312
0.0
0.0
380.5218811035156
0.0
759.8834228515625
0.0
0.0
0.0
0.0
394.5683898925781
0.0
0.0
0.0
0.0
0.0
0.0
256.17047119140625
204.73147583007812
0.0
1129.3118896484375
894.7797241210938
0.0
0.0
485.13360595703125
380.023101806640

In [313]:
def getcat(output):
    top_n, top_i = output.topk(1)
    cat = top_i[0].item()
    return cat

In [320]:
correct = 0
for i, (data, labels) in enumerate(test_loader):
    labels = labels.item()
    for j in range(data.size(1)):
        output, hidden = rnn(data[0][j].view(1,-1).float(), hidden)
    cag1 = getcat(output)
    if cag1 == labels:
        correct += 1

In [323]:
correct/y_test.size

0.527027027027027