In [1]:
# References
# https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/01-basics/pytorch_basics/main.py
# http://pytorch.org/tutorials/beginner/data_loading_tutorial.html#dataset-class
import torch
import numpy as np
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

In [2]:
class DiabetesDataset(Dataset):
    """ Diabetes dataset."""

    # Initialize your data, download, etc.
    def __init__(self):
        xy = np.loadtxt('./data/diabetes.csv.gz',
                        delimiter=',', dtype=np.float32)
        self.len = xy.shape[0]
        self.x_data = torch.from_numpy(xy[:, 0:-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])

    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]

    def __len__(self):
        return self.len


dataset = DiabetesDataset()

In [3]:
train_loader = DataLoader(dataset=dataset,
                          batch_size=32,
                          shuffle=True,
                          num_workers=0)

In [4]:
class Model(torch.nn.Module):

    def __init__(self):
        """
        In the constructor we instantiate two nn.Linear module
        """
        super(Model, self).__init__()
        self.l1 = torch.nn.Linear(8, 6)
        self.l2 = torch.nn.Linear(6, 4)
        self.l3 = torch.nn.Linear(4, 1)

        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        """
        In the forward function we accept a Variable of input data and we must return
        a Variable of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Variables.
        """
        out1 = self.sigmoid(self.l1(x))
        out2 = self.sigmoid(self.l2(out1))
        y_pred = self.sigmoid(self.l3(out2))
        return y_pred

# our model
model = Model()

In [5]:
# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)



In [7]:
# Training loop
for epoch in range(100):
    for i, data in enumerate(train_loader, 0):
        # get the inputs
        inputs, labels = data

        # wrap them in Variable
        inputs, labels = Variable(inputs), Variable(labels)

        # Forward pass: Compute predicted y by passing x to the model
        y_pred = model(inputs)

        # Compute and print loss
        loss = criterion(y_pred, labels)
        print(epoch, i, loss.data[0])

        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

  from ipykernel import kernelapp as app


0 0 tensor(0.6258)
0 1 tensor(0.6260)
0 2 tensor(0.6258)
0 3 tensor(0.6076)
0 4 tensor(0.6066)
0 5 tensor(0.5858)
0 6 tensor(0.6036)
0 7 tensor(0.6428)
0 8 tensor(0.6020)
0 9 tensor(0.5819)
0 10 tensor(0.7267)
0 11 tensor(0.7236)
0 12 tensor(0.6443)
0 13 tensor(0.6428)
0 14 tensor(0.5846)
0 15 tensor(0.6820)
0 16 tensor(0.7020)
0 17 tensor(0.6997)
0 18 tensor(0.6613)
0 19 tensor(0.6992)
0 20 tensor(0.5895)
0 21 tensor(0.6805)
0 22 tensor(0.6985)
0 23 tensor(0.6458)
1 0 tensor(0.5714)
1 1 tensor(0.6050)
1 2 tensor(0.6613)
1 3 tensor(0.6814)
1 4 tensor(0.7360)
1 5 tensor(0.6082)
1 6 tensor(0.6255)
1 7 tensor(0.6430)
1 8 tensor(0.6620)
1 9 tensor(0.5321)
1 10 tensor(0.6823)
1 11 tensor(0.6050)
1 12 tensor(0.6623)
1 13 tensor(0.6624)
1 14 tensor(0.6039)
1 15 tensor(0.7028)
1 16 tensor(0.6804)
1 17 tensor(0.6245)
1 18 tensor(0.6046)
1 19 tensor(0.7011)
1 20 tensor(0.5853)
1 21 tensor(0.7020)
1 22 tensor(0.7199)
1 23 tensor(0.6188)
2 0 tensor(0.6985)
2 1 tensor(0.5879)
2 2 tensor(0.7182)
2 3

17 17 tensor(0.6789)
17 18 tensor(0.5521)
17 19 tensor(0.5283)
17 20 tensor(0.6421)
17 21 tensor(0.6012)
17 22 tensor(0.7038)
17 23 tensor(0.5596)
18 0 tensor(0.7043)
18 1 tensor(0.6415)
18 2 tensor(0.5609)
18 3 tensor(0.6616)
18 4 tensor(0.6824)
18 5 tensor(0.6823)
18 6 tensor(0.6819)
18 7 tensor(0.5834)
18 8 tensor(0.6212)
18 9 tensor(0.6631)
18 10 tensor(0.5006)
18 11 tensor(0.6637)
18 12 tensor(0.6420)
18 13 tensor(0.6432)
18 14 tensor(0.6845)
18 15 tensor(0.6410)
18 16 tensor(0.6429)
18 17 tensor(0.6424)
18 18 tensor(0.5588)
18 19 tensor(0.6416)
18 20 tensor(0.6638)
18 21 tensor(0.7256)
18 22 tensor(0.7228)
18 23 tensor(0.5913)
19 0 tensor(0.6827)
19 1 tensor(0.6998)
19 2 tensor(0.6421)
19 3 tensor(0.6225)
19 4 tensor(0.7392)
19 5 tensor(0.6235)
19 6 tensor(0.6598)
19 7 tensor(0.6240)
19 8 tensor(0.5670)
19 9 tensor(0.5636)
19 10 tensor(0.6014)
19 11 tensor(0.6821)
19 12 tensor(0.6635)
19 13 tensor(0.7827)
19 14 tensor(0.6783)
19 15 tensor(0.5680)
19 16 tensor(0.6417)
19 17 tensor

35 23 tensor(0.6135)
36 0 tensor(0.6606)
36 1 tensor(0.7020)
36 2 tensor(0.6210)
36 3 tensor(0.7182)
36 4 tensor(0.6603)
36 5 tensor(0.6593)
36 6 tensor(0.6415)
36 7 tensor(0.6609)
36 8 tensor(0.6764)
36 9 tensor(0.6053)
36 10 tensor(0.6594)
36 11 tensor(0.5835)
36 12 tensor(0.5812)
36 13 tensor(0.6401)
36 14 tensor(0.6995)
36 15 tensor(0.6192)
36 16 tensor(0.6400)
36 17 tensor(0.6800)
36 18 tensor(0.6003)
36 19 tensor(0.5824)
36 20 tensor(0.5795)
36 21 tensor(0.7014)
36 22 tensor(0.5980)
36 23 tensor(0.6439)
37 0 tensor(0.6407)
37 1 tensor(0.6197)
37 2 tensor(0.7036)
37 3 tensor(0.6408)
37 4 tensor(0.6027)
37 5 tensor(0.5787)
37 6 tensor(0.6178)
37 7 tensor(0.6194)
37 8 tensor(0.6831)
37 9 tensor(0.7856)
37 10 tensor(0.6008)
37 11 tensor(0.5813)
37 12 tensor(0.6595)
37 13 tensor(0.5780)
37 14 tensor(0.7446)
37 15 tensor(0.5803)
37 16 tensor(0.6593)
37 17 tensor(0.6372)
37 18 tensor(0.6790)
37 19 tensor(0.6003)
37 20 tensor(0.6999)
37 21 tensor(0.6616)
37 22 tensor(0.6408)
37 23 tensor

53 0 tensor(0.6358)
53 1 tensor(0.5824)
53 2 tensor(0.5382)
53 3 tensor(0.6785)
53 4 tensor(0.5534)
53 5 tensor(0.7212)
53 6 tensor(0.7007)
53 7 tensor(0.5784)
53 8 tensor(0.6756)
53 9 tensor(0.6574)
53 10 tensor(0.6379)
53 11 tensor(0.6193)
53 12 tensor(0.7198)
53 13 tensor(0.5813)
53 14 tensor(0.6166)
53 15 tensor(0.6996)
53 16 tensor(0.6361)
53 17 tensor(0.6949)
53 18 tensor(0.6765)
53 19 tensor(0.5245)
53 20 tensor(0.5764)
53 21 tensor(0.6947)
53 22 tensor(0.6555)
53 23 tensor(0.7123)
54 0 tensor(0.7116)
54 1 tensor(0.7078)
54 2 tensor(0.6196)
54 3 tensor(0.6189)
54 4 tensor(0.7096)
54 5 tensor(0.5348)
54 6 tensor(0.6007)
54 7 tensor(0.6587)
54 8 tensor(0.6171)
54 9 tensor(0.5994)
54 10 tensor(0.6170)
54 11 tensor(0.6925)
54 12 tensor(0.5997)
54 13 tensor(0.5775)
54 14 tensor(0.6174)
54 15 tensor(0.6772)
54 16 tensor(0.6370)
54 17 tensor(0.5358)
54 18 tensor(0.6355)
54 19 tensor(0.6597)
54 20 tensor(0.5762)
54 21 tensor(0.7000)
54 22 tensor(0.7591)
54 23 tensor(0.6896)
55 0 tensor(

71 16 tensor(0.6686)
71 17 tensor(0.6704)
71 18 tensor(0.6755)
71 19 tensor(0.5351)
71 20 tensor(0.6250)
71 21 tensor(0.6961)
71 22 tensor(0.5718)
71 23 tensor(0.6827)
72 0 tensor(0.6248)
72 1 tensor(0.6101)
72 2 tensor(0.5835)
72 3 tensor(0.6303)
72 4 tensor(0.6014)
72 5 tensor(0.6681)
72 6 tensor(0.5877)
72 7 tensor(0.7089)
72 8 tensor(0.6289)
72 9 tensor(0.7049)
72 10 tensor(0.6745)
72 11 tensor(0.5897)
72 12 tensor(0.5833)
72 13 tensor(0.6742)
72 14 tensor(0.6624)
72 15 tensor(0.6150)
72 16 tensor(0.6505)
72 17 tensor(0.6496)
72 18 tensor(0.5544)
72 19 tensor(0.6474)
72 20 tensor(0.5927)
72 21 tensor(0.5915)
72 22 tensor(0.6647)
72 23 tensor(0.6248)
73 0 tensor(0.6886)
73 1 tensor(0.5898)
73 2 tensor(0.6828)
73 3 tensor(0.6375)
73 4 tensor(0.7110)
73 5 tensor(0.6836)
73 6 tensor(0.5798)
73 7 tensor(0.6921)
73 8 tensor(0.6029)
73 9 tensor(0.6266)
73 10 tensor(0.6664)
73 11 tensor(0.6401)
73 12 tensor(0.6241)
73 13 tensor(0.6936)
73 14 tensor(0.6356)
73 15 tensor(0.5404)
73 16 tensor

89 5 tensor(0.5681)
89 6 tensor(0.7004)
89 7 tensor(0.5962)
89 8 tensor(0.6319)
89 9 tensor(0.5643)
89 10 tensor(0.7156)
89 11 tensor(0.5838)
89 12 tensor(0.6282)
89 13 tensor(0.6184)
89 14 tensor(0.5624)
89 15 tensor(0.5855)
89 16 tensor(0.5855)
89 17 tensor(0.5799)
89 18 tensor(0.6325)
89 19 tensor(0.6274)
89 20 tensor(0.6049)
89 21 tensor(0.6618)
89 22 tensor(0.5444)
89 23 tensor(0.6304)
90 0 tensor(0.6047)
90 1 tensor(0.5820)
90 2 tensor(0.6293)
90 3 tensor(0.6504)
90 4 tensor(0.6095)
90 5 tensor(0.5588)
90 6 tensor(0.6936)
90 7 tensor(0.6379)
90 8 tensor(0.6823)
90 9 tensor(0.6539)
90 10 tensor(0.5900)
90 11 tensor(0.5921)
90 12 tensor(0.5721)
90 13 tensor(0.5929)
90 14 tensor(0.6336)
90 15 tensor(0.5929)
90 16 tensor(0.5373)
90 17 tensor(0.5963)
90 18 tensor(0.5996)
90 19 tensor(0.5767)
90 20 tensor(0.6294)
90 21 tensor(0.5750)
90 22 tensor(0.5329)
90 23 tensor(0.6650)
91 0 tensor(0.7127)
91 1 tensor(0.6034)
91 2 tensor(0.6002)
91 3 tensor(0.5895)
91 4 tensor(0.6513)
91 5 tensor(

In [None]:
notice bad cost pattern above