In [None]:
import os
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, applications
import ssl

ssl._create_default_https_context = ssl._create_unverified_context


def build_resnet50_unet(input_size=(256, 256, 3)):
    base_model = applications.ResNet50(
        include_top=False, weights='imagenet', input_shape=input_size
    )
    
    # 获取 ResNet-50 中间层特征图，作为 U-Net 的跳跃连接
    layer_names = [
        'conv1_relu',  # 64x64x64
        'conv2_block3_out',  # 32x32x128
        'conv3_block4_out',  # 16x16x256
        'conv4_block6_out'   # 8x8x512
    ]
    encoder_outputs = [base_model.get_layer(name).output for name in layer_names]

    # U-Net 解码器部分
    inputs = base_model.input
    x = encoder_outputs[-1]  # Bottleneck

    for filters, skip in zip([256, 128, 64, 32], reversed(encoder_outputs[:-1])):
        x = layers.Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same')(x)
        x = layers.concatenate([x, skip])
        x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(x)
        x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(x)
    
    x = layers.UpSampling2D(size=(2, 2))(x)  # 使用 UpSampling 来恢复尺寸

    # 输出层
    outputs = layers.Conv2D(3, (1, 1), activation='sigmoid', padding='same')(x)
    model = models.Model(inputs, outputs)

    return model

# 构建模型
model = build_resnet50_unet(input_size=(256, 256, 3))
model.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])

# 打印模型结构
model.summary()

# 加载训练数据
def load_images_from_nested_folders(folder, target_size=(256, 256)):
    images = []
    for subfolder in os.listdir(folder):
        subfolder_path = os.path.join(folder, subfolder)
        if not os.path.isdir(subfolder_path):
            continue
        for inner_subfolder in os.listdir(subfolder_path):
            inner_subfolder_path = os.path.join(subfolder_path, inner_subfolder)
            if not os.path.isdir(inner_subfolder_path):
                continue
            img_x_path = os.path.join(inner_subfolder_path, '0.png')
            img_y_path = os.path.join(inner_subfolder_path, '4.png')
            if not (os.path.exists(img_x_path) and os.path.exists(img_y_path)):
                continue
            img_x = np.array(Image.open(img_x_path).convert('RGB').resize(target_size)) / 255.0
            img_y = np.array(Image.open(img_y_path).convert('RGB').resize(target_size)) / 255.0
            images.append((img_x, img_y))
    return images

# 加载数据
train_folder = "/Users/txh/Desktop/test"
data = load_images_from_nested_folders(train_folder)
print(f"Total pairs loaded: {len(data)}")

x_train = np.array([item[0] for item in data])
y_train = np.array([item[1] for item in data])
print(f"x_train shape: {x_train.shape}, y_train shape: {y_train.shape}")

# 训练模型
model.fit(x_train, y_train, batch_size=3, epochs=10)




In [None]:
# 测试模型
x_test_path = '/Users/txh/Desktop/test/new_data_list/35f3c444/0.png'
x_test = Image.open(x_test_path).convert('RGB').resize((256, 256))
x_test = np.array(x_test) / 255.0
x_test = np.expand_dims(x_test, axis=0)
y_pred = model.predict(x_test)

# 显示预测结果
import matplotlib.pyplot as plt
plt.subplot(1, 2, 1)
plt.imshow(x_test[0])
plt.title('Input Image')
plt.subplot(1, 2, 2)
plt.imshow(y_pred[0])
plt.title('Predicted Image')
plt.show()