In [88]:
import numpy as np
import numpy as pn
import torch
import torch.nn as nn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary
from torch.optim import Adam

In [89]:
class IrisDataset(Dataset):

    def __init__(self, X:np.ndarray, y:np.ndarray):
        super(IrisDataset, self).__init__()
        X = X.astype('float32')
        y = y.astype('int64')
        self.X = torch.from_numpy(X)
        self.y = torch.from_numpy(y)

    def __getitem__(self, index):
        return self.X[index], self.y[index]

    def __len__(self):
        return len(self.X)


In [90]:
class Net(nn.Module):

    def __init__(self, feature_nums:int, layer_num:int):
        """
        :param feature_nums:
        :param layer_num:
        """
        super(Net, self).__init__()
        self.classifier = nn.Sequential()
        input_nums = feature_nums
        output_nums = 16
        for layer_index in range(layer_num - 1):
            self.classifier.add_module(f'Linear_layer{layer_index}',nn.Linear(input_nums, output_nums))
            self.classifier.add_module(f'Relu{layer_index}',nn.ReLU())
            input_nums = output_nums
            output_nums *= 2

        self.classifier.add_module(f'Linear_laye{layer_num - 1}', nn.Linear(input_nums, feature_nums))


    def forward(self, x)->torch.Tensor:
        return self.classifier(x)



In [91]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
x = torch.rand((120, 4)).to(device)
net = Net(4, 6).to("cuda")
# print(net)
print(net(x))


cuda
tensor([[0.0152, 0.0143, 0.0225, 0.0271],
        [0.0186, 0.0132, 0.0140, 0.0319],
        [0.0161, 0.0145, 0.0196, 0.0312],
        [0.0150, 0.0141, 0.0220, 0.0269],
        [0.0160, 0.0152, 0.0224, 0.0289],
        [0.0146, 0.0145, 0.0207, 0.0298],
        [0.0124, 0.0146, 0.0201, 0.0261],
        [0.0150, 0.0148, 0.0199, 0.0301],
        [0.0136, 0.0125, 0.0213, 0.0281],
        [0.0171, 0.0141, 0.0225, 0.0291],
        [0.0161, 0.0142, 0.0178, 0.0309],
        [0.0160, 0.0134, 0.0202, 0.0310],
        [0.0153, 0.0132, 0.0180, 0.0305],
        [0.0155, 0.0139, 0.0224, 0.0274],
        [0.0158, 0.0148, 0.0203, 0.0304],
        [0.0158, 0.0142, 0.0228, 0.0267],
        [0.0126, 0.0140, 0.0203, 0.0261],
        [0.0141, 0.0151, 0.0225, 0.0272],
        [0.0143, 0.0147, 0.0192, 0.0298],
        [0.0164, 0.0148, 0.0195, 0.0312],
        [0.0148, 0.0137, 0.0217, 0.0266],
        [0.0148, 0.0147, 0.0210, 0.0296],
        [0.0151, 0.0143, 0.0210, 0.0302],
        [0.0163, 0.0138, 0.02

In [92]:
iris_data = load_iris(return_X_y=True)
print(iris_data[0].shape, iris_data[1].shape)

X = iris_data[0]
y = iris_data[1]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=40)

iris_train = IrisDataset(X_train, y_train)
iris_test = IrisDataset(X_test, y_test)



(150, 4) (150,)


In [93]:
# train

net.train()

train_loader = DataLoader(iris_train, batch_size=10, shuffle=True)
test_loader = DataLoader(iris_test, batch_size=10,shuffle=False)

EPOCH = 100
net = Net(4, 4)
opt = Adam(lr=0.001, params=net.parameters())
loss_fn = nn.CrossEntropyLoss()

for epoch in range(EPOCH):

    for batch_id, (Sample_data, target) in enumerate(train_loader):

        print(Sample_data.dtype)
        predict = net(Sample_data)

        loss = loss_fn(predict, target).requires_grad_(True)
        opt.zero_grad()
        loss.backward()
        opt.step()

        print(f"epoch:{epoch}, batch_id:{batch_id}, loss:{loss}")


        # print((idx, input_data, target))

torch.float32
epoch:0, batch_id:0, loss:1.4903762340545654
torch.float32
epoch:0, batch_id:1, loss:1.4536776542663574
torch.float32
epoch:0, batch_id:2, loss:1.4409438371658325
torch.float32
epoch:0, batch_id:3, loss:1.3790836334228516
torch.float32
epoch:0, batch_id:4, loss:1.3325238227844238
torch.float32
epoch:0, batch_id:5, loss:1.388898253440857
torch.float32
epoch:0, batch_id:6, loss:1.3694090843200684
torch.float32
epoch:0, batch_id:7, loss:1.3824855089187622
torch.float32
epoch:0, batch_id:8, loss:1.341977596282959
torch.float32
epoch:0, batch_id:9, loss:1.2835955619812012
torch.float32
epoch:0, batch_id:10, loss:1.3229113817214966
torch.float32
epoch:0, batch_id:11, loss:1.3375552892684937
torch.float32
epoch:1, batch_id:0, loss:1.253288984298706
torch.float32
epoch:1, batch_id:1, loss:1.3224456310272217
torch.float32
epoch:1, batch_id:2, loss:1.2365986108779907
torch.float32
epoch:1, batch_id:3, loss:1.2567509412765503
torch.float32
epoch:1, batch_id:4, loss:1.240922689437866

In [94]:
net.eval()
prob_func = nn.Softmax(dim=1)
# x = torch.rand((4, 3))
# x = prob_func(x)
# print(x)
# print(torch.argmax(x,dim=1))
with torch.no_grad():

    for epoch in range(EPOCH):
        num_cnt = 0
        for batch_id, (Sample_data, labels) in enumerate(test_loader):
            total = Sample_data.shape[0]
            pred = net(Sample_data)
            pred = prob_func(pred)
            idx = torch.argmax(pred,dim=1)
            for index in range(total):
                if idx[index] == labels[index]:
                    num_cnt += 1

            acc = num_cnt / total
            num_cnt = 0
            print(f"epoch:{epoch}, batch_id:{batch_id}, acc:{acc}")


epoch:0, batch_id:0, acc:1.0
epoch:0, batch_id:1, acc:1.0
epoch:0, batch_id:2, acc:1.0
epoch:1, batch_id:0, acc:1.0
epoch:1, batch_id:1, acc:1.0
epoch:1, batch_id:2, acc:1.0
epoch:2, batch_id:0, acc:1.0
epoch:2, batch_id:1, acc:1.0
epoch:2, batch_id:2, acc:1.0
epoch:3, batch_id:0, acc:1.0
epoch:3, batch_id:1, acc:1.0
epoch:3, batch_id:2, acc:1.0
epoch:4, batch_id:0, acc:1.0
epoch:4, batch_id:1, acc:1.0
epoch:4, batch_id:2, acc:1.0
epoch:5, batch_id:0, acc:1.0
epoch:5, batch_id:1, acc:1.0
epoch:5, batch_id:2, acc:1.0
epoch:6, batch_id:0, acc:1.0
epoch:6, batch_id:1, acc:1.0
epoch:6, batch_id:2, acc:1.0
epoch:7, batch_id:0, acc:1.0
epoch:7, batch_id:1, acc:1.0
epoch:7, batch_id:2, acc:1.0
epoch:8, batch_id:0, acc:1.0
epoch:8, batch_id:1, acc:1.0
epoch:8, batch_id:2, acc:1.0
epoch:9, batch_id:0, acc:1.0
epoch:9, batch_id:1, acc:1.0
epoch:9, batch_id:2, acc:1.0
epoch:10, batch_id:0, acc:1.0
epoch:10, batch_id:1, acc:1.0
epoch:10, batch_id:2, acc:1.0
epoch:11, batch_id:0, acc:1.0
epoch:11, 