In [1]:
import torch
import numpy as np
from torch.utils.data import Dataset #抽象类，只能被其他类继承，不可实例化 获取数据及其label
from torch.utils.data import DataLoader #可以实例化

In [2]:
#01 prepare dataset
class DiabetesDataset(Dataset):
    def __init__(self,filepath):
        xy=np.loadtxt(filepath,delimiter=',',dtype=np.float32)
        
        self.len=xy.shape[0]                      #xy是个N*9的矩阵，通过xy.shape就是【N,9】，所以self.len=N
        self.x_data=torch.from_numpy(xy[:,:-1])
        self.y_data=torch.from_numpy(xy[:,[-1]])
        
    def __getitem__(self,index):
        return self.x_data[index], self.y_data[index] #返回是一个元组(x,y)
    
    def __len__(self):
        return self.len
    
dataset=DiabetesDataset('diabetes.csv.gz')

train_loader=DataLoader(dataset=dataset,batch_size=32,shuffle=True,num_workers=0)
#num_workers 并行导入几个数据，根据GPU能力


In [3]:
#2 define model

class Model(torch.nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.linear1=torch.nn.Linear(8,6)
        self.linear2=torch.nn.Linear(6,4)
        self.linear3=torch.nn.Linear(4,1)
        self.sigmoid=torch.nn.Sigmoid()
        
    def forward(self,x):
        x=self.sigmoid(self.linear1(x))
        x=self.sigmoid(self.linear2(x))
        x=self.sigmoid(self.linear3(x))  #其实这里是x1,x2,y_pred但都写成x是惯例，防止打字错误
        return x
    
model=Model()

In [4]:
#3 construct loss and optimizer(using pytorch api
criterion=torch.nn.BCELoss(reduction='sum') #求交叉熵
optimizer=torch.optim.SGD(model.parameters(),lr=0.0001)

In [5]:
#04 training cycle
if __name__=='__main__':
    for epoch in range(100):
        for i,data in enumerate(train_loader,0):
            inputs,labels=data
        
            y_pred=model(inputs)
            loss=criterion(y_pred,labels)
            print(epoch,i,loss.item())
        
            optimizer.zero_grad() 
            loss.backward()
            optimizer.step() #梯度下降更新 w=w-lr*梯度

0 0 25.842941284179688
0 1 22.589515686035156
0 2 23.98602294921875
0 3 24.888507843017578
0 4 23.951366424560547
0 5 25.327625274658203
0 6 24.854982376098633
0 7 25.75986671447754
0 8 25.28711700439453
0 9 25.287694931030273
0 10 23.011581420898438
0 11 26.15660285949707
0 12 27.01373863220215
0 13 25.219581604003906
0 14 26.5321044921875
0 15 25.180627822875977
0 16 25.165836334228516
0 17 25.15283203125
0 18 23.402742385864258
0 19 25.586971282958984
0 20 25.10636329650879
0 21 26.402503967285156
0 22 26.370620727539062
0 23 19.2391414642334
1 0 22.478443145751953
1 1 26.3094539642334
1 2 21.65290641784668
1 3 24.20158576965332
1 4 25.85275650024414
1 5 25.406333923339844
1 6 24.981422424316406
1 7 24.144044876098633
1 8 24.96331024169922
1 9 26.583826065063477
1 10 25.34787368774414
1 11 24.90070343017578
1 12 26.105886459350586
1 13 26.888599395751953
1 14 26.056774139404297
1 15 25.24532127380371
1 16 25.60834312438965
1 17 24.016157150268555
1 18 23.21843719482422
1 19 24.39534

28 7 21.213214874267578
28 8 20.21945571899414
28 9 20.701078414916992
28 10 22.20100212097168
28 11 20.958234786987305
28 12 20.949724197387695
28 13 20.691808700561523
28 14 21.183124542236328
28 15 20.69133758544922
28 16 22.20083999633789
28 17 19.66209602355957
28 18 20.67711067199707
28 19 22.201866149902344
28 20 21.427318572998047
28 21 21.945777893066406
28 22 21.936763763427734
28 23 15.761672973632812
29 0 21.430885314941406
29 1 22.199357986450195
29 2 21.434402465820312
29 3 22.215944290161133
29 4 22.19876480102539
29 5 21.687944412231445
29 6 21.689603805541992
29 7 20.153581619262695
29 8 21.9429874420166
29 9 21.69316864013672
29 10 20.648914337158203
29 11 21.168399810791016
29 12 20.38814926147461
29 13 20.642345428466797
29 14 20.119632720947266
29 15 21.423452377319336
29 16 20.105093002319336
29 17 20.614587783813477
29 18 20.61568832397461
29 19 21.146696090698242
29 20 20.601707458496094
29 21 21.94025993347168
29 22 21.40946388244629
29 23 14.94974136352539
30 

55 11 20.199426651000977
55 12 21.675983428955078
55 13 19.72454071044922
55 14 20.210071563720703
55 15 20.198808670043945
55 16 19.227542877197266
55 17 22.175765991210938
55 18 21.683048248291016
55 19 21.66908836364746
55 20 20.696115493774414
55 21 19.708728790283203
55 22 18.719860076904297
55 23 15.903177261352539
56 0 20.199256896972656
56 1 21.1815242767334
56 2 19.722705841064453
56 3 20.696744918823242
56 4 23.158016204833984
56 5 20.688753128051758
56 6 20.199342727661133
56 7 20.182462692260742
56 8 17.228557586669922
56 9 21.688934326171875
56 10 20.68610191345215
56 11 21.19027328491211
56 12 19.20212173461914
56 13 23.667457580566406
56 14 22.17552947998047
56 15 22.169342041015625
56 16 20.67986297607422
56 17 19.698387145996094
56 18 22.66702651977539
56 19 19.70716094970703
56 20 19.702014923095703
56 21 19.685199737548828
56 22 19.697784423828125
56 23 15.907111167907715
57 0 22.18161392211914
57 1 20.687700271606445
57 2 20.186168670654297
57 3 20.678131103515625
5

81 5 22.93061637878418
81 6 18.90139389038086
81 7 20.048786163330078
81 8 21.20525550842285
81 9 20.616230010986328
81 10 18.88600730895996
81 11 20.038238525390625
81 12 20.62453842163086
81 13 20.618425369262695
81 14 20.617656707763672
81 15 18.30059051513672
81 16 19.45496368408203
81 17 23.5300350189209
81 18 22.353893280029297
81 19 23.519542694091797
81 20 22.360599517822266
81 21 18.31430435180664
81 22 20.61739158630371
81 23 13.711692810058594
82 0 21.2005672454834
82 1 19.46854591369629
82 2 20.62541961669922
82 3 22.36142349243164
82 4 21.19568634033203
82 5 21.76522445678711
82 6 20.048492431640625
82 7 21.195270538330078
82 8 18.31661033630371
82 9 19.46143913269043
82 10 19.463478088378906
82 11 22.34693145751953
82 12 20.036502838134766
82 13 22.94597625732422
82 14 20.626251220703125
82 15 20.61804962158203
82 16 22.359302520751953
82 17 20.61975860595703
82 18 19.46058464050293
82 19 20.032590866088867
82 20 20.03980255126953
82 21 20.618741989135742
82 22 20.6110382