In [None]:

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import dataset
import torchvision.models as models
from torch.utils.data import DataLoader

#训练集transform
#归一化
normMean = [0.5]
normStd = [0.5]
normTransform = transforms.Normalize(normMean, normStd)

train_transform = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.RandomCrop(224, padding=4),
            transforms.ToTensor(),
            normTransform,
        ]
    )
#验证集transform
val_transform = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normTransform,
        ]
    )

# 加载数据集
#训练集
train_set = dataset.Pdataset(
    root_dir="I:\PythonCode\pytorch\chest_xray/train/", transform=train_transform
)
#验证集
val_set = dataset.Pdataset(
    root_dir="I:\PythonCode\pytorch\chest_xray/test/", transform=val_transform
)

#dataloader
train_loader=DataLoader(train_set,batch_size=32,shuffle=True)
val_loader=DataLoader(val_set,batch_size=32,shuffle=False)

#使用预训练模型ResNet50


model = models.resnet50(pretrained=True)
#替换
model.conv1=nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 训练模型
num_epochs = 3

for epoch in range(num_epochs):
    model.train()
    training_loss = 0.0
    training_corrects = 0
    
    for inputs, labels in train_loader:
        inputs = inputs.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        labels = labels.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # 计算准确率
        batch_size=labels.size(0)
        preds = torch.argmax(outputs, dim=1)
        current_corrects = torch.sum(preds == labels.data)
        training_corrects += current_corrects
        training_loss += batch_size - current_corrects
        accuracy = training_corrects / (training_loss + training_corrects)
        print(f"accuracy:{accuracy.item():.4f}",end="\r")
    print(f"epoch:{epoch}")
    print(f"accuracy:{accuracy.item():.4f}")
    print(f"training_loss:{training_loss.item()}")
    print(f"training_corrects:{training_corrects.item()}")
        
    
    val_loss = 0.0
    val_corrects = 0
    model.eval()
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
            labels = labels.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            # 计算准确率
            batch_size=labels.size(0)
            preds = torch.argmax(outputs, dim=1)
            current_corrects = torch.sum(preds == labels.data)
            val_corrects += current_corrects
            val_loss += batch_size - current_corrects
            accuracy = val_corrects / (val_loss + val_corrects)
            print(f"accuracy:{accuracy.item():.4f}")
            print(f"val_loss:{val_loss.item()}")
            print(f"val_corrects:{val_corrects.item()}")



  root_dir="I:\PythonCode\pytorch\chest_xray/train/", transform=train_transform
  root_dir="I:\PythonCode\pytorch\chest_xray/test/", transform=val_transform


epoch:0y:0.9281
accuracy:0.9281
training_loss:376.0
training_corrects:4856
accuracy:0.7500
val_loss:8.0
val_corrects:24
accuracy:0.7812
val_loss:14.0
val_corrects:50
accuracy:0.8125
val_loss:18.0
val_corrects:78
accuracy:0.7422
val_loss:33.0
val_corrects:95
accuracy:0.6687
val_loss:53.0
val_corrects:107
accuracy:0.6823
val_loss:61.0
val_corrects:131
accuracy:0.6741
val_loss:73.0
val_corrects:151
accuracy:0.7070
val_loss:75.0
val_corrects:181
accuracy:0.7361
val_loss:76.0
val_corrects:212
accuracy:0.7594
val_loss:77.0
val_corrects:243
accuracy:0.7812
val_loss:77.0
val_corrects:275
accuracy:0.7943
val_loss:79.0
val_corrects:305
accuracy:0.8029
val_loss:82.0
val_corrects:334
accuracy:0.8170
val_loss:82.0
val_corrects:366
accuracy:0.8292
val_loss:82.0
val_corrects:398
accuracy:0.8398
val_loss:82.0
val_corrects:430
accuracy:0.8474
val_loss:83.0
val_corrects:461
accuracy:0.8559
val_loss:83.0
val_corrects:493
accuracy:0.8618
val_loss:84.0
val_corrects:524
accuracy:0.8654
val_loss:84.0
val_cor