In [None]:
from src.emcfsys.EMCellFiner.model import UNet
model = UNet()


In [None]:
import torch
torch.save(model.state_dict(), "model.pt")

In [None]:
from pathlib import Path
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
from skimage.transform import resize
import torch

class ImageMaskDataset(Dataset):
    def __init__(self, images_dir, masks_dir, 
                 image_ext=("png","jpg","jpeg","tif","tiff"),
                 mask_ext="png",
                 target_size=None):
        self.images_dir = Path(images_dir)
        self.masks_dir = Path(masks_dir)

        # 收集所有 image 文件
        imgs = []
        for e in image_ext:
            imgs += list(self.images_dir.glob(f"**/*.{e}"))

        # 只保留有对应 mask 的 image
        self.files = [p for p in imgs if (self.masks_dir / (p.stem + f".{mask_ext}")).exists()]
        self.target_size = target_size

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        mask_path = self.masks_dir / f"{img_path.stem}.png"  # mask 必须为 png

        # 读取 image
        im = np.array(Image.open(img_path).convert("L"))  # 转灰度

        # 读取 mask
        m = np.array(Image.open(mask_path).convert("P"))  # 保持离散标签

        # Resize
        if self.target_size is not None and im.shape != self.target_size:
            im = resize(im, self.target_size, preserve_range=True)
            m = resize(m, self.target_size, preserve_range=True, order=0)  # 最近邻插值

        # Normalize image [0,1]
        im = im.astype("float32")
        if im.max() > im.min():
            im = (im - im.min()) / (im.max() - im.min())

        im = im[np.newaxis, ...]  # C,H,W
        m = (m > 0).astype("float32")[np.newaxis, ...]  # binarize mask

        return torch.from_numpy(im), torch.from_numpy(m)


In [None]:
dataset = ImageMaskDataset(r"D:\napari_EMCF\EMCFsys\emcfsys\image_low",
                           r"D:\napari_EMCF\EMCFsys\emcfsys\label",
                           target_size=(256,256))


print("dataset size: ", dataset.__len__())

In [None]:
for i in range(len(dataset)):
    img, mask = dataset[i]
    print(img.shape, mask.shape)
    break

In [None]:
img

In [None]:
logs = []
epoch_times = []
metrics_all = []
def cb(epoch, batch, n_batches, loss, finished_epoch=False, epoch_time=None, model_dict=None, metrics=None):
    if metrics is not None:
        metrics_all.append(metrics)
    # 保存 batch/epoch 日志
    logs.append((epoch, batch, n_batches, loss, finished_epoch, epoch_time, metrics))
    return logs

In [None]:
from src.emcfsys.EMCellFiner.train import train_loop

logs = []

def my_callback(epoch, batch, n_batches, loss, finished_epoch=False, epoch_time=None, model_dict=None, metrics=None):
    if metrics is not None:
        print(f"[Epoch {epoch}] batch {batch}/{n_batches}, loss={loss}, metric={metrics}")
    # print(f"[Epoch {epoch}] batch {batch}/{n_batches}, loss={loss}, metric={metrics}")

train_loop(images_dir=r"D:\napari_EMCF\EMCFsys\emcfsys\image", 
        masks_dir=r"D:\napari_EMCF\EMCFsys\emcfsys\label", 
        save_path= r"D:\napari_EMCF\EMCFsys\emcfsys\save",
        pretrained_model=None,#r"D:\napari_EMCF\EMCFsys\emcfsys\save\best_model_epoch_21.pth",
        lr=1e-3, 
        batch_size=4, 
        epochs=100, 
        device=None,
        callback=my_callback, 
        target_size=(512, 512), 
        in_channels=1, 
        classes_num=2, 
        ignore_index=-1)

In [None]:
import torch
torch.load(r"D:\napari_EMCF\EMCFsys\emcfsys\save\best_model_epoch_1.pth")