In [12]:
import numpy as np
import pandas as pd
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F


In [4]:
input_path = r"C:\Users\Admin\Desktop\WHK ENDECODE\traindata_npy\inputs.npy"
out_path = r"C:\Users\Admin\Desktop\WHK ENDECODE\traindata_npy\outputs.npy"

x = np.load(input_path).astype(np.float32)
y = np.load(out_path).astype(np.float32)

In [9]:
class SolderJointDataset(Dataset):
    def __init__(self, input_path, output_path, image_dir, transform=None):
        # 读取输入和输出数据
        self.x = np.load(input_path)     # (351, 3): T, A, YM
        self.y = np.load(output_path)    # (351, 48): 特征向量
        self.image_dir = image_dir
        self.transform = transform

        # 读取所有图像文件名
        self.image_files = os.listdir(image_dir)

        # 做一个文件名索引，提升查找速度
        self.filename_map = {self.construct_filename(*self.x[i]): i for i in range(len(self.x))}

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        t, a, ym = self.x[idx]
        features = self.y[idx]

        filename = self.construct_filename(t, a, ym)
        image_path = os.path.join(self.image_dir, filename)

        if not os.path.exists(image_path):
            raise FileNotFoundError(f"图像文件不存在: {image_path}")

        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        return {
            "input": torch.tensor(features, dtype=torch.float32),     # shape: (48,)
            "condition": torch.tensor([t, a, ym], dtype=torch.float32),  # shape: (3,)
            "image": image  # shape: (3, H, W)
        }

    @staticmethod
    def float_to_comma_str(val):
        return f"{val:.2f}".replace('.', ',')

    def construct_filename(self, t, a, ym):
        t_str = self.float_to_comma_str(t)
        a_str = f"{a:.3f}".replace('.', ',')
        ym_str = self.float_to_comma_str(ym)
        return f"T{t_str} A{a_str} YM{ym_str}.png"

In [11]:
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

dataset = SolderJointDataset(
    input_path=r"C:\Users\Admin\Desktop\WHK ENDECODE\traindata_npy\inputs.npy",
    output_path=r"C:\Users\Admin\Desktop\WHK ENDECODE\traindata_npy\outputs.npy",
    image_dir=r"C:\Users\Admin\Desktop\WHK ENDECODE\only image",
    transform=transform
)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

for batch in dataloader:
    input_vec = batch["input"]        # (B, 48)
    condition = batch["condition"]    # (B, 3)
    label_img = batch["image"]        # (B, 3, H, W)
    # 在这里你可以送入模型进行训练


In [None]:


# 获取一个 batch
batch = next(iter(dataloader))

input_vec = batch["input"]        # (B, 48)
condition = batch["condition"]    # (B, 3)
label_img = batch["image"]        # (B, 3, 256, 256)

# 打印 shape 和部分数值
print("Input vector shape:", input_vec.shape)
print("Condition shape:", condition.shape)
print("Image shape:", label_img.shape)

print("Sample input vector:", input_vec[0][:5])  # 前5个特征
print("Sample condition:", condition[0])         # T, A, YM

# 可视化图像
img_tensor = label_img[0]  # 取第一张
img = F.to_pil_image(img_tensor)  # 转为 PIL 图像
plt.imshow(img)
plt.title(f"T={condition[0][0]:.2f}, A={condition[0][1]:.3f}, YM={condition[0][2]:.2f}")
plt.axis('off')
plt.show()


In [16]:
for i in range(5):  # 只看前5个
    t, a, ym = dataset.x[i]
    constructed_name = dataset.construct_filename(t, a, ym)
    print(f"Sample {i}: {constructed_name}")

Sample 0: T-19,25 A0,200 YM20,00.png
Sample 1: T-19,25 A0,200 YM23,50.png
Sample 2: T-19,25 A0,200 YM27,00.png
Sample 3: T-19,25 A0,290 YM20,00.png
Sample 4: T-19,25 A0,290 YM23,50.png
