In [73]:
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
import joblib
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt 
%matplotlib inline 
from matplotlib_inline import backend_inline 
backend_inline.set_matplotlib_formats('svg') 

In [74]:
class DeceasedPredictor(nn.Module):
    def __init__(self):
        super(DeceasedPredictor, self).__init__()
        # 定义网络层
        self.fc1 = nn.Linear(13, 128)  # 假设输入特征有21个
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)   # 输出层一个节点

    def forward(self, x):
        # 前向传播
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))  # 使用sigmoid函数将输出限制在0和1之间
        return x

# 实例化模型
model = DeceasedPredictor()

In [75]:
model.load_state_dict(torch.load(r'C:\Users\qyypy\Desktop\机器学习综合实践\Save_Model\DNN.pth'))
model.eval()

DeceasedPredictor(
  (fc1): Linear(in_features=13, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=1, bias=True)
)

In [76]:
data = pd.read_excel(r'C:\Users\qyypy\Desktop\机器学习综合实践\Test\Sheet2\testDataProcess.xlsx')
data

Unnamed: 0,Patient Code,Deceased,Glucose,Urea,Creatinine,Sodium,Potassium,TB,DB,ALT,...,WBC,Platelet,CRP,PCT,IL-6,PT,D-Dimer,Troponin,CPK-MB,LDH
0,P 473,No,68.0,52.0,2.0,145.0,4.0,0.9,,200.0,...,23500.0,243000.0,183.0,19.3,,,10.1,,,3091.0
1,P 982,Yes,450.0,240.0,3.8,150.0,5.5,1.2,0.4,198.0,...,8800.0,76000.0,,21.8,,,,,,2125.0
2,P 258,No,128.0,23.0,,140.0,4.0,,0.2,65.0,...,11100.0,250000.0,95.6,0.2,,,0.7,,,2111.0
3,P 969,No,128.0,178.0,2.4,129.0,4.8,1.2,0.4,200.0,...,8500.0,56000.0,,,,13.0,,,,2105.0
4,P 253,No,95.0,83.0,1.0,148.0,,,0.3,45.0,...,16800.0,311000.0,,0.1,,,,,,2016.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1226,242,,,,,,,,,,...,,,,,,,,,,
1227,243,,,,,,,,,,...,,,,,,,,,,
1228,244,,,,,,,,,,...,,,,,,,,,,
1229,245,,,,,,,,,,...,,,,,,,,,,


In [77]:
# KNN处理缺失值
from sklearn.impute import KNNImputer
from sklearn.preprocessing import StandardScaler

numerical_cols = data.select_dtypes(include=['float64', 'int64']).columns

scaler = StandardScaler()
scaled_data = scaler.fit_transform(data[numerical_cols])

knn_imputer = KNNImputer(n_neighbors=8)
imputed_data = knn_imputer.fit_transform(scaled_data)

imputed_data = scaler.inverse_transform(imputed_data)


data[numerical_cols] = imputed_data


In [78]:
data

Unnamed: 0,Patient Code,Deceased,Glucose,Urea,Creatinine,Sodium,Potassium,TB,DB,ALT,...,WBC,Platelet,CRP,PCT,IL-6,PT,D-Dimer,Troponin,CPK-MB,LDH
0,P 473,No,68.000000,52.000000,2.000000,145.00000,4.000000,0.900000,0.30000,200.000000,...,23500.000000,243000.000000,183.000000,19.300000,321.766667,14.250000,10.100000,,67.637500,3091.000000
1,P 982,Yes,450.000000,240.000000,3.800000,150.00000,5.500000,1.200000,0.40000,198.000000,...,8800.000000,76000.000000,142.975000,21.800000,321.766667,14.250000,4.500000,,37.287500,2125.000000
2,P 258,No,128.000000,23.000000,0.887500,140.00000,4.000000,0.962500,0.20000,65.000000,...,11100.000000,250000.000000,95.600000,0.200000,321.766667,13.625000,0.700000,,25.762500,2111.000000
3,P 969,No,128.000000,178.000000,2.400000,129.00000,4.800000,1.200000,0.40000,200.000000,...,8500.000000,56000.000000,126.625000,8.400000,321.766667,13.000000,3.162500,,66.362500,2105.000000
4,P 253,No,95.000000,83.000000,1.000000,148.00000,3.037500,0.912500,0.30000,45.000000,...,16800.000000,311000.000000,62.650000,0.100000,321.766667,13.500000,0.812500,,54.262500,2016.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1226,242,,163.068126,53.225353,1.442653,139.85103,4.142292,1.021581,0.32205,66.389734,...,10836.271186,231315.254237,60.743726,3.147126,321.766667,17.809249,1.760509,,34.397805,910.844086
1227,243,,163.068126,53.225353,1.442653,139.85103,4.142292,1.021581,0.32205,66.389734,...,10836.271186,231315.254237,60.743726,3.147126,321.766667,17.809249,1.760509,,34.397805,910.844086
1228,244,,163.068126,53.225353,1.442653,139.85103,4.142292,1.021581,0.32205,66.389734,...,10836.271186,231315.254237,60.743726,3.147126,321.766667,17.809249,1.760509,,34.397805,910.844086
1229,245,,163.068126,53.225353,1.442653,139.85103,4.142292,1.021581,0.32205,66.389734,...,10836.271186,231315.254237,60.743726,3.147126,321.766667,17.809249,1.760509,,34.397805,910.844086


In [79]:
df = data[985:][:]
df

Unnamed: 0,Patient Code,Deceased,Glucose,Urea,Creatinine,Sodium,Potassium,TB,DB,ALT,...,WBC,Platelet,CRP,PCT,IL-6,PT,D-Dimer,Troponin,CPK-MB,LDH
985,1,,399.000000,108.000000,4.100000,132.00000,3.100000,0.800000,0.20000,18.000000,...,14400.000000,201000.000000,89.000000,2.762500,321.766667,13.750000,2.475000,,60.387500,782.000000
986,2,,78.000000,57.000000,1.100000,141.00000,2.600000,0.887500,0.31250,114.625000,...,12100.000000,197000.000000,55.375000,0.912500,321.766667,13.375000,1.362500,,28.012500,725.000000
987,3,,198.000000,86.000000,1.200000,143.00000,3.000000,1.000000,0.40000,34.000000,...,16700.000000,188000.000000,103.125000,0.275000,321.766667,14.625000,4.950000,,52.237500,1162.500000
988,4,,152.375000,48.125000,1.000000,137.00000,3.787500,1.162500,0.28750,116.625000,...,9800.000000,184000.000000,57.925000,0.200000,321.766667,13.750000,0.612500,,29.625000,681.750000
989,5,,147.000000,64.000000,1.300000,134.00000,3.800000,0.800000,0.20000,112.000000,...,11900.000000,180000.000000,76.125000,0.212500,321.766667,13.375000,1.037500,,23.512500,858.625000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1226,242,,163.068126,53.225353,1.442653,139.85103,4.142292,1.021581,0.32205,66.389734,...,10836.271186,231315.254237,60.743726,3.147126,321.766667,17.809249,1.760509,,34.397805,910.844086
1227,243,,163.068126,53.225353,1.442653,139.85103,4.142292,1.021581,0.32205,66.389734,...,10836.271186,231315.254237,60.743726,3.147126,321.766667,17.809249,1.760509,,34.397805,910.844086
1228,244,,163.068126,53.225353,1.442653,139.85103,4.142292,1.021581,0.32205,66.389734,...,10836.271186,231315.254237,60.743726,3.147126,321.766667,17.809249,1.760509,,34.397805,910.844086
1229,245,,163.068126,53.225353,1.442653,139.85103,4.142292,1.021581,0.32205,66.389734,...,10836.271186,231315.254237,60.743726,3.147126,321.766667,17.809249,1.760509,,34.397805,910.844086


In [80]:
X = df.drop(columns=['Patient Code', 'Deceased','Total Protein', 'Albubin', 'Ferritin', 'CRP', 'PCT', 'IL-6', 'PT', 'D-Dimer', 'Troponin', 'CPK-MB', 'LDH'])
X.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 246 entries, 985 to 1230
Data columns (total 13 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   Glucose     246 non-null    float64
 1   Urea        246 non-null    float64
 2   Creatinine  246 non-null    float64
 3   Sodium      246 non-null    float64
 4   Potassium   246 non-null    float64
 5   TB          246 non-null    float64
 6   DB          246 non-null    float64
 7   ALT         246 non-null    float64
 8   AST         246 non-null    float64
 9   ALP         246 non-null    float64
 10  Hemoglobin  246 non-null    float64
 11  WBC         246 non-null    float64
 12  Platelet    246 non-null    float64
dtypes: float64(13)
memory usage: 25.1 KB


In [81]:
X = torch.tensor(X.values, dtype=torch.float32)

In [83]:
with torch.no_grad():
    y = model(X)

tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
      