In [1]:
import os
import cv2
import time
import torch
import random
import imageio
import torchvision
import py_sod_metrics
import numpy as np
import torch.nn as nn
import torch.optim as opt
import torch.nn.init as init
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.nn import Module, Parameter, Softmax
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.models import resnet34

from sam2.build_sam import build_sam2
from dataset import TestDataset,FullDataset

In [2]:
time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1" 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
hiera_path = './sam2_weight/sam2_hiera_large.pt'

dataset_name = 'CVC-300'
dataset_base = ''

train_image_path = os.path.join(dataset_base,'TrainDataset','image/')
train_mask_path = os.path.join(dataset_base,'TrainDataset','masks/')
test_image_path = os.path.join(dataset_base,'TestDataset',dataset_name,'images/')
test_mask_path = os.path.join(dataset_base,'TestDataset',dataset_name,'masks/')

In [4]:
class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

In [5]:
class DepthWiseConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
        super(DepthWiseConv2d, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels,
                               bias=bias)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x

In [6]:
class LayerNorm(nn.Module):
    """ 
    LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x
        

class SpatialAttention(nn.Module):
    def __init__(self, dim, kernel_size, expand_ratio=2):
        super().__init__()
        self.norm = LayerNorm(dim, eps=1e-6, data_format="channels_first")
        self.att = nn.Sequential(
                nn.Conv2d(dim, dim, 1),
                nn.GELU(),
                nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim)
        )
        self.v = nn.Conv2d(dim, dim, 1)
        self.proj = nn.Conv2d(dim, dim, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.norm(x)        
        x = self.att(x) * self.v(x)
        x = self.proj(x)
        return x

In [7]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

In [8]:
def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class Conv(nn.Module):
    """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""

    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given arguments including activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor."""
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        """Perform transposed convolution of 2D data."""
        return self.act(self.conv(x))


class Pinwheel_shapedConv(nn.Module):  
    ''' Pinwheel-shaped Convolution using the Asymmetric Padding method. '''
    
    def __init__(self, c1, c2, k, s):
        super().__init__()

        # self.k = k
        p = [(k, 0, 1, 0), (0, k, 0, 1), (0, 1, k, 0), (1, 0, 0, k)]
        self.pad = [nn.ZeroPad2d(padding=(p[g])) for g in range(4)]
        self.cw = Conv(c1, c2 // 4, (1, k), s=s, p=0)
        self.ch = Conv(c1, c2 // 4, (k, 1), s=s, p=0)
        self.cat = Conv(c2, c2, 2, s=1, p=0)

    def forward(self, x):
        yw0 = self.cw(self.pad[0](x))
        yw1 = self.cw(self.pad[1](x))
        yh0 = self.ch(self.pad[2](x))
        yh1 = self.ch(self.pad[3](x))
        return self.cat(torch.cat([yw0, yw1, yh0, yh1], dim=1))

In [9]:
class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

In [10]:
class Adapter_linear(nn.Module):
    def __init__(self, blk) -> None:
        super(Adapter_linear, self).__init__()
        self.block = blk
        dim = blk.attn.qkv.in_features
        self.prompt_learn = nn.Sequential(
            nn.Linear(dim, 256),
            nn.GELU(),
            nn.Linear(256, dim),
            nn.GELU()
        )

    def forward(self, x):
        prompt = self.prompt_learn(x)
        promped = x + prompt
        net = self.block(promped)
        return net

In [11]:
class RFB_modified(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(RFB_modified, self).__init__()
        self.relu = nn.ReLU(True)
        self.branch0 = nn.Sequential(
            # CBAM(in_channels=in_channel, ratio=16, kernel_size=7),
            BasicConv2d(in_channel, out_channel, 1),
        )
        self.branch1 = nn.Sequential(
            Pinwheel_shapedConv(in_channel,out_channel,k=3,s=1),
        )
        self.branch2 = nn.Sequential(
            Pinwheel_shapedConv(in_channel,out_channel,k=5,s=1),
        )
        self.branch3 = nn.Sequential(
            SpatialAttention(dim=in_channel, kernel_size=3),
            BasicConv2d(in_channel, out_channel, 1,),
        )
        self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1)
        self.conv_res = BasicConv2d(in_channel, out_channel, 1)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))

        x = self.relu(x_cat + self.conv_res(x))
        return x

In [12]:
class SAM2UNet(nn.Module):
    def __init__(self, checkpoint_path=None) -> None:
        super(SAM2UNet, self).__init__()    
        model_cfg = "sam2_hiera_l.yaml"
        if checkpoint_path:
            model = build_sam2(model_cfg, checkpoint_path)
        else:
            model = build_sam2(model_cfg)
        del model.sam_mask_decoder
        del model.sam_prompt_encoder
        del model.memory_encoder
        del model.memory_attention
        del model.mask_downsample
        del model.obj_ptr_tpos_proj
        del model.obj_ptr_proj
        del model.image_encoder.neck
        self.encoder = model.image_encoder.trunk

        for param in self.encoder.parameters():
            param.requires_grad = False
        blocks = []
        i = 0
        for block in self.encoder.blocks:
          if i % 2 == 0:
              blocks.append(
              Adapter_linear(block)
              )
          else:
              blocks.append(block)
          i += 1
        self.encoder.blocks = nn.Sequential(
            *blocks
        )
        
        self.rfb1 = RFB_modified(144, 64)
        self.rfb2 = RFB_modified(288, 64)
        self.rfb3 = RFB_modified(576, 64)
        self.rfb4 = RFB_modified(1152, 64)

        self.up1 = (Up(128, 64))
        self.up2 = (Up(128, 64))
        self.up3 = (Up(128, 64))
        self.up4 = (Up(128, 64))
        self.side1 = nn.Conv2d(64, 1, kernel_size=1)
        self.side2 = nn.Conv2d(64, 1, kernel_size=1)
        self.head = nn.Conv2d(64, 1, kernel_size=1)
        
        self.res1 = nn.Conv2d(144, 64, kernel_size=1)
        self.res2 = nn.Conv2d(288, 64, kernel_size=1)
        self.res3 = nn.Conv2d(576, 64, kernel_size=1)
        
    def forward(self, x):
        x1_res, x2_res, x3_res, x4_res = self.encoder(x)
        x1, x2, x3, x4 = self.rfb1(x1_res), self.rfb2(x2_res), self.rfb3(x3_res), self.rfb4(x4_res)
        x = self.up1(x4, x3)
        x = x * self.res3(x3_res)
        out1 = F.interpolate(self.side1(x), scale_factor=16, mode='bilinear')
        x = self.up2(x, x2)
        x = x * self.res2(x2_res)
        out2 = F.interpolate(self.side2(x), scale_factor=8, mode='bilinear')
        x = self.up3(x, x1)
        x = x * self.res1(x1_res)
        out = F.interpolate(self.head(x), scale_factor=4, mode='bilinear')
        
        return out, out1, out2

In [13]:
def test_all(model, path):
    data_path = path
    image_root = '{}/images/'.format(data_path)
    gt_root = '{}/masks/'.format(data_path)
    model.eval()
    num1 = len(os.listdir(gt_root))
    test_loader = TestDataset(image_root, gt_root, 512)
    
    # 初始化三个评估指标
    DSC = 0.0
    mIOU = 0.0
    MAE = 0.0

    for i in range(num1):
        image, gt,_ = test_loader.load_data()
        gt = np.asarray(gt, np.float32)
        gt /= (gt.max() + 1e-8)
        image = image.cuda()

        res, _, _  = model(image)
        
        # 调整尺寸并归一化
        res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
        res = res.sigmoid().data.cpu().numpy().squeeze()
        res = (res - res.min()) / (res.max() - res.min() + 1e-8)

        # 计算Dice系数
        input = res
        target = np.array(gt)
        smooth = 1
        input_flat = input.flatten()
        target_flat = target.flatten()
        intersection = (input_flat * target_flat).sum()
        dice = (2 * intersection + smooth) / (input.sum() + target.sum() + smooth)
        DSC += dice

        # 计算MAE（不需要阈值）
        mae = np.abs(res - gt).mean()
        MAE += mae

        # 计算mIOU（需要二值化）
        threshold = 0.5
        pred_binary = (res >= threshold).astype(np.float32)
        gt_binary = gt.astype(np.float32)
        
        intersection = (pred_binary * gt_binary).sum()
        union = pred_binary.sum() + gt_binary.sum() - intersection
        iou = (intersection + 1e-8) / (union + 1e-8)
        mIOU += iou

    # 返回三个指标的平均值
    return DSC/num1, mIOU/num1, MAE/num1

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
testpath = '/home/mlw/SGN/Nets/datasets/TestDataset/CVC-300'
checkpoint = './model/CVC-300/Best_ours_full.pth'
model = nn.DataParallel(SAM2UNet(hiera_path).to(device))
model = torch.load(checkpoint, map_location=device)
model.eval()
meandice,meanIOU,meanMAE = test_all(model, testpath)
print(meandice,meanIOU,meanMAE)

  res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)


0.9189201 0.8621931 0.0062314295
