In [2]:
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
 
# prepare dataset
class DiabetesDataset(Dataset):
    def __init__(self, filepath):
        xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
        self.len = xy.shape[0] # shape(多少行，多少列)
        self.x_data = torch.from_numpy(xy[:, :-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
 
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
 
    def __len__(self):
        return self.len
 
 
dataset = DiabetesDataset('./dataset/diabetes.csv')
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2) #num_workers 多线程
 
 
# design model using class
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 1)
        self.sigmoid = torch.nn.Sigmoid()
 
    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))
        return x
 
 
model = Model()
 
# construct loss and optimizer
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
 
# training cycle forward, backward, update
if __name__ == '__main__':
    for epoch in range(100):
        for i, data in enumerate(train_loader, 0): # train_loader 是先shuffle后mini_batch
            inputs, labels = data
            y_pred = model(inputs)
            loss = criterion(y_pred, labels)
            print(epoch, i, loss.item())
 
            optimizer.zero_grad()
            loss.backward()
 
            optimizer.step()

145854949951
65 23 0.6191020011901855
66 0 0.6034243702888489
66 1 0.5637224912643433
66 2 0.6428331732749939
66 3 0.623440146446228
66 4 0.7226921916007996
66 5 0.6826580166816711
66 6 0.7210690975189209
66 7 0.583355188369751
66 8 0.6631215214729309
66 9 0.5836985111236572
66 10 0.544350802898407
66 11 0.6428607106208801
66 12 0.6426671743392944
66 13 0.6030682921409607
66 14 0.6625211834907532
66 15 0.7028594017028809
66 16 0.543725311756134
66 17 0.7419840097427368
66 18 0.7020857930183411
66 19 0.6826244592666626
66 20 0.6624529361724854
66 21 0.7014200687408447
66 22 0.6040106415748596
66 23 0.6454955339431763
67 0 0.6232517957687378
67 1 0.7019836902618408
67 2 0.6237810254096985
67 3 0.6229473352432251
67 4 0.7814610600471497
67 5 0.6235973238945007
67 6 0.6038619875907898
67 7 0.6432237029075623
67 8 0.662602961063385
67 9 0.722165048122406
67 10 0.6234211921691895
67 11 0.6032957434654236
67 12 0.6426170468330383
67 13 0.662743091583252
67 14 0.5841829776763916
67 15 0.583055