In [1]:
from PIL import Image

def pad_image(image, target_size, type='rgb'):
    """
    :param image: input image
    :param target_size: a tuple (num,num)
    :return: new image
    """
    iw, ih = image.size
    w, h = target_size

    scale = min(w / iw, h / ih)

    nw = int(iw * scale + 0.5)
    nh = int(ih * scale + 0.5)

    image = image.resize((nw, nh), Image.BICUBIC)
    if type == 'rgb':
        new_image = Image.new('RGB', target_size, (0, 0, 0))
    else:
        new_image = Image.new('L', target_size, (0))
    new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))

    return new_image

In [3]:
import os
from PIL import Image

# Define the path to the folder containing the images
folder_path = '../datasets/aim500/test/trimap'

# Get a list of all the image file names in the folder
image_files = [file for file in os.listdir(folder_path) if file.endswith('.png')]

# Loop through each image file
for file_name in image_files:
    # Construct the full file path
    file_path = os.path.join(folder_path, file_name)
    
    # Open the image
    image = Image.open(file_path)
    
    # Apply the pad_image function to crop the image
    cropped_image = pad_image(image, (512, 512))
    
    # Define the output file path
    output_file_path = os.path.join(folder_path, f"{file_name}")
    
    # Save the cropped image
    cropped_image.save(output_file_path)

In [14]:
import torch
from torch import nn
class Self_Attn(nn.Module):
    def __init__(self,in_dim):
        super(Self_Attn,self).__init__()
        self.chanel_in = in_dim

        self.reduce_dim=nn.Sequential(
            nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=1, padding=1),  # 保持通道数量不变
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 尺寸减半 (128x128 -> 64x64)
            
            nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=1, padding=1),  # 保持通道数量不变
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 尺寸再减半 (64x64 -> 32x32)
        )
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1) #
    def forward(self,x,y):
        x=self.reduce_dim(x)
        y=self.reduce_dim(y)
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        
        yproj_query  = self.query_conv(y).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        yproj_key =  self.key_conv(y).view(m_batchsize,-1,width*height) # B X C x (*W*H)

        proj_query=torch.add(proj_query,yproj_query)
        
        proj_key=torch.add(proj_key,yproj_key)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N
        yproj_value = self.value_conv(y).view(m_batchsize,-1,width*height) # B X C X N
        
        proj_value=torch.add(proj_value,yproj_value)
        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)
        
        out = self.gamma*out + x
        out=self.value_conv(out)
        
        return out.view(out.size(0), out.size(1), -1)
    
    
x=torch.randn(1,128,128,128)
y=torch.randn(1,128,128,128)
model=Self_Attn(in_dim=128)
print(model(x,y).shape)

torch.Size([1, 128, 1024])


In [2]:
import torch
from torch import nn
import torch.nn.functional as F 
class encoderfusion(nn.Module):
    def __init__(self):
        super(encoderfusion, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1)  # Output: (1, 64, 256, 256)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)  # Output: (1, 128, 128, 128)
        self.bn2 = nn.BatchNorm2d(128)
        
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)  # Output: (1, 256, 64, 64)
        self.bn3 = nn.BatchNorm2d(256)

        self.reduce = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)  # Output: (1, 256, 64, 64)
        self.bn4 = nn.BatchNorm2d(256)
    
    def forward(self, feat,prompt):
        x = F.relu(self.bn1(self.conv1(prompt)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))

        fusion=torch.cat((feat,x),1)
        reduce=self.bn4(self.reduce(fusion))
        return reduce


x=torch.randn(1,256,64,64)
y=torch.randn(1,1,512,512)
model=encoderfusion()
print(model(x,y).shape)

torch.Size([1, 256, 64, 64])


In [7]:
import os
import cv2
import numpy as np

def process_images(folder_path):
    # 获取文件夹中的所有文件
    files = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
    
    for file in files:
        file_path = os.path.join(folder_path, file)
        
        # 读取图像
        image = cv2.imread(file_path, cv2.IMREAD_GRAYSCALE)
        
        if image is not None:
            # 将像素值小于30的设置为0
            image[image < 2] = 0
            
            # 保存处理后的图像
            cv2.imwrite(file_path, image)
            print(f'Processed {file_path}')
        else:
            print(f'Could not read {file_path}')

folder_path = '../datasets/aim500/test/mask'  # 替换为你的文件夹路径
process_images(folder_path)


Processed ../datasets/aim500/test/mask/o_1c321c56.jpg
Processed ../datasets/aim500/test/mask/o_0a09b978.jpg
Processed ../datasets/aim500/test/mask/o_0c33063a.jpg
Processed ../datasets/aim500/test/mask/o_1a9abc07.jpg
Processed ../datasets/aim500/test/mask/o_1b224771.jpg
Processed ../datasets/aim500/test/mask/o_0b7228ec.jpg
Processed ../datasets/aim500/test/mask/o_1ae3ae29.jpg
Processed ../datasets/aim500/test/mask/o_2b0e2eed.jpg
Processed ../datasets/aim500/test/mask/o_1ea2b894.jpg
Processed ../datasets/aim500/test/mask/o_0df5178f.jpg
Processed ../datasets/aim500/test/mask/o_1b4c1dfc.jpg
Processed ../datasets/aim500/test/mask/o_1edbc402.jpg
Processed ../datasets/aim500/test/mask/o_0a0ae43d.jpg
Processed ../datasets/aim500/test/mask/o_1f836c45.jpg
Processed ../datasets/aim500/test/mask/o_0bf712f6.jpg


In [24]:
import torch

def compute_mse(pred, target, trimap):
    # 确保所有张量在相同的设备上
    device = pred.device
    target = target.to(device)
    trimap = trimap.to(device)

    # 计算误差映射
    error_map = (pred - target) / 255.0

    # 计算损失
    valid_mask = (trimap == 128).float()
    loss = torch.sum((error_map ** 2) * valid_mask) / (torch.sum(valid_mask) + 1e-8)

    return loss

# 示例用法
pred = torch.randn(1, 1, 512, 512).cuda()  # 示例预测张量
target = torch.randn(1, 1, 512, 512).cuda()  # 示例目标张量
trimap = torch.randint(0, 256, (1, 1, 512, 512)).cuda()  # 示例 trimap 张量，值在 [0, 255] 范围内

mse_loss = compute_mse(pred, target, trimap)
print(mse_loss)


tensor(3.4420e-05, device='cuda:0')
