In [None]:
#!/usr/bin/env conda run -n pytorch02 python


import os
from PIL import Image
import torch
from torchvision import transforms

def preprocess_dataset(normal_folder_path, pneumonia_folder_path, output_file):
    normal_image_files = os.listdir(normal_folder_path)
    pneumonia_image_files = os.listdir(pneumonia_folder_path)
    
    # 修改transforms，去掉RGB转换，只用一个通道的均值和标准差
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485], std=[0.229])
    ])

    preprocessed_data = []

    # 处理Normal图像
    for image_file in normal_image_files:
        image_path = os.path.join(normal_folder_path, image_file)
        image = Image.open(image_path).convert("L")  # 转换为灰度图像
        image = transform(image)
        label = 0  # Normal
        preprocessed_data.append((image, label))

    # 处理Pneumonia图像
    for image_file in pneumonia_image_files:
        image_path = os.path.join(pneumonia_folder_path, image_file)
        image = Image.open(image_path).convert("L")  # 转换为灰度图像
        image = transform(image)
        if "bacteria" in image_file:
            label = 1  # Bacteria Pneumonia
        else:
            label = 2  # Viral Pneumonia
        preprocessed_data.append((image, label))
    
    # 保存预处理后的数据
    torch.save(preprocessed_data, output_file)

# Example usage:
# preprocess_dataset("path_to_normal_images", "path_to_pneumonia_images", "output.pth")



# Define the paths to the folders
train_normal_folder = "datasets/chest_xray/train/NORMAL"
train_pneumonia_folder = "datasets/chest_xray/train/PNEUMONIA"

val_normal_folder = "datasets/chest_xray/val/NORMAL"
val_pneumonia_folder = "datasets/chest_xray/val/PNEUMONIA"

test_normal_folder = "datasets/chest_xray/test/NORMAL"
test_pneumonia_folder = "datasets/chest_xray/test/PNEUMONIA"

train_output_file = "preprocessed_datasets/preprocessed_train_data.pt"
val_output_file = "preprocessed_datasets/preprocessed_val_data.pt"
test_output_file = "preprocessed_datasets/preprocessed_test_data.pt"

# 预处理
preprocess_dataset(train_normal_folder,train_pneumonia_folder, train_output_file)
preprocess_dataset(val_normal_folder,val_pneumonia_folder, val_output_file)
preprocess_dataset(test_normal_folder,test_pneumonia_folder, test_output_file)
