In [1]:
from src.emcfsys.EMCellFiner.model import UNet
model = UNet()


In [2]:
import torch
torch.save(model.state_dict(), "model.pt")

In [39]:
from pathlib import Path
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
from skimage.transform import resize
import torch

class ImageMaskDataset(Dataset):
    def __init__(self, images_dir, masks_dir, 
                 image_ext=("png","jpg","jpeg","tif","tiff"),
                 mask_ext="png",
                 target_size=None):
        self.images_dir = Path(images_dir)
        self.masks_dir = Path(masks_dir)

        # 收集所有 image 文件
        imgs = []
        for e in image_ext:
            imgs += list(self.images_dir.glob(f"**/*.{e}"))

        # 只保留有对应 mask 的 image
        self.files = [p for p in imgs if (self.masks_dir / (p.stem + f".{mask_ext}")).exists()]
        self.target_size = target_size

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        mask_path = self.masks_dir / f"{img_path.stem}.png"  # mask 必须为 png

        # 读取 image
        im = np.array(Image.open(img_path).convert("L"))  # 转灰度

        # 读取 mask
        m = np.array(Image.open(mask_path).convert("P"))  # 保持离散标签

        # Resize
        if self.target_size is not None and im.shape != self.target_size:
            im = resize(im, self.target_size, preserve_range=True)
            m = resize(m, self.target_size, preserve_range=True, order=0)  # 最近邻插值

        # Normalize image [0,1]
        im = im.astype("float32")
        if im.max() > im.min():
            im = (im - im.min()) / (im.max() - im.min())

        im = im[np.newaxis, ...]  # C,H,W
        m = (m > 0).astype("float32")[np.newaxis, ...]  # binarize mask

        return torch.from_numpy(im), torch.from_numpy(m)


In [40]:
dataset = ImageMaskDataset(r"D:\napari_EMCF\EMCFsys\emcfsys\image_low",
                           r"D:\napari_EMCF\EMCFsys\emcfsys\label",
                           target_size=(256,256))


print("dataset size: ", dataset.__len__())

dataset size:  5


In [41]:
for i in range(len(dataset)):
    img, mask = dataset[i]
    print(img.shape, mask.shape)
    break

torch.Size([1, 256, 256]) torch.Size([1, 256, 256])


In [42]:
img

tensor([[[0.4228, 0.4383, 0.4874,  ..., 0.5364, 0.6123, 0.5702],
         [0.3930, 0.4217, 0.3769,  ..., 0.5039, 0.6302, 0.4923],
         [0.3505, 0.3374, 0.3301,  ..., 0.4491, 0.5586, 0.5121],
         ...,
         [0.4873, 0.3922, 0.2926,  ..., 0.5086, 0.4271, 0.4279],
         [0.4900, 0.3937, 0.2902,  ..., 0.6072, 0.5319, 0.4870],
         [0.4719, 0.3616, 0.3080,  ..., 0.6700, 0.6121, 0.5937]]])