In [45]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import numpy as np
from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import random
from scipy.ndimage import label as cc_label
# Пути к конфигу и модели
CFG_PATH = "configs/sam2.1/sam2.1_hiera_b+.yaml"
MODEL_PATH = Path("/storage01/miroslavm/3d_segmentation_models/seg_anything/sam2.1_finetuned_b+.pt")
DEVICE = "cuda:1"

In [46]:
def augment(img: np.ndarray, msk: np.ndarray, pt: list | None):
    h, w = msk.shape

    # --- H-flip ----------------------------------------------------------
    if random.random() < 0.5:
        img, msk = img[:, ::-1], msk[:, ::-1]
        if pt is not None:
            pt[0] = w - 1 - pt[0]

    # --- V-flip ----------------------------------------------------------
    if random.random() < 0.5:
        img, msk = img[::-1, :], msk[::-1, :]
        if pt is not None:
            pt[1] = h - 1 - pt[1]

    # --- ±90° rotation ---------------------------------------------------
    if random.random() < 0.25:
        k = random.choice([1, 3])           # +90° или −90°
        img, msk = np.rot90(img, k), np.rot90(msk, k)
        if pt is not None:
            if k == 1:
                pt = [pt[1], w - 1 - pt[0]]
            else:
                pt = [h - 1 - pt[1], pt[0]]

    img = np.ascontiguousarray(img)
    msk = np.ascontiguousarray(msk)
    return img, msk, pt

In [47]:
sam2_model = build_sam2(str(CFG_PATH), ckpt_path=None, device=DEVICE)

# Загружаем state_dict вручную
sd = torch.load(str(MODEL_PATH), map_location="cpu")
sam2_model.load_state_dict(sd)
# Создаем предиктор
predictor = SAM2ImagePredictor(sam2_model, device=DEVICE)

In [48]:
ROOT = Path("/storage01/miroslavm/3d_segmentation_models/data_rat/mose_jpeg/")
IMG_DIR = ROOT / "images"
MSK_DIR = ROOT / "masks"
do_augment=False
class MoseDataset(Dataset):
    """
    * Одна PNG-маска может содержать несколько объектов.
    * Каждая связная компонента => отдельный элемент датасета.
    """
    def __init__(self, img_dir, msk_dir, neg_prompt=True, do_augment=False):
        self.img_dir, self.msk_dir = Path(img_dir), Path(msk_dir)
        self.neg_prompt = neg_prompt
        self.items = []                         # [(img_name, comp_label), ...]

        for img_path in self.img_dir.glob("*.jpg"):
            name = img_path.stem
            msk_path = self.msk_dir / f"{name}.png"
            if not msk_path.exists():
                continue
            msk = cv2.imread(str(msk_path), cv2.IMREAD_GRAYSCALE)
            if msk is None or msk.max() == 0:
                continue
            lbl, n = cc_label(msk > 0)
            for cid in range(1, n+1):
                self.items.append((name, cid))

    def __len__(self): return len(self.items)

    def __getitem__(self, idx):
        name, cid = self.items[idx]
        img = cv2.imread(str(self.img_dir/f"{name}.jpg"))[:, :, ::-1]   # RGB
        msk_full = cv2.imread(str(self.msk_dir/f"{name}.png"),
                              cv2.IMREAD_GRAYSCALE)
        lbl, _ = cc_label(msk_full > 0)
        msk = (lbl == cid).astype(np.float32)          # (H,W) 0/1

        # positive prompt (внутри объекта)
        ys, xs = np.where(msk > 0)
        j = np.random.randint(len(xs))
        pos_pt = [int(xs[j]), int(ys[j])]

        # optional negative prompt
        pts  = [pos_pt]
        labs = [1.]
        if self.neg_prompt:
            h, w = msk.shape
            nx, ny = np.random.randint(0, w), np.random.randint(0, h)
            while msk[ny, nx]:
                nx, ny = np.random.randint(0, w), np.random.randint(0, h)
            pts.append([nx, ny]); labs.append(0.)

        if do_augment==True:
            img, msk, pos_pt = augment(img, msk, pos_pt)

        pts = torch.tensor(pts,  dtype=torch.float32)   # (N,2)
        labs = torch.tensor(labs, dtype=torch.float32)  # (N,)

        return {
            "image": img.copy(),                        # np.ndarray H×W×3
            "mask" : torch.from_numpy(msk).unsqueeze(0),# (1,H,W)
            "point": pts,                               # (N,2)
            "label": labs                               # (N,)
        }

# Загрузка валидационного сета
dataset = MoseDataset(IMG_DIR, MSK_DIR, neg_prompt=False, do_augment=False)
val_loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [50]:
sam2_model.eval()  # отключаем dropout и прочее

for batch in val_loader:
    if batch is not None:
        print('exit')
        break
    
    idx = random.randint(0, len(batch["image"]) - 1)
    # ---------- данные ----------
    img_np  = sample["image"][idx]                         # np.ndarray (H,W,3)
    gt_mask = sample["mask"][idx].to(DEVICE)               # (1,H,W)
    pts     = sample["point"][idx].unsqueeze(0).to(DEVICE) # (1,N,2)
    lbs     = sample["label"][idx].unsqueeze(0).to(DEVICE) # (1,N)

    # ---------- forward ----------
    predictor.set_image_batch([img_np])
    with torch.no_grad(), torch.amp.autocast(device_type='cuda'):
        sp_emb, dn_emb = predictor.model.sam_prompt_encoder(
            points=(pts, lbs), boxes=None, masks=None)

        hi = [f[-1].unsqueeze(0) for f in predictor._features["high_res_feats"]]
        low_logits, scores, _, _ = predictor.model.sam_mask_decoder(
            predictor._features["image_embed"],
            predictor.model.sam_prompt_encoder.get_dense_pe(),
            sp_emb, dn_emb, False, False, hi)

        # Апскейлим все маски
        logits_up_all = torch.nn.functional.interpolate(
            low_logits, size=gt_mask.shape[-2:], mode="bilinear", align_corners=False)
        prob_all = torch.sigmoid(logits_up_all).cpu().numpy()  # (N,H,W)
        scores_np = scores[:, 0].cpu().numpy()

    # ---------- координаты точек ----------
    pos_pt = sample["point"][0].cpu().numpy()
    neg_pt = sample["point"][1].cpu().numpy() if sample["label"].shape[0] > 1 else None

    # ---------- визуализация ----------
    N = prob_all.shape[0]  # кол-во масок
    plt.figure(figsize=(5 * (N + 2), 4))

    # 1. Вход + промпты
    plt.subplot(1, N+2, 1)
    plt.imshow(img_np)
    plt.scatter([pos_pt[0]], [pos_pt[1]], c='g', marker='+', s=120, lw=2, label='pos')
    if neg_pt is not None:
        plt.scatter([neg_pt[0]], [neg_pt[1]], c='r', marker='x', s=120, lw=2, label='neg')
    plt.legend(loc='upper right'); plt.axis('off'); plt.title("Input")

    # 2. GT маска
    plt.subplot(1, N+2, 2)
    plt.imshow(gt_mask.squeeze(0).cpu(), cmap='gray')
    plt.scatter([pos_pt[0]], [pos_pt[1]], c='g', marker='+', s=120, lw=2)
    if neg_pt is not None:
        plt.scatter([neg_pt[0]], [neg_pt[1]], c='r', marker='x', s=120, lw=2)
    plt.axis('off'); plt.title("GT Mask")

    # 3+. Предсказанные маски
    for i in range(N):
        plt.subplot(1, N+2, i+3)
        plt.imshow((prob_all[i] > 0.4).squeeze(), cmap='gray')
        plt.title(f"Pred #{i+1}\n(score={scores_np[i]:.2f})")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

exit
