In [2]:
from torch.utils.data import TensorDataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

class NetWork(nn.Module):
    def __init__(self, input_feature, output):
        super().__init__()
        self.fc1 = nn.Linear(input_feature, 300, bias=True)
        self.fc2 = nn.Linear(300, output, bias=True)
        self.fc3 = nn.Softmax(dim=1)
        nn.init.xavier_normal_(self.fc1.weight)
        nn.init.xavier_normal_(self.fc2.weight)
    
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

# 加载数据
X_train = torch.load("X_train.pt")
Y_train = torch.load("Y_train.pt")
X_test = torch.load("X_test.pt")
Y_test = torch.load("Y_test.pt")

# 确认是否有可用的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 将数据移到设备上
X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)

# 定义损失函数和优化器
loss = nn.CrossEntropyLoss(reduction="mean")
model = NetWork(300, 4).to(device)
ds = TensorDataset(X_train, Y_train)
DL = DataLoader(ds, batch_size=256, shuffle=True)
optimizer = optim.SGD(params=model.parameters(), lr=0.001)
writer = SummaryWriter(log_dir="logs")

# 训练模型
epoch = 1000
for ep in range(epoch):
    for X, Y in DL:
        X = X.to(device)
        Y = Y.to(device)
        Y_pred = model(X)
        CEloss = loss(Y_pred, Y)
        optimizer.zero_grad()
        CEloss.backward()
        optimizer.step()

# 切换到CPU上进行评估
X_train = X_train.to("cpu")
Y_train = Y_train.to("cpu")
X_test = X_test.to("cpu")
Y_test = Y_test.to("cpu")
model = model.to("cpu")

# 评估训练数据
Y_pred = model(X_train)
result = torch.max(Y_pred.data, dim=1).indices
accuracy = result.eq(Y_train).sum().numpy()/len(Y_pred)
print("学習データ:", accuracy)

# 评估测试数据
Y_pred = model(X_test)
result = torch.max(Y_pred.data, dim=1).indices
accuracy = result.eq(Y_test).sum().numpy()/len(Y_pred)
print("評価データ:", accuracy)


学習データ: 0.7847436978727392
評価データ: 0.7901049475262368
