# Tutorial: CNN with Pytorch using Opp dataset

In [1]:
import sys
sys.path.append('../')

In [2]:
# general
import numpy as np
import pandas as pd
from progressbar import ProgressBar

# dataset
from actreco.datasets.opportunityc import Opportunity
from actreco.utils.generator import batch_generator

# modeling
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable


In [3]:
# dataset

In [13]:
train_set = Opportunity(userID='S2,S3,S4')
x_train_list = train_set.get('X')
y_train_list = train_set.get('y')

In [14]:
test_set = Opportunity(userID='S1')
x_test_list = test_set.get('X')
y_test_list = test_set.get('y')

In [15]:
class VanilaCNN(nn.Module):
    def __init__(self, l_sample, nb_modal, nb_in_filter, nb_filter_list, kw, nb_unit, nb_out):
        super(VanilaCNN, self).__init__()
        self.l_sample = l_sample
        self.nb_modal = nb_modal
        self.nb_conv = len(nb_filter_list)
        
        if isinstance(kw, int):
            kw = [kw] * self.nb_conv
        assert len(nb_filter_list) == len(kw)
        nb_filter_list = [nb_in_filter] + nb_filter_list
        
        self.kw = kw
        self.nb_filter_list = nb_filter_list

        convs = []
        for i in range(self.nb_conv):
            convs.append(nn.Conv2d(nb_filter_list[i], nb_filter_list[i+1], (1, kw[i])))
        self.convs = nn.ModuleList(convs)
        self.fc1 = nn.Linear(self.conv_output_shape(), nb_unit)
        self.fc2 = nn.Linear(nb_unit, nb_out)
    
    def forward(self, x):
        for conv in self.convs:
            x = F.max_pool2d(F.relu(conv(x)), (1, 2))
            dropout = nn.Dropout2d(0.5)
            # x = dropout(x)
        x = x.view(-1, self.conv_output_shape())
        x = F.dropout(self.fc1(x), 0.5)
        x = F.relu(x)
        x = self.fc2(x)
        return x
    
    def conv_output_shape(self):
        length = self.l_sample
        for i in range(self.nb_conv):
            length = int((length - self.kw[i] + 1) /2)
        return length * self.nb_filter_list[-1] * self.nb_modal

In [28]:
nb_modal_dim = 2
nb_modal = train_set.nb_modal
nb_out = train_set.nb_class

kw = 3
nb_filter_list = [50, 40, 20]
nb_in_filter = 1
drop_rate = 0.5
nb_unit = 400

batch_size = 128
l_sample = 30
interval = int(l_sample) * 0.5

nb_iter = 2000
report_each = 10

model = VanilaCNN(l_sample=l_sample, nb_modal=nb_modal, nb_in_filter=1, nb_filter_list=nb_filter_list, kw=kw, nb_unit=nb_unit, nb_out=nb_out).cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = nn.CrossEntropyLoss()

In [33]:
model.train()

p = ProgressBar(max_value=nb_iter)
gen = batch_generator(x_train_list, y_train_list, batch_size, l_sample, train=True, nb_iter=nb_iter, categorical=False)
for batch_idx, (X, y) in enumerate(gen):
    batch_idx += 1
    tX = Variable(torch.from_numpy(X.astype('float32')), requires_grad=False).cuda()
    ty = Variable(torch.LongTensor(y), requires_grad=False).cuda()
    
    optimizer.zero_grad()
    output = model(tX)
    
    loss = criterion(output, ty)
    loss.backward()
    optimizer.step()
    if batch_idx % report_each == 0:
        print("{} samples, {} batch \t Loss: {:.6f}".format(batch_idx * batch_size, batch_idx, loss.data[0]))
    p.update(batch_idx)
    

  0% (14 of 2000) |                       | Elapsed Time: 0:00:00 ETA:  0:01:24

1280 samples, 10 batch 	 Loss: 0.283480


  1% (23 of 2000) |                       | Elapsed Time: 0:00:00 ETA:  0:01:23

2560 samples, 20 batch 	 Loss: 0.301822


  1% (34 of 2000) |                       | Elapsed Time: 0:00:01 ETA:  0:01:22

3840 samples, 30 batch 	 Loss: 0.163103


  2% (43 of 2000) |                       | Elapsed Time: 0:00:01 ETA:  0:01:22

5120 samples, 40 batch 	 Loss: 0.327317


  2% (54 of 2000) |                       | Elapsed Time: 0:00:02 ETA:  0:01:21

6400 samples, 50 batch 	 Loss: 0.339387


  3% (63 of 2000) |                       | Elapsed Time: 0:00:02 ETA:  0:01:21

7680 samples, 60 batch 	 Loss: 0.305300


  3% (75 of 2000) |                       | Elapsed Time: 0:00:03 ETA:  0:01:20

8960 samples, 70 batch 	 Loss: 0.192852


  4% (82 of 2000) |                       | Elapsed Time: 0:00:03 ETA:  0:01:22

10240 samples, 80 batch 	 Loss: 0.297979


  4% (93 of 2000) |#                      | Elapsed Time: 0:00:04 ETA:  0:01:27

11520 samples, 90 batch 	 Loss: 0.288208


  5% (105 of 2000) |#                     | Elapsed Time: 0:00:04 ETA:  0:01:29

12800 samples, 100 batch 	 Loss: 0.258656


  5% (114 of 2000) |#                     | Elapsed Time: 0:00:05 ETA:  0:01:28

14080 samples, 110 batch 	 Loss: 0.177608


  6% (123 of 2000) |#                     | Elapsed Time: 0:00:05 ETA:  0:01:28

15360 samples, 120 batch 	 Loss: 0.347669


  6% (135 of 2000) |#                     | Elapsed Time: 0:00:05 ETA:  0:01:24

16640 samples, 130 batch 	 Loss: 0.337460


  7% (144 of 2000) |#                     | Elapsed Time: 0:00:06 ETA:  0:01:19

17920 samples, 140 batch 	 Loss: 0.222116


  7% (155 of 2000) |#                     | Elapsed Time: 0:00:06 ETA:  0:01:16

19200 samples, 150 batch 	 Loss: 0.173426


  8% (164 of 2000) |#                     | Elapsed Time: 0:00:07 ETA:  0:01:14

20480 samples, 160 batch 	 Loss: 0.243142


  8% (173 of 2000) |#                     | Elapsed Time: 0:00:07 ETA:  0:01:14

21760 samples, 170 batch 	 Loss: 0.256578


  9% (184 of 2000) |##                    | Elapsed Time: 0:00:07 ETA:  0:01:13

23040 samples, 180 batch 	 Loss: 0.211985


  9% (193 of 2000) |##                    | Elapsed Time: 0:00:08 ETA:  0:01:13

24320 samples, 190 batch 	 Loss: 0.205997


 10% (205 of 2000) |##                    | Elapsed Time: 0:00:08 ETA:  0:01:13

25600 samples, 200 batch 	 Loss: 0.167587


 10% (214 of 2000) |##                    | Elapsed Time: 0:00:09 ETA:  0:01:13

26880 samples, 210 batch 	 Loss: 0.156177


 11% (223 of 2000) |##                    | Elapsed Time: 0:00:09 ETA:  0:01:13

28160 samples, 220 batch 	 Loss: 0.254321


 11% (234 of 2000) |##                    | Elapsed Time: 0:00:09 ETA:  0:01:12

29440 samples, 230 batch 	 Loss: 0.218105


 12% (243 of 2000) |##                    | Elapsed Time: 0:00:10 ETA:  0:01:12

30720 samples, 240 batch 	 Loss: 0.337280


 12% (254 of 2000) |##                    | Elapsed Time: 0:00:10 ETA:  0:01:11

32000 samples, 250 batch 	 Loss: 0.345940


 13% (263 of 2000) |##                    | Elapsed Time: 0:00:11 ETA:  0:01:10

33280 samples, 260 batch 	 Loss: 0.256340


 13% (275 of 2000) |###                   | Elapsed Time: 0:00:11 ETA:  0:01:09

34560 samples, 270 batch 	 Loss: 0.224109


 14% (284 of 2000) |###                   | Elapsed Time: 0:00:11 ETA:  0:01:09

35840 samples, 280 batch 	 Loss: 0.148641


 14% (293 of 2000) |###                   | Elapsed Time: 0:00:12 ETA:  0:01:09

37120 samples, 290 batch 	 Loss: 0.253190


 15% (304 of 2000) |###                   | Elapsed Time: 0:00:12 ETA:  0:01:09

38400 samples, 300 batch 	 Loss: 0.139847


 15% (313 of 2000) |###                   | Elapsed Time: 0:00:13 ETA:  0:01:09

39680 samples, 310 batch 	 Loss: 0.244233


 16% (325 of 2000) |###                   | Elapsed Time: 0:00:13 ETA:  0:01:08

40960 samples, 320 batch 	 Loss: 0.209142


 16% (333 of 2000) |###                   | Elapsed Time: 0:00:13 ETA:  0:01:08

42240 samples, 330 batch 	 Loss: 0.147320


 17% (344 of 2000) |###                   | Elapsed Time: 0:00:14 ETA:  0:01:08

43520 samples, 340 batch 	 Loss: 0.220233


 17% (353 of 2000) |###                   | Elapsed Time: 0:00:14 ETA:  0:01:07

44800 samples, 350 batch 	 Loss: 0.249852


 18% (364 of 2000) |####                  | Elapsed Time: 0:00:15 ETA:  0:01:07

46080 samples, 360 batch 	 Loss: 0.178560


 18% (373 of 2000) |####                  | Elapsed Time: 0:00:15 ETA:  0:01:08

47360 samples, 370 batch 	 Loss: 0.168465


 19% (383 of 2000) |####                  | Elapsed Time: 0:00:16 ETA:  0:01:08

48640 samples, 380 batch 	 Loss: 0.250009


 19% (395 of 2000) |####                  | Elapsed Time: 0:00:16 ETA:  0:01:08

49920 samples, 390 batch 	 Loss: 0.155370


 20% (404 of 2000) |####                  | Elapsed Time: 0:00:16 ETA:  0:01:07

51200 samples, 400 batch 	 Loss: 0.289069


 20% (415 of 2000) |####                  | Elapsed Time: 0:00:17 ETA:  0:01:06

52480 samples, 410 batch 	 Loss: 0.156936


 21% (424 of 2000) |####                  | Elapsed Time: 0:00:17 ETA:  0:01:06

53760 samples, 420 batch 	 Loss: 0.298257


 21% (433 of 2000) |####                  | Elapsed Time: 0:00:18 ETA:  0:01:04

55040 samples, 430 batch 	 Loss: 0.230051


 22% (445 of 2000) |####                  | Elapsed Time: 0:00:18 ETA:  0:01:03

56320 samples, 440 batch 	 Loss: 0.251861


 22% (454 of 2000) |####                  | Elapsed Time: 0:00:19 ETA:  0:01:03

57600 samples, 450 batch 	 Loss: 0.243635


 23% (465 of 2000) |#####                 | Elapsed Time: 0:00:19 ETA:  0:01:02

58880 samples, 460 batch 	 Loss: 0.259490


 23% (473 of 2000) |#####                 | Elapsed Time: 0:00:19 ETA:  0:01:02

60160 samples, 470 batch 	 Loss: 0.323166


 24% (485 of 2000) |#####                 | Elapsed Time: 0:00:20 ETA:  0:01:01

61440 samples, 480 batch 	 Loss: 0.328504


 24% (494 of 2000) |#####                 | Elapsed Time: 0:00:20 ETA:  0:01:01

62720 samples, 490 batch 	 Loss: 0.230251


 25% (503 of 2000) |#####                 | Elapsed Time: 0:00:21 ETA:  0:01:01

64000 samples, 500 batch 	 Loss: 0.277737


 25% (515 of 2000) |#####                 | Elapsed Time: 0:00:21 ETA:  0:01:01

65280 samples, 510 batch 	 Loss: 0.245572


 26% (523 of 2000) |#####                 | Elapsed Time: 0:00:21 ETA:  0:01:01

66560 samples, 520 batch 	 Loss: 0.310799


 26% (535 of 2000) |#####                 | Elapsed Time: 0:00:22 ETA:  0:01:00

67840 samples, 530 batch 	 Loss: 0.266622


 27% (544 of 2000) |#####                 | Elapsed Time: 0:00:22 ETA:  0:00:59

69120 samples, 540 batch 	 Loss: 0.330352


 27% (553 of 2000) |######                | Elapsed Time: 0:00:23 ETA:  0:00:59

70400 samples, 550 batch 	 Loss: 0.320879


 28% (565 of 2000) |######                | Elapsed Time: 0:00:23 ETA:  0:00:59

71680 samples, 560 batch 	 Loss: 0.322719


 28% (574 of 2000) |######                | Elapsed Time: 0:00:23 ETA:  0:00:58

72960 samples, 570 batch 	 Loss: 0.159692


 29% (583 of 2000) |######                | Elapsed Time: 0:00:24 ETA:  0:00:58

74240 samples, 580 batch 	 Loss: 0.164632


 29% (595 of 2000) |######                | Elapsed Time: 0:00:24 ETA:  0:00:57

75520 samples, 590 batch 	 Loss: 0.335535


 30% (604 of 2000) |######                | Elapsed Time: 0:00:25 ETA:  0:00:57

76800 samples, 600 batch 	 Loss: 0.162970


 30% (613 of 2000) |######                | Elapsed Time: 0:00:25 ETA:  0:00:57

78080 samples, 610 batch 	 Loss: 0.178640


 31% (625 of 2000) |######                | Elapsed Time: 0:00:26 ETA:  0:00:56

79360 samples, 620 batch 	 Loss: 0.154346


 31% (633 of 2000) |######                | Elapsed Time: 0:00:26 ETA:  0:00:56

80640 samples, 630 batch 	 Loss: 0.182621


 32% (645 of 2000) |#######               | Elapsed Time: 0:00:26 ETA:  0:00:55

81920 samples, 640 batch 	 Loss: 0.248413


 32% (654 of 2000) |#######               | Elapsed Time: 0:00:27 ETA:  0:00:55

83200 samples, 650 batch 	 Loss: 0.147819


 33% (665 of 2000) |#######               | Elapsed Time: 0:00:27 ETA:  0:00:54

84480 samples, 660 batch 	 Loss: 0.138391


 33% (674 of 2000) |#######               | Elapsed Time: 0:00:28 ETA:  0:00:53

85760 samples, 670 batch 	 Loss: 0.358932


 34% (683 of 2000) |#######               | Elapsed Time: 0:00:28 ETA:  0:00:53

87040 samples, 680 batch 	 Loss: 0.211805


 34% (694 of 2000) |#######               | Elapsed Time: 0:00:28 ETA:  0:00:53

88320 samples, 690 batch 	 Loss: 0.080155


 35% (703 of 2000) |#######               | Elapsed Time: 0:00:29 ETA:  0:00:52

89600 samples, 700 batch 	 Loss: 0.160956


 35% (714 of 2000) |#######               | Elapsed Time: 0:00:29 ETA:  0:00:52

90880 samples, 710 batch 	 Loss: 0.099826


 36% (723 of 2000) |#######               | Elapsed Time: 0:00:30 ETA:  0:00:52

92160 samples, 720 batch 	 Loss: 0.244508


 36% (735 of 2000) |########              | Elapsed Time: 0:00:30 ETA:  0:00:51

93440 samples, 730 batch 	 Loss: 0.321499


 37% (744 of 2000) |########              | Elapsed Time: 0:00:30 ETA:  0:00:51

94720 samples, 740 batch 	 Loss: 0.151449


 37% (753 of 2000) |########              | Elapsed Time: 0:00:31 ETA:  0:00:50

96000 samples, 750 batch 	 Loss: 0.288369


 38% (763 of 2000) |########              | Elapsed Time: 0:00:31 ETA:  0:00:50

97280 samples, 760 batch 	 Loss: 0.127043


 38% (775 of 2000) |########              | Elapsed Time: 0:00:32 ETA:  0:00:50

98560 samples, 770 batch 	 Loss: 0.312534


 39% (783 of 2000) |########              | Elapsed Time: 0:00:32 ETA:  0:00:52

99840 samples, 780 batch 	 Loss: 0.189759


 39% (794 of 2000) |########              | Elapsed Time: 0:00:33 ETA:  0:00:54

101120 samples, 790 batch 	 Loss: 0.263731


 40% (803 of 2000) |########              | Elapsed Time: 0:00:33 ETA:  0:00:54

102400 samples, 800 batch 	 Loss: 0.171545


 40% (814 of 2000) |########              | Elapsed Time: 0:00:33 ETA:  0:00:52

103680 samples, 810 batch 	 Loss: 0.189764


 41% (823 of 2000) |#########             | Elapsed Time: 0:00:34 ETA:  0:00:52

104960 samples, 820 batch 	 Loss: 0.267462


 41% (835 of 2000) |#########             | Elapsed Time: 0:00:34 ETA:  0:00:49

106240 samples, 830 batch 	 Loss: 0.231727


 42% (844 of 2000) |#########             | Elapsed Time: 0:00:35 ETA:  0:00:47

107520 samples, 840 batch 	 Loss: 0.170849


 42% (853 of 2000) |#########             | Elapsed Time: 0:00:35 ETA:  0:00:46

108800 samples, 850 batch 	 Loss: 0.185887


 43% (864 of 2000) |#########             | Elapsed Time: 0:00:36 ETA:  0:00:46

110080 samples, 860 batch 	 Loss: 0.351439


 43% (873 of 2000) |#########             | Elapsed Time: 0:00:36 ETA:  0:00:46

111360 samples, 870 batch 	 Loss: 0.256561


 44% (885 of 2000) |#########             | Elapsed Time: 0:00:36 ETA:  0:00:46

112640 samples, 880 batch 	 Loss: 0.180496


 44% (893 of 2000) |#########             | Elapsed Time: 0:00:37 ETA:  0:00:45

113920 samples, 890 batch 	 Loss: 0.242713


 45% (905 of 2000) |#########             | Elapsed Time: 0:00:37 ETA:  0:00:44

115200 samples, 900 batch 	 Loss: 0.210267


 45% (914 of 2000) |##########            | Elapsed Time: 0:00:38 ETA:  0:00:44

116480 samples, 910 batch 	 Loss: 0.260131


 46% (923 of 2000) |##########            | Elapsed Time: 0:00:38 ETA:  0:00:43

117760 samples, 920 batch 	 Loss: 0.138618


 46% (935 of 2000) |##########            | Elapsed Time: 0:00:38 ETA:  0:00:43

119040 samples, 930 batch 	 Loss: 0.274195


 47% (943 of 2000) |##########            | Elapsed Time: 0:00:39 ETA:  0:00:42

120320 samples, 940 batch 	 Loss: 0.229983


 47% (955 of 2000) |##########            | Elapsed Time: 0:00:39 ETA:  0:00:42

121600 samples, 950 batch 	 Loss: 0.219030


 48% (963 of 2000) |##########            | Elapsed Time: 0:00:40 ETA:  0:00:42

122880 samples, 960 batch 	 Loss: 0.235099


 48% (975 of 2000) |##########            | Elapsed Time: 0:00:40 ETA:  0:00:42

124160 samples, 970 batch 	 Loss: 0.272920


 49% (984 of 2000) |##########            | Elapsed Time: 0:00:40 ETA:  0:00:42

125440 samples, 980 batch 	 Loss: 0.161036


 49% (993 of 2000) |##########            | Elapsed Time: 0:00:41 ETA:  0:00:41

126720 samples, 990 batch 	 Loss: 0.115860


 50% (1005 of 2000) |##########           | Elapsed Time: 0:00:41 ETA:  0:00:41

128000 samples, 1000 batch 	 Loss: 0.249236


 50% (1013 of 2000) |##########           | Elapsed Time: 0:00:42 ETA:  0:00:41

129280 samples, 1010 batch 	 Loss: 0.209280


 51% (1025 of 2000) |##########           | Elapsed Time: 0:00:42 ETA:  0:00:40

130560 samples, 1020 batch 	 Loss: 0.180783


 51% (1034 of 2000) |##########           | Elapsed Time: 0:00:43 ETA:  0:00:40

131840 samples, 1030 batch 	 Loss: 0.122997


 52% (1043 of 2000) |##########           | Elapsed Time: 0:00:43 ETA:  0:00:40

133120 samples, 1040 batch 	 Loss: 0.186392


 52% (1055 of 2000) |###########          | Elapsed Time: 0:00:43 ETA:  0:00:39

134400 samples, 1050 batch 	 Loss: 0.251314


 53% (1064 of 2000) |###########          | Elapsed Time: 0:00:44 ETA:  0:00:39

135680 samples, 1060 batch 	 Loss: 0.203112


 53% (1073 of 2000) |###########          | Elapsed Time: 0:00:44 ETA:  0:00:39

136960 samples, 1070 batch 	 Loss: 0.137384


 54% (1085 of 2000) |###########          | Elapsed Time: 0:00:45 ETA:  0:00:38

138240 samples, 1080 batch 	 Loss: 0.231788


 54% (1094 of 2000) |###########          | Elapsed Time: 0:00:45 ETA:  0:00:38

139520 samples, 1090 batch 	 Loss: 0.208735


 55% (1103 of 2000) |###########          | Elapsed Time: 0:00:45 ETA:  0:00:37

140800 samples, 1100 batch 	 Loss: 0.153626


 55% (1114 of 2000) |###########          | Elapsed Time: 0:00:46 ETA:  0:00:37

142080 samples, 1110 batch 	 Loss: 0.232597


 56% (1123 of 2000) |###########          | Elapsed Time: 0:00:46 ETA:  0:00:36

143360 samples, 1120 batch 	 Loss: 0.268613


 56% (1135 of 2000) |###########          | Elapsed Time: 0:00:47 ETA:  0:00:36

144640 samples, 1130 batch 	 Loss: 0.293430


 57% (1143 of 2000) |############         | Elapsed Time: 0:00:47 ETA:  0:00:36

145920 samples, 1140 batch 	 Loss: 0.140571


 57% (1155 of 2000) |############         | Elapsed Time: 0:00:48 ETA:  0:00:36

147200 samples, 1150 batch 	 Loss: 0.108118


 58% (1163 of 2000) |############         | Elapsed Time: 0:00:48 ETA:  0:00:36

148480 samples, 1160 batch 	 Loss: 0.216706


 58% (1174 of 2000) |############         | Elapsed Time: 0:00:48 ETA:  0:00:35

149760 samples, 1170 batch 	 Loss: 0.130257


 59% (1183 of 2000) |############         | Elapsed Time: 0:00:49 ETA:  0:00:35

151040 samples, 1180 batch 	 Loss: 0.206098


 59% (1193 of 2000) |############         | Elapsed Time: 0:00:49 ETA:  0:00:34

152320 samples, 1190 batch 	 Loss: 0.251470


 60% (1205 of 2000) |############         | Elapsed Time: 0:00:50 ETA:  0:00:33

153600 samples, 1200 batch 	 Loss: 0.186186


 60% (1214 of 2000) |############         | Elapsed Time: 0:00:50 ETA:  0:00:33

154880 samples, 1210 batch 	 Loss: 0.241921


 61% (1225 of 2000) |############         | Elapsed Time: 0:00:51 ETA:  0:00:32

156160 samples, 1220 batch 	 Loss: 0.179275


 61% (1233 of 2000) |############         | Elapsed Time: 0:00:51 ETA:  0:00:32

157440 samples, 1230 batch 	 Loss: 0.265361


 62% (1243 of 2000) |#############        | Elapsed Time: 0:00:51 ETA:  0:00:33

158720 samples, 1240 batch 	 Loss: 0.110933


 62% (1254 of 2000) |#############        | Elapsed Time: 0:00:52 ETA:  0:00:34

160000 samples, 1250 batch 	 Loss: 0.113963


 63% (1263 of 2000) |#############        | Elapsed Time: 0:00:52 ETA:  0:00:33

161280 samples, 1260 batch 	 Loss: 0.285038


 63% (1275 of 2000) |#############        | Elapsed Time: 0:00:53 ETA:  0:00:32

162560 samples, 1270 batch 	 Loss: 0.166562


 64% (1284 of 2000) |#############        | Elapsed Time: 0:00:53 ETA:  0:00:32

163840 samples, 1280 batch 	 Loss: 0.191259


 64% (1292 of 2000) |#############        | Elapsed Time: 0:00:54 ETA:  0:00:30

165120 samples, 1290 batch 	 Loss: 0.285110


 65% (1302 of 2000) |#############        | Elapsed Time: 0:00:54 ETA:  0:00:29

166400 samples, 1300 batch 	 Loss: 0.193868


 65% (1313 of 2000) |#############        | Elapsed Time: 0:00:55 ETA:  0:00:31

167680 samples, 1310 batch 	 Loss: 0.104786


 66% (1324 of 2000) |#############        | Elapsed Time: 0:00:55 ETA:  0:00:31

168960 samples, 1320 batch 	 Loss: 0.207987


 66% (1333 of 2000) |#############        | Elapsed Time: 0:00:56 ETA:  0:00:30

170240 samples, 1330 batch 	 Loss: 0.189642


 67% (1345 of 2000) |##############       | Elapsed Time: 0:00:56 ETA:  0:00:30

171520 samples, 1340 batch 	 Loss: 0.178366


 67% (1354 of 2000) |##############       | Elapsed Time: 0:00:56 ETA:  0:00:29

172800 samples, 1350 batch 	 Loss: 0.255310


 68% (1363 of 2000) |##############       | Elapsed Time: 0:00:57 ETA:  0:00:26

174080 samples, 1360 batch 	 Loss: 0.101000


 68% (1373 of 2000) |##############       | Elapsed Time: 0:00:57 ETA:  0:00:26

175360 samples, 1370 batch 	 Loss: 0.142319


 69% (1385 of 2000) |##############       | Elapsed Time: 0:00:58 ETA:  0:00:25

176640 samples, 1380 batch 	 Loss: 0.111390


 69% (1393 of 2000) |##############       | Elapsed Time: 0:00:58 ETA:  0:00:24

177920 samples, 1390 batch 	 Loss: 0.160374


 70% (1405 of 2000) |##############       | Elapsed Time: 0:00:58 ETA:  0:00:24

179200 samples, 1400 batch 	 Loss: 0.157156


 70% (1414 of 2000) |##############       | Elapsed Time: 0:00:59 ETA:  0:00:24

180480 samples, 1410 batch 	 Loss: 0.148960


 71% (1423 of 2000) |##############       | Elapsed Time: 0:00:59 ETA:  0:00:23

181760 samples, 1420 batch 	 Loss: 0.097193


 71% (1435 of 2000) |###############      | Elapsed Time: 0:01:00 ETA:  0:00:22

183040 samples, 1430 batch 	 Loss: 0.299849


 72% (1444 of 2000) |###############      | Elapsed Time: 0:01:00 ETA:  0:00:22

184320 samples, 1440 batch 	 Loss: 0.155926


 72% (1453 of 2000) |###############      | Elapsed Time: 0:01:00 ETA:  0:00:22

185600 samples, 1450 batch 	 Loss: 0.125284


 73% (1465 of 2000) |###############      | Elapsed Time: 0:01:01 ETA:  0:00:21

186880 samples, 1460 batch 	 Loss: 0.140768


 73% (1474 of 2000) |###############      | Elapsed Time: 0:01:01 ETA:  0:00:21

188160 samples, 1470 batch 	 Loss: 0.111292


 74% (1483 of 2000) |###############      | Elapsed Time: 0:01:02 ETA:  0:00:21

189440 samples, 1480 batch 	 Loss: 0.153813


 74% (1494 of 2000) |###############      | Elapsed Time: 0:01:02 ETA:  0:00:21

190720 samples, 1490 batch 	 Loss: 0.163353


 75% (1503 of 2000) |###############      | Elapsed Time: 0:01:03 ETA:  0:00:20

192000 samples, 1500 batch 	 Loss: 0.140292


 75% (1515 of 2000) |###############      | Elapsed Time: 0:01:03 ETA:  0:00:20

193280 samples, 1510 batch 	 Loss: 0.129969


 76% (1524 of 2000) |################     | Elapsed Time: 0:01:03 ETA:  0:00:20

194560 samples, 1520 batch 	 Loss: 0.210544


 76% (1533 of 2000) |################     | Elapsed Time: 0:01:04 ETA:  0:00:19

195840 samples, 1530 batch 	 Loss: 0.110404


 77% (1545 of 2000) |################     | Elapsed Time: 0:01:04 ETA:  0:00:19

197120 samples, 1540 batch 	 Loss: 0.144353


 77% (1554 of 2000) |################     | Elapsed Time: 0:01:05 ETA:  0:00:18

198400 samples, 1550 batch 	 Loss: 0.252424


 78% (1563 of 2000) |################     | Elapsed Time: 0:01:05 ETA:  0:00:18

199680 samples, 1560 batch 	 Loss: 0.161297


 78% (1573 of 2000) |################     | Elapsed Time: 0:01:05 ETA:  0:00:17

200960 samples, 1570 batch 	 Loss: 0.244937


 79% (1585 of 2000) |################     | Elapsed Time: 0:01:06 ETA:  0:00:17

202240 samples, 1580 batch 	 Loss: 0.291475


 79% (1594 of 2000) |################     | Elapsed Time: 0:01:06 ETA:  0:00:16

203520 samples, 1590 batch 	 Loss: 0.100711


 80% (1603 of 2000) |################     | Elapsed Time: 0:01:07 ETA:  0:00:16

204800 samples, 1600 batch 	 Loss: 0.150947


 80% (1615 of 2000) |################     | Elapsed Time: 0:01:07 ETA:  0:00:16

206080 samples, 1610 batch 	 Loss: 0.090968


 81% (1624 of 2000) |#################    | Elapsed Time: 0:01:08 ETA:  0:00:15

207360 samples, 1620 batch 	 Loss: 0.141007


 81% (1633 of 2000) |#################    | Elapsed Time: 0:01:08 ETA:  0:00:15

208640 samples, 1630 batch 	 Loss: 0.214211


 82% (1644 of 2000) |#################    | Elapsed Time: 0:01:08 ETA:  0:00:14

209920 samples, 1640 batch 	 Loss: 0.105638


 82% (1655 of 2000) |#################    | Elapsed Time: 0:01:09 ETA:  0:00:14

211200 samples, 1650 batch 	 Loss: 0.197800


 83% (1664 of 2000) |#################    | Elapsed Time: 0:01:09 ETA:  0:00:13

212480 samples, 1660 batch 	 Loss: 0.111538


 83% (1673 of 2000) |#################    | Elapsed Time: 0:01:10 ETA:  0:00:13

213760 samples, 1670 batch 	 Loss: 0.244398


 84% (1685 of 2000) |#################    | Elapsed Time: 0:01:10 ETA:  0:00:13

215040 samples, 1680 batch 	 Loss: 0.113625


 84% (1693 of 2000) |#################    | Elapsed Time: 0:01:10 ETA:  0:00:12

216320 samples, 1690 batch 	 Loss: 0.164892


 85% (1705 of 2000) |#################    | Elapsed Time: 0:01:11 ETA:  0:00:12

217600 samples, 1700 batch 	 Loss: 0.114851


 85% (1714 of 2000) |#################    | Elapsed Time: 0:01:11 ETA:  0:00:12

218880 samples, 1710 batch 	 Loss: 0.129427


 86% (1722 of 2000) |##################   | Elapsed Time: 0:01:12 ETA:  0:00:11

220160 samples, 1720 batch 	 Loss: 0.206059


 86% (1734 of 2000) |##################   | Elapsed Time: 0:01:12 ETA:  0:00:11

221440 samples, 1730 batch 	 Loss: 0.160156


 87% (1743 of 2000) |##################   | Elapsed Time: 0:01:13 ETA:  0:00:10

222720 samples, 1740 batch 	 Loss: 0.208569


 87% (1752 of 2000) |##################   | Elapsed Time: 0:01:13 ETA:  0:00:10

224000 samples, 1750 batch 	 Loss: 0.223442


 88% (1764 of 2000) |##################   | Elapsed Time: 0:01:13 ETA:  0:00:09

225280 samples, 1760 batch 	 Loss: 0.230788


 88% (1773 of 2000) |##################   | Elapsed Time: 0:01:14 ETA:  0:00:09

226560 samples, 1770 batch 	 Loss: 0.204989


 89% (1785 of 2000) |##################   | Elapsed Time: 0:01:14 ETA:  0:00:09

227840 samples, 1780 batch 	 Loss: 0.182290


 89% (1794 of 2000) |##################   | Elapsed Time: 0:01:15 ETA:  0:00:08

229120 samples, 1790 batch 	 Loss: 0.162863


 90% (1803 of 2000) |##################   | Elapsed Time: 0:01:15 ETA:  0:00:08

230400 samples, 1800 batch 	 Loss: 0.174652


 90% (1814 of 2000) |###################  | Elapsed Time: 0:01:16 ETA:  0:00:08

231680 samples, 1810 batch 	 Loss: 0.179020


 91% (1823 of 2000) |###################  | Elapsed Time: 0:01:16 ETA:  0:00:07

232960 samples, 1820 batch 	 Loss: 0.169205


 91% (1835 of 2000) |###################  | Elapsed Time: 0:01:17 ETA:  0:00:07

234240 samples, 1830 batch 	 Loss: 0.154007


 92% (1843 of 2000) |###################  | Elapsed Time: 0:01:17 ETA:  0:00:06

235520 samples, 1840 batch 	 Loss: 0.142637


 92% (1855 of 2000) |###################  | Elapsed Time: 0:01:17 ETA:  0:00:06

236800 samples, 1850 batch 	 Loss: 0.091786


 93% (1864 of 2000) |###################  | Elapsed Time: 0:01:18 ETA:  0:00:05

238080 samples, 1860 batch 	 Loss: 0.195316


 93% (1873 of 2000) |###################  | Elapsed Time: 0:01:18 ETA:  0:00:05

239360 samples, 1870 batch 	 Loss: 0.245078


 94% (1885 of 2000) |###################  | Elapsed Time: 0:01:19 ETA:  0:00:04

240640 samples, 1880 batch 	 Loss: 0.125540


 94% (1894 of 2000) |###################  | Elapsed Time: 0:01:19 ETA:  0:00:04

241920 samples, 1890 batch 	 Loss: 0.158658


 95% (1905 of 2000) |#################### | Elapsed Time: 0:01:20 ETA:  0:00:04

243200 samples, 1900 batch 	 Loss: 0.208983


 95% (1914 of 2000) |#################### | Elapsed Time: 0:01:20 ETA:  0:00:03

244480 samples, 1910 batch 	 Loss: 0.180885


 96% (1923 of 2000) |#################### | Elapsed Time: 0:01:20 ETA:  0:00:03

245760 samples, 1920 batch 	 Loss: 0.136758


 96% (1934 of 2000) |#################### | Elapsed Time: 0:01:21 ETA:  0:00:02

247040 samples, 1930 batch 	 Loss: 0.157808


 97% (1943 of 2000) |#################### | Elapsed Time: 0:01:21 ETA:  0:00:02

248320 samples, 1940 batch 	 Loss: 0.129676


 97% (1953 of 2000) |#################### | Elapsed Time: 0:01:22 ETA:  0:00:01

249600 samples, 1950 batch 	 Loss: 0.072462


 98% (1965 of 2000) |#################### | Elapsed Time: 0:01:22 ETA:  0:00:01

250880 samples, 1960 batch 	 Loss: 0.074557


 98% (1974 of 2000) |#################### | Elapsed Time: 0:01:22 ETA:  0:00:01

252160 samples, 1970 batch 	 Loss: 0.165799


 99% (1983 of 2000) |#################### | Elapsed Time: 0:01:23 ETA:  0:00:00

253440 samples, 1980 batch 	 Loss: 0.128508


 99% (1995 of 2000) |#################### | Elapsed Time: 0:01:23 ETA:  0:00:00

254720 samples, 1990 batch 	 Loss: 0.126635


100% (2000 of 2000) |#####################| Elapsed Time: 0:01:23 ETA:  0:00:00

256000 samples, 2000 batch 	 Loss: 0.193492


In [30]:
from sklearn import metrics

def test(model, x_list, y_list, l_sample, nb_iter=1, batch_size=2048):
    model.eval()
    test_gen = batch_generator(x_list, y_list, batch_size, l_sample, train=False, nb_iter=nb_iter, categorical=False, seed=0)
    y_list = []
    py_list = []
    for X, y in test_gen:
        tX = Variable(torch.from_numpy(X.astype('float32')), requires_grad=False).cuda()
        py = model(tX).max(1)[1].data.cpu().numpy().reshape(-1)
        y_list.append(y)
        py_list.append(py)
    
    y = np.concatenate(y_list)
    py = np.concatenate(py_list)
    print(metrics.accuracy_score(y, py))
    print(metrics.f1_score(y, py, average='macro'))
    print(metrics.precision_score(y, py, average=None))
    print(metrics.recall_score(y, py, average=None))

In [32]:
test(model, x_test_list, y_test_list, l_sample, nb_iter=10, batch_size=batch_size)

0.77109375
0.561022799081
[ 0.91304348  0.83333333  0.82352941  0.76666667  0.42253521  0.22857143
  0.28571429  0.36666667  0.5         1.          0.35294118  0.33333333
  1.          0.66666667  0.83333333  0.67741935  0.80952381  0.85028571]
[ 0.7         0.74074074  0.60869565  0.88461538  0.90909091  0.42105263
  0.15384615  0.91666667  0.46153846  0.36842105  0.35294118  0.22222222
  0.23529412  0.45454545  0.83333333  0.32307692  0.60714286  0.90953545]


In [11]:
# keras model for comparison
# which is much slower than pytorch

In [12]:
import keras
if keras.__version__ >= '2.0.0':

    from keras.models import Sequential
    from keras.layers import Dense, Convolution2D, Dropout, Activation, Flatten
    from keras.layers import MaxPooling2D

    from keras.optimizers import SGD

    model = Sequential()
    model.add(Convolution2D(50, kernel_size=(1, 3), input_shape=[1, nb_modal, l_sample]))
    model.add(Activation('relu'))
    model.add(MaxPool2D(pool_size=(1, 2)))
    model.add(Convolution2D(40, (1, 3)))
    model.add(Activation('relu'))
    model.add(MaxPool2D(pool_size=(1, 2)))
    model.add(Convolution2D(20, (1, 3)))
    model.add(Activation('relu'))
    model.add(MaxPool2D(pool_size=(1, 2)))
    model.add(Flatten())
    model.add(Dense(nb_unit))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(nb_out))
    model.add(Activation('softmax'))
    model.summary()

    optimizer = SGD(lr=0.01, momentum=0.5)

    model.compile(optimizer=optimizer, loss='categorical_crossentropy')
    
    p = ProgressBar(max_value=report_each)
    gen = batch_generator(x_train_list, y_train_list, batch_size, l_sample, train=True, nb_iter=nb_iter, categorical=train_set.nb_class)
    for batch_idx, (X, y) in enumerate(gen):
        batch_idx += 1
        loss = model.train_on_batch(X, y)
        if batch_idx % report_each == 0:
            print("Seen {} samples, {} batch \t Loss: {:.6f}".format(batch_idx * batch_size, batch_idx, float(loss)))
            p.update(report_each)
        else:
            p.update((batch_idx % report_each))
else:
    print("TBA")

Using Theano backend.
 https://github.com/Theano/Theano/wiki/Converting-to-the-new-gpu-back-end%28gpuarray%29



TBA


Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN Mixed dnn version. The header is from one version, but we link with a different version (4007, 6021))
