convert masks in to binary format save them as in a new nympy array 

In [2]:
import torch
import random
import PIL
import numbers
import numpy as np
import torch.nn as nn
import collections
import matplotlib.pyplot as plt
import torchvision.transforms as ts
import torchvision.transforms.functional as TF
from PIL import Image, ImageDraw


_pil_interpolation_to_str = {
    Image.NEAREST: 'PIL.Image.NEAREST',
    Image.BILINEAR: 'PIL.Image.BILINEAR',
    Image.BICUBIC: 'PIL.Image.BICUBIC',
    Image.LANCZOS: 'PIL.Image.LANCZOS',
}




def ISIC2018_transform_320(sample, train_type):
    image, label = Image.fromarray(np.uint8(sample['image']), mode='RGB'),\
                   Image.fromarray(np.uint8(sample['label']), mode='L')

    if train_type == 'train':
        image, label = randomcrop(size=(224, 320))(image, label)
        image, label = randomflip_rotate(image, label, p=0.5, degrees=30)
    else:
        image, label = resize(size=(224, 320))(image, label)

    image = ts.Compose([ts.ToTensor(),
                        ts.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])(image)
    label = ts.ToTensor()(label)

    return {'image': image, 'label': label}
    
    


# these are founctional function for transform
def randomflip_rotate(img, lab, p=0.5, degrees=0):
    if random.random() < p:
        img = TF.hflip(img)
        lab = TF.hflip(lab)
    if random.random() < p:
        img = TF.vflip(img)
        lab = TF.vflip(lab)

    if isinstance(degrees, numbers.Number):
        if degrees < 0:
            raise ValueError("If degrees is a single number, it must be positive.")
        degrees = (-degrees, degrees)
    else:
        if len(degrees) != 2:
            raise ValueError("If degrees is a sequence, it must be of len 2.")
        degrees = degrees
    angle = random.uniform(degrees[0], degrees[1])
    img = TF.rotate(img, angle)
    lab = TF.rotate(lab, angle)

    return img, lab


class randomcrop(object):
    """Crop the given PIL Image and mask at a random location.

    Args:
        size (sequence or int): Desired output size of the crop. If size is an
            int instead of sequence like (h, w), a square crop (size, size) is
            made.
        padding (int or sequence, optional): Optional padding on each border
            of the image. Default is 0, i.e no padding. If a sequence of length
            4 is provided, it is used to pad left, top, right, bottom borders
            respectively.
        pad_if_needed (boolean): It will pad the image if smaller than the
            desired size to avoid raising an exception.
    """

    def __init__(self, size, padding=0, pad_if_needed=False):
        if isinstance(size, numbers.Number):
            self.size = (int(size), int(size))
        else:
            self.size = size
        self.padding = padding
        self.pad_if_needed = pad_if_needed

    @staticmethod
    def get_params(img, output_size):
        """Get parameters for ``crop`` for a random crop.

        Args:
            img (PIL Image): Image to be cropped.
            output_size (tuple): Expected output size of the crop.

        Returns:
            tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
        """
        w, h = img.size
        th, tw = output_size
        if w == tw and h == th:
            return 0, 0, h, w

        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw

    def __call__(self, img, lab):
        """
        Args:
            img (PIL Image): Image to be cropped.
            lab (PIL Image): Image to be cropped.

        Returns:
            PIL Image: Cropped image and mask.
        """
        if self.padding > 0:
            img = TF.pad(img, self.padding)
            lab = TF.pad(lab, self.padding)

        # pad the width if needed
        if self.pad_if_needed and img.size[0] < self.size[1]:
            img = TF.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0))
            lab = TF.pad(lab, (int((1 + self.size[1] - lab.size[0]) / 2), 0))
        # pad the height if needed
        if self.pad_if_needed and img.size[1] < self.size[0]:
            img = TF.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2)))
            lab = TF.pad(lab, (0, int((1 + self.size[0] - lab.size[1]) / 2)))

        i, j, h, w = self.get_params(img, self.size)

        return TF.crop(img, i, j, h, w), TF.crop(lab, i, j, h, w)

    def __repr__(self):
        return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)


class resize(object):
    """Resize the input PIL Image and mask to the given size.

    Args:
        size (sequence or int): Desired output size. If size is a sequence like
            (h, w), output size will be matched to this. If size is an int,
            smaller edge of the image will be matched to this number.
            i.e, if height > width, then image will be rescaled to
            (size * height / width, size)
        interpolation (int, optional): Desired interpolation. Default is
            ``PIL.Image.BILINEAR``
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
        self.size = size
        self.interpolation = interpolation

    def __call__(self, img, lab):
        """
        Args:
            img (PIL Image): Image to be scaled.
            lab (PIL Image): Image to be scaled.

        Returns:
            PIL Image: Rescaled image and mask.
        """
        return TF.resize(img, self.size, self.interpolation), TF.resize(lab, self.size, self.interpolation)

    def __repr__(self):
        interpolate_str = _pil_interpolation_to_str[self.interpolation]
        return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)


def itensity_normalize(volume):
    """
    normalize the itensity of an nd volume based on the mean and std of nonzeor region
    inputs:
        volume: the input nd volume
    outputs:
        out: the normalized n                                                                                                                                                                 d volume
    """

    # pixels = volume[volume > 0]
    mean = volume.mean()
    std = volume.std()
    out = (volume - mean) / std
    out_random = np.random.normal(0, 1, size=volume.shape)
    out[volume == 0] = out_random[volume == 0]

    return out

In [None]:
import os
import glob
import torch
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
#from transform import ISIC2018_transform_320  # from your uploaded file

# Parameters
height, width = 224, 320
channels = 3

# Paths
dataset_root = '/kaggle/input/isic2018-challenge-task1-data-segmentation/'
input_dir = os.path.join(dataset_root, 'ISIC2018_Task1-2_Training_Input')
mask_dir = os.path.join(dataset_root, 'ISIC2018_Task1_Training_GroundTruth')

# Get image list
img_paths = sorted(glob.glob(os.path.join(input_dir, '*.jpg')))

# Initialize data containers
samples = []

print('Reading ISIC 2018')
for idx, img_path in enumerate(img_paths):
    print(f"{idx + 1}/{len(img_paths)}")

    # Load image and corresponding mask
    img = Image.open(img_path).convert("RGB")
    img_np = np.array(img)

    img_name = os.path.basename(img_path).replace('.jpg', '_segmentation.png')
    mask_path = os.path.join(mask_dir, img_name)
    mask = Image.open(mask_path).convert("L")
    mask_np = np.array(mask)

    samples.append({'image': img_np, 'label': mask_np})

print('Reading ISIC 2018 finished')

# Split the dataset: 1815 train, 259 val, rest test
train_samples, temp_samples = train_test_split(samples, train_size=1815, random_state=42)
val_samples, test_samples = train_test_split(temp_samples, test_size=len(samples) - (1815 + 259), random_state=42)

def apply_transforms(samples, train_type):
    images, masks = [], []
    for sample in samples:
        transformed = ISIC2018_transform_320(sample, train_type=train_type)
        images.append(transformed['image'])
        masks.append(transformed['label'].squeeze(0))  # remove channel dim
    return torch.stack(images), torch.stack(masks)

# Apply transforms
train_imgs, train_masks = apply_transforms(train_samples, train_type='train')
val_imgs, val_masks = apply_transforms(val_samples, train_type='val')
test_imgs, test_masks = apply_transforms(test_samples, train_type='val')  # 'val' = no augment

# Save tensors
torch.save(train_imgs, 'data_train.pt')
torch.save(val_imgs, 'data_val.pt')
torch.save(test_imgs, 'data_test.pt')

torch.save(train_masks, 'mask_train.pt')
torch.save(val_masks, 'mask_val.pt')
torch.save(test_masks, 'mask_test.pt')


Reading ISIC 2018
1/2594
2/2594
3/2594
4/2594
5/2594
6/2594
7/2594
8/2594
9/2594
10/2594
11/2594
12/2594
13/2594
14/2594
15/2594
16/2594
17/2594
18/2594
19/2594
20/2594
21/2594
22/2594
23/2594
24/2594
25/2594
26/2594
27/2594
28/2594
29/2594
30/2594
31/2594
32/2594
33/2594
34/2594
35/2594
36/2594
37/2594
38/2594
39/2594
40/2594
41/2594
42/2594
43/2594
44/2594
45/2594
46/2594
47/2594
48/2594
49/2594
50/2594
51/2594
52/2594
53/2594
54/2594
55/2594
56/2594
57/2594
58/2594
59/2594
60/2594
61/2594
62/2594
63/2594
64/2594
65/2594
66/2594
67/2594
68/2594
69/2594
70/2594
71/2594
72/2594
73/2594
74/2594
75/2594
76/2594
77/2594
78/2594
79/2594
80/2594
81/2594
82/2594
83/2594
84/2594
85/2594
86/2594
87/2594
88/2594
89/2594
90/2594
91/2594
92/2594
93/2594
94/2594
95/2594
96/2594
97/2594
98/2594
99/2594
100/2594
101/2594
102/2594
103/2594
104/2594
105/2594
106/2594
107/2594
108/2594
109/2594
110/2594
111/2594
112/2594
113/2594
114/2594
115/2594
116/2594
117/2594
118/2594
119/2594
120/2594
121/2594
1

In [None]:
!pip install thop

In [None]:
import torch

# Load masks
train_masks = torch.load('mask_train.pt')  # shape: (N, H, W)
val_masks = torch.load('mask_val.pt')
test_masks = torch.load('mask_test.pt')

# Binarize masks (threshold at 0.5)
train_masks_binary = (train_masks > 0.5).to(torch.uint8)
val_masks_binary = (val_masks > 0.5).to(torch.uint8)
test_masks_binary = (test_masks > 0.5).to(torch.uint8)

# Save binary masks
torch.save(train_masks_binary, 'mask_train_binary.pt')
torch.save(val_masks_binary, 'mask_val_binary.pt')
torch.save(test_masks_binary, 'mask_test_binary.pt')

print("Masks have been successfully binarized and saved.")


normalize images from 0-255 to 0-1 and covert float32 and save them in new numpy array(normalized)

In [None]:
import torch

# Load the original images (values in 0–255 range, shape: [N, 3, H, W])
train_images = torch.load('data_train.pt')
val_images = torch.load('data_val.pt')
test_images = torch.load('data_test.pt')

# Normalize images by dividing by 255.0 and converting to float32
train_images_normalized = train_images.float() / 255.0
val_images_normalized = val_images.float() / 255.0
test_images_normalized = test_images.float() / 255.0

# Save normalized images
torch.save(train_images_normalized, 'data_train_normalized.pt')
torch.save(val_images_normalized, 'data_val_normalized.pt')
torch.save(test_images_normalized, 'data_test_normalized.pt')

print("Normalization and saving completed.")



normalizing way two same as that of the paper

veryfing masks are binary and printing sample data

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

# Load normalized images and binary masks
train_images = torch.load('data_train.pt')  # shape: (N, 3, H, W)
train_masks = torch.load('mask_train_binary.pt')       # shape: (N, H, W)
val_images = torch.load('data_val.pt')
val_masks = torch.load('mask_val_binary.pt')
test_images = torch.load('data_test.pt')
test_masks = torch.load('mask_test_binary.pt')

# Check shapes
print("Train Images Shape: ", train_images.shape)
print("Train Masks Shape: ", train_masks.shape)
print("Validation Images Shape: ", val_images.shape)
print("Validation Masks Shape: ", val_masks.shape)
print("Test Images Shape: ", test_images.shape)
print("Test Masks Shape: ", test_masks.shape)

# Visualize a sample
def display_sample(images, masks, index):
    """Displays an image and its corresponding mask (converted to numpy for matplotlib)."""
    image = images[index].permute(1, 2, 0).cpu().numpy()  # (C, H, W) -> (H, W, C)
    mask = masks[index].cpu().numpy()                    # (H, W)

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(image)
    ax[0].set_title('Image')
    ax[0].axis('off')
    ax[1].imshow(mask, cmap='gray')
    ax[1].set_title('Mask')
    ax[1].axis('off')
    plt.show()

# Display a sample from the training set
display_sample(train_images, train_masks, index=0)

# Verify masks are binary (only 0 and 1)
assert torch.equal(train_masks.unique(), torch.tensor([0, 1], dtype=torch.uint8)), "Train masks are not binary!"
assert torch.equal(val_masks.unique(), torch.tensor([0, 1], dtype=torch.uint8)), "Validation masks are not binary!"
assert torch.equal(test_masks.unique(), torch.tensor([0, 1], dtype=torch.uint8)), "Test masks are not binary!"

print("Dataset verification completed successfully.")


In [None]:
# Check min and max values of the entire training image set
print("Min pixel value in training images:", train_images.min().item())
print("Max pixel value in training images:", train_images.max().item())

# Check min and max values of a specific sample image
sample_image = train_images[0]
print("Min pixel value in sample image:", sample_image.min().item())
print("Max pixel value in sample image:", sample_image.max().item())


In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 0 = all logs, 1 = filter info, 2 = filter warnings, 3 = filter errors
import tensorflow as tf
#to supress warnings 

defining model unet 

In [None]:
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 10 09:57:49 2019

@author: Fsl
"""
from scipy import ndimage
import torch
from torchvision import models
import torch.nn as nn
from torchsummary import summary

# from .resnet import resnet34
# from resnet import resnet34
# import resnet
from torch.nn import functional as F
import torchsummary
from torch.nn import init
import numpy as np
from functools import partial
from thop import profile
up_kwargs = {'mode': 'bilinear', 'align_corners': True}
BatchNorm2d = nn.BatchNorm2d

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class SpatialAttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(SpatialAttentionBlock, self).__init__()
        self.query = nn.Sequential(
            nn.Conv2d(in_channels,in_channels//8,kernel_size=(1,3), padding=(0,1)),
            nn.BatchNorm2d(in_channels//8),
            nn.ReLU(inplace=True)
        )
        self.key = nn.Sequential(
            nn.Conv2d(in_channels, in_channels//8, kernel_size=(3,1), padding=(1,0)),
            nn.BatchNorm2d(in_channels//8),
            nn.ReLU(inplace=True)
        )
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        """
        :param x: input( BxCxHxW )
        :return: affinity value + x
        """
        B, C, H, W = x.size()
        # compress x: [B,C,H,W]-->[B,H*W,C], make a matrix transpose
        proj_query = self.query(x).view(B, -1, W * H).permute(0, 2, 1)
        proj_key = self.key(x).view(B, -1, W * H)
        affinity = torch.matmul(proj_query, proj_key)
        affinity = self.softmax(affinity)
        proj_value = self.value(x).view(B, -1, H * W)
        weights = torch.matmul(proj_value, affinity.permute(0, 2, 1))
        weights = weights.view(B, C, H, W)
        out = self.gamma * weights + x
        return out
class ChannelAttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(ChannelAttentionBlock, self).__init__()
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        """
        :param x: input( BxCxHxW )
        :return: affinity value + x
        """
        B, C, H, W = x.size()
        proj_query = x.view(B, C, -1)
        proj_key = x.view(B, C, -1).permute(0, 2, 1)
        affinity = torch.matmul(proj_query, proj_key)
        affinity_new = torch.max(affinity, -1, keepdim=True)[0].expand_as(affinity) - affinity
        affinity_new = self.softmax(affinity_new)
        proj_value = x.view(B, C, -1)
        weights = torch.matmul(affinity_new, proj_value)
        weights = weights.view(B, C, H, W)
        out = self.gamma * weights + x
        return out
        
class AffinityAttention2(nn.Module):
    """ Affinity attention module """

    def __init__(self, in_channels):
        super(AffinityAttention2, self).__init__()
        self.sab = SpatialAttentionBlock(in_channels)
        self.cab = ChannelAttentionBlock(in_channels)
        # self.conv1x1 = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1)

    def forward(self, x):
        """
        sab: spatial attention block
        cab: channel attention block
        :param x: input tensor
        :return: sab + cab
        """
        sab = self.sab(x)
        cab = self.cab(sab)
        out = sab + cab
        return out

class UnetDsv3(nn.Module):
    def __init__(self, in_size, out_size, scale_factor):
        super(UnetDsv3, self).__init__()
        self.dsv = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size=1, stride=1, padding=0),
                                 nn.Upsample(size=scale_factor, mode='bilinear'), )

    def forward(self, input):
        return self.dsv(input)


class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1,
                 relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
                              dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x



class Scale_Aware(nn.Module):
    def __init__(self, in_channels):
        super(Scale_Aware, self).__init__()

        # self.bn = nn.ModuleList([nn.BatchNorm2d(in_channels), nn.BatchNorm2d(in_channels), nn.BatchNorm2d(in_channels)])
        self.conv1x1 = nn.ModuleList(
            [nn.Conv2d(in_channels=2 * in_channels, out_channels=in_channels, dilation=1, kernel_size=1, padding=0),
             nn.Conv2d(in_channels=2 * in_channels, out_channels=in_channels, dilation=1, kernel_size=1, padding=0)])
        self.conv3x3_1 = nn.ModuleList(
            [nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 2, dilation=1, kernel_size=3, padding=1),
             nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 2, dilation=1, kernel_size=3, padding=1)])
        self.conv3x3_2 = nn.ModuleList(
            [nn.Conv2d(in_channels=in_channels // 2, out_channels=2, dilation=1, kernel_size=3, padding=1),
             nn.Conv2d(in_channels=in_channels // 2, out_channels=2, dilation=1, kernel_size=3, padding=1)])
        self.conv_last = ConvBnRelu(in_planes=in_channels, out_planes=in_channels, ksize=1, stride=1, pad=0, dilation=1)

        self.relu = nn.ReLU()
    def forward(self, x_l, x_h):
        feat = torch.cat([x_l, x_h], dim=1)
        # feat=feat_cat.detach()
        feat = self.relu(self.conv1x1[0](feat))
        feat = self.relu(self.conv3x3_1[0](feat))
        att = self.conv3x3_2[0](feat)
        att = F.softmax(att, dim=1)

        att_1 = att[:, 0, :, :].unsqueeze(1)
        att_2 = att[:, 1, :, :].unsqueeze(1)

        fusion_1_2 = att_1 * x_l + att_2 * x_h
        return fusion_1_2




class BaseNetHead(nn.Module):
    def __init__(self, in_planes, out_planes, scale,
                 is_aux=False, norm_layer=nn.BatchNorm2d):
        super(BaseNetHead, self).__init__()
        if is_aux:
            self.conv_1x1_3x3=nn.Sequential(
                ConvBnRelu(in_planes, 64, 1, 1, 0,
                                       has_bn=True, norm_layer=norm_layer,
                                       has_relu=True, has_bias=False),
                ConvBnRelu(64, 64, 3, 1, 1,
                                       has_bn=True, norm_layer=norm_layer,
                                       has_relu=True, has_bias=False))
        else:
            self.conv_1x1_3x3=nn.Sequential(
                ConvBnRelu(in_planes, 32, 1, 1, 0,
                                       has_bn=True, norm_layer=norm_layer,
                                       has_relu=True, has_bias=False),
                ConvBnRelu(32, 32, 3, 1, 1,
                                       has_bn=True, norm_layer=norm_layer,
                                       has_relu=True, has_bias=False))
        # self.dropout = nn.Dropout(0.1)
        if is_aux:
            self.conv_1x1_2 = nn.Conv2d(64, out_planes, kernel_size=1,
                                      stride=1, padding=0)
        else:
            self.conv_1x1_2 = nn.Conv2d(32, out_planes, kernel_size=1,
                                      stride=1, padding=0)
        self.scale = scale
            
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                init.normal_(m.weight.data, 1.0, 0.02)
                init.constant_(m.bias.data, 0.0)

    def forward(self, x):

        if self.scale > 1:
            x = F.interpolate(x, scale_factor=self.scale,
                                   mode='bilinear',
                                   align_corners=True)
        fm = self.conv_1x1_3x3(x)
        # fm = self.dropout(fm)
        output = self.conv_1x1_2(fm)
        return output



class ConvBnRelu(nn.Module):
    def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1,
                 groups=1, has_bn=True, norm_layer=nn.BatchNorm2d,
                 has_relu=True, inplace=True, has_bias=False):
        super(ConvBnRelu, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize,
                              stride=stride, padding=pad,
                              dilation=dilation, groups=groups, bias=has_bias)
        self.has_bn = has_bn
        if self.has_bn:
            self.bn = nn.BatchNorm2d(out_planes)
        self.has_relu = has_relu
        if self.has_relu:
            self.relu = nn.ReLU(inplace=inplace)

    def forward(self, x):
        x = self.conv(x)
        if self.has_bn:
            x = self.bn(x)
        if self.has_relu:
            x = self.relu(x)

        return x
    
class GlobalAvgPool2d(nn.Module):
    def __init__(self):
        """Global average pooling over the input's spatial dimensions"""
        super(GlobalAvgPool2d, self).__init__()

    def forward(self, inputs):
        in_size = inputs.size()
        inputs = inputs.view((in_size[0], in_size[1], -1)).mean(dim=2)
        inputs = inputs.view(in_size[0], in_size[1], 1, 1)

        return inputs

class Local_Channel(nn.Module):
    def __init__(self, in_channel):
        super(Local_Channel, self).__init__()
        self.attn = nn.Sequential(GlobalAvgPool2d(), nn.Conv2d(in_channel, in_channel, 1), nn.Sigmoid())
        self.gamma = nn.Parameter(torch.zeros(1))
    def forward(self, x):
        attn_map = self.attn(x)
        return x * (1 - self.gamma) + attn_map * x * self.gamma, attn_map

class Local_Spatial(nn.Module):
    def __init__(self, in_channel, mid_channel):
        super(Local_Spatial, self).__init__()
        self.conv1x1 = nn.Conv2d(in_channel, mid_channel, 1)
        self.branch1 = nn.Conv2d(mid_channel, mid_channel, 3, 1, 1, 1)
        self.branch2 = nn.Conv2d(mid_channel, mid_channel, 3, 1, 2, 2)
        self.branch3 = nn.Conv2d(mid_channel, mid_channel, 3, 1, 3, 3)
        self.attn = nn.Conv2d(3 * mid_channel, 1, 1)
        self.gamma = nn.Parameter(torch.zeros(1))
    def forward(self, x):
        mid = self.conv1x1(x)
        branch1 = self.branch1(mid)
        branch2 = self.branch2(mid)
        branch3 = self.branch3(mid)
        branch123 = torch.cat([branch1, branch2, branch3], dim=1)
        attn_map = self.attn(branch123)
        return x * (1 - self.gamma) + attn_map * x * self.gamma, attn_map

nonlinearity = partial(F.relu, inplace=True)



class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        # diffY = x2.size()[2] - x1.size()[2]
        # diffX = x2.size()[3] - x1.size()[3]
        #
        # x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
        #                 diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)





class CBAM_Module2(nn.Module):
    def __init__(self, channels=512, reduction=2):
        super(CBAM_Module2, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
                             padding=0)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
                             padding=0)
        self.sigmoid_channel = nn.Sigmoid()
        self.conv_after_concat = nn.Conv2d(2, 1, kernel_size=7, stride=1, padding=3)
        self.sigmoid_spatial = nn.Sigmoid()

    def forward(self, x):
        # Channel Attention module
        module_input = x
        avg = self.avg_pool(x)
        mx = self.max_pool(x)
        avg = self.fc1(avg)
        mx = self.fc1(mx)
        avg = self.relu(avg)
        mx = self.relu(mx)
        avg = self.fc2(avg)
        mx = self.fc2(mx)
        x = avg + mx
        x = self.sigmoid_channel(x)
        # Spatial Attention module
        x = module_input * x + module_input
        module_input = x
        avg = torch.mean(x, 1, True)
        mx, _ = torch.max(x, 1, True)
        x = torch.cat((avg, mx), 1)
        x = self.conv_after_concat(x)
        x = self.sigmoid_spatial(x)
        x = module_input * x + module_input
        return x

class Bridge(nn.Module):
    def __init__(self, in_channels_1, in_channels_2, in_channels_3, mid_channels):
        super(Bridge, self).__init__()
        self.mid_channels = mid_channels
        self.conv_qk1 = nn.Conv2d(in_channels_1, mid_channels, 1, 1, 0)
        self.conv_qk2 = nn.Conv2d(in_channels_2, mid_channels, 1, 1, 0)
        self.conv_qk3 = nn.Conv2d(in_channels_3, mid_channels, 1, 1, 0)

        self.conv_v1 = nn.Conv2d(in_channels_1, mid_channels, 1, 1, 0)
        self.conv_v2 = nn.Conv2d(in_channels_2, mid_channels, 1, 1, 0)
        self.conv_v3 = nn.Conv2d(in_channels_3, mid_channels, 1, 1, 0)

        self.conv_out1 = nn.Conv2d(2 * mid_channels + in_channels_1, in_channels_1, 1, 1, 0)
        self.conv_out2 = nn.Conv2d(2 * mid_channels + in_channels_2, in_channels_2, 1, 1, 0)
        self.conv_out3 = nn.Conv2d(2 * mid_channels + in_channels_3, in_channels_3, 1, 1, 0)

    def forward(self, f1, f2, f3):
        batch_size = f1.size(0)
        qk1 = self.conv_qk1(f1).view(batch_size, self.mid_channels, -1)
        qk2 = self.conv_qk2(f2).view(batch_size, self.mid_channels, -1)
        qk3 = self.conv_qk3(f3).view(batch_size, self.mid_channels, -1)

        v1 = self.conv_v1(f1).view(batch_size, self.mid_channels, -1)
        v2 = self.conv_v2(f2).view(batch_size, self.mid_channels, -1)
        v3 = self.conv_v3(f3).view(batch_size, self.mid_channels, -1)

        sim12 = torch.matmul(qk1.permute(0, 2, 1), qk2)
        sim23 = torch.matmul(qk2.permute(0, 2, 1), qk3)
        sim31 = torch.matmul(qk3.permute(0, 2, 1), qk1)

        attn12 = F.softmax(sim12, dim=-1)
        attn21 = F.softmax(sim12.permute(0, 2, 1), dim=-1)
        attn23 = F.softmax(sim23, dim=-1)
        attn32 = F.softmax(sim23.permute(0, 2, 1), dim=-1)
        attn31 = F.softmax(sim31, dim=-1)
        attn13 = F.softmax(sim31.permute(0, 2, 1), dim=-1)

        y12 = torch.matmul(v1, attn12).contiguous()
        y13 = torch.matmul(v1, attn13).contiguous()
        y21 = torch.matmul(v2, attn21).contiguous()
        y23 = torch.matmul(v2, attn23).contiguous()
        y31 = torch.matmul(v3, attn31).contiguous()
        y32 = torch.matmul(v3, attn32).contiguous()

        y12 = y12.view(batch_size, self.mid_channels, int(f2.size()[2]), int(f2.size()[3]))
        y13 = y13.view(batch_size, self.mid_channels, int(f3.size()[2]), int(f3.size()[3]))
        y21 = y21.view(batch_size, self.mid_channels, int(f1.size()[2]), int(f1.size()[3]))
        y23 = y23.view(batch_size, self.mid_channels, int(f3.size()[2]), int(f3.size()[3]))
        y31 = y31.view(batch_size, self.mid_channels, int(f1.size()[2]), int(f1.size()[3]))
        y32 = y32.view(batch_size, self.mid_channels, int(f2.size()[2]), int(f2.size()[3]))

        out1 = self.conv_out1(torch.cat([f1, y31, y21], dim=1))
        out2 = self.conv_out2(torch.cat([f2, y12, y32], dim=1))
        out3 = self.conv_out3(torch.cat([f3, y23, y13], dim=1))

        return out1, out2, out3



class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class ResidualConv(nn.Module):
    def __init__(self, input_dim, output_dim, stride, padding):
        super(ResidualConv, self).__init__()

        self.conv_block = nn.Sequential(
            nn.BatchNorm2d(input_dim),
            nn.ReLU(),
            nn.Conv2d(
                input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
            ),
            nn.BatchNorm2d(output_dim),
            nn.ReLU(),
            nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
        )
        self.conv_skip = nn.Sequential(
            nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(output_dim),
        )

    def forward(self, x):

        return self.conv_block(x) + self.conv_skip(x)



class ConvBnRelu(nn.Module):
    def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1,
                 groups=1, has_bn=True, norm_layer=nn.BatchNorm2d,
                 has_relu=True, inplace=True, has_bias=False):
        super(ConvBnRelu, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize,
                              stride=stride, padding=pad,
                              dilation=dilation, groups=groups, bias=has_bias)
        self.has_bn = has_bn
        if self.has_bn:
            self.bn = nn.BatchNorm2d(out_planes)
        self.has_relu = has_relu
        if self.has_relu:
            self.relu = nn.ReLU(inplace=inplace)

    def forward(self, x):
        x = self.conv(x)
        if self.has_bn:
            x = self.bn(x)
        if self.has_relu:
            x = self.relu(x)

        return x




class DecoderBlock(nn.Module):
    def __init__(self, in_planes, out_planes,
                 norm_layer=nn.BatchNorm2d,scale=2,relu=True,last=False):
        super(DecoderBlock, self).__init__()
       

        self.conv_3x3 = ConvBnRelu(in_planes, in_planes, 3, 1, 1,
                                   has_bn=True, norm_layer=norm_layer,
                                   has_relu=True, has_bias=False)
        self.conv_1x1 = ConvBnRelu(in_planes, out_planes, 1, 1, 0,
                                   has_bn=True, norm_layer=norm_layer,
                                   has_relu=True, has_bias=False)
       
        self.sap=SAPblock(in_planes)
        self.scale=scale
        self.last=last

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                init.normal_(m.weight.data, 1.0, 0.02)
                init.constant_(m.bias.data, 0.0)

    def forward(self, x):

        if self.last==False:
            x = self.conv_3x3(x)
            # x=self.sap(x)
        if self.scale>1:
            x=F.interpolate(x,scale_factor=self.scale,mode='bilinear',align_corners=True)
        x=self.conv_1x1(x)
        return x






class msca_net(nn.Module):
    def __init__(self, classes=1, channels=3, ccm=True, norm_layer=nn.BatchNorm2d, is_training=True, expansion=2,
                 base_channel=32):
        super(msca_net, self).__init__()
        self.backbone = models.resnet34(pretrained=True)
        # self.backbone =resnet34(pretrained=False)
        self.expansion = expansion
        self.base_channel = base_channel
        if self.expansion == 4 and self.base_channel == 64:
            expan = [512, 1024, 2048]
            spatial_ch = [128, 256]
        elif self.expansion == 4 and self.base_channel == 32:
            expan = [256, 512, 1024]
            spatial_ch = [32, 128]
            conv_channel_up = [256, 384, 512]
        elif self.expansion == 2 and self.base_channel == 32:
            expan = [128, 256, 512]
            spatial_ch = [64, 64]
            conv_channel_up = [128, 256, 512]

        conv_channel = expan[0]

        self.is_training = is_training
        # self.sap = SAPblock(expan[-1])

        # self.decoder5 = DecoderBlock(expan[-1], expan[-2], relu=False, last=True)  # 256
        # self.decoder4 = DecoderBlock(expan[-2], expan[-3], relu=False)  # 128
        # self.decoder3 = DecoderBlock(expan[-3], spatial_ch[-1], relu=False)  # 64
        # self.decoder2 = DecoderBlock(spatial_ch[-1], spatial_ch[-2])  # 32

        bilinear =True
        factor = 2
        self.up1 = Up(768, 512 // factor, bilinear)
        self.up2 = Up(384, 256 // factor, bilinear)
        self.up3 = Up(192, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)

        self.main_head = BaseNetHead(64, classes, 2,
                                     is_aux=False, norm_layer=norm_layer)

        # self.relu = nn.ReLU()

        # self.fpt = FPT(feature_dim=4)

        filters = [64, 64, 128, 256]
        self.out_size = (112, 160)
        self.dsv4 = UnetDsv3(in_size=filters[3], out_size=64, scale_factor=self.out_size)
        self.dsv3 = UnetDsv3(in_size=filters[2], out_size=64, scale_factor=self.out_size)
        self.dsv2 = UnetDsv3(in_size=filters[1], out_size=64, scale_factor=self.out_size)
        self.dsv1 = nn.Conv2d(in_channels=filters[0], out_channels=64, kernel_size=1)

        self.sw1 = Scale_Aware(in_channels=64)
        self.sw2 = Scale_Aware(in_channels=64)
        self.sw3 = Scale_Aware(in_channels=64)

        self.affinity_attention = AffinityAttention2(512)
        self.cbam = CBAM_Module2()
        self.gamma1 = nn.Parameter(torch.zeros(1))
        self.gamma2 = nn.Parameter(torch.zeros(1))
        self.gamma3 = nn.Parameter(torch.zeros(1))

        self.bridge = Bridge(64, 128, 256, 64)
    def forward(self, x):

        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        c1 = self.backbone.relu(x)  # 1/2  64

        x = self.backbone.maxpool(c1)
        c2 = self.backbone.layer1(x)  # 1/4   64
        c3 = self.backbone.layer2(c2)  # 1/8   128
        c4 = self.backbone.layer3(c3)  # 1/16   256
        c5 = self.backbone.layer4(c4)  # 1/32   512
        # d_bottom=self.bottom(c5)

        # m1, m2, m3, m4 = self.fpt(c1, c2, c3, c4)
        m2, m3, m4 = self.bridge(c2, c3, c4)

        # c5 = self.sap(c5)
        attention = self.affinity_attention(c5)
        cbam_attn = self.cbam(c5)
        # l_channel, _ = self.l_channel(c5)
        # l_spatial, _ = self.l_spatial(c5)
        c5 = self.gamma1 * attention + self.gamma2 * cbam_attn + self.gamma3 * c5#多种并行方式， 用不用bn relu, 用不用scale aware

        # d5=d_bottom+c5           #512

        # d4 = self.relu(self.decoder5(c5) + m4)  # 256
        # d3 = self.relu(self.decoder4(d4) + m3)  # 128
        # d2 = self.relu(self.decoder3(d3) + m2)  # 64
        # d1 = self.decoder2(d2) + m1  # 32
        d4 = self.up1(c5, m4)
        d3 = self.up2(d4, m3)
        d2 = self.up3(d3, m2)
        d1 = self.up4(d2, c1)

        dsv4 = self.dsv4(d4)
        dsv3 = self.dsv3(d3)
        dsv2 = self.dsv2(d2)
        dsv1 = self.dsv1(d1)

        dsv43 = self.sw1(dsv4, dsv3)
        dsv432 = self.sw2(dsv43, dsv2)
        dsv4321 = self.sw3(dsv432, dsv1)

        main_out = self.main_head(dsv4321)

        final = F.sigmoid(main_out)

        return final


model = msca_net().to(device)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


summary(model, input_size=(3,224,320),device=device.type)  # Input size should match your input data



In [None]:
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_slow_operation_alarm=false"


In [None]:
import torch
import matplotlib.pyplot as plt

# Load the datasets
train_images = torch.load('data_train.pt')  # Shape: (N, 3, 256, 256)
train_masks = torch.load('mask_train_binary.pt')       # Shape: (N, 256, 256)
val_images = torch.load('data_val.pt')
val_masks = torch.load('mask_val_binary.pt')
test_images = torch.load('data_test.pt')
test_masks = torch.load('mask_test_binary.pt')

# Check shapes
print("Train Images Shape: ", train_images.shape)
print("Train Masks Shape: ", train_masks.shape)
print("Validation Images Shape: ", val_images.shape)
print("Validation Masks Shape: ", val_masks.shape)
print("Test Images Shape: ", test_images.shape)
print("Test Masks Shape: ", test_masks.shape)

# Ensure the image shape is (3, 256, 256) and the mask shape is (256, 256)
assert train_images.shape[1:] == (3, 224, 320), "Training images shape mismatch!"
assert train_masks.shape[1:] == (224,320), "Training masks shape mismatch!"

# Visualize a few samples to check if the images and masks are correct
def display_sample(images, masks, index):
    """Displays an image and its corresponding mask (converted to NumPy for visualization)."""
    image = images[index].permute(1, 2, 0).cpu().numpy()  # (C, H, W) → (H, W, C)
    mask = masks[index].cpu().numpy()                    # (H, W)

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(image)
    ax[0].set_title('Image')
    ax[0].axis('off')
    ax[1].imshow(mask, cmap='gray')
    ax[1].set_title('Mask')
    ax[1].axis('off')
    plt.show()

# Display a random training sample
display_sample(train_images, train_masks, index=0)

# Check that masks are binary
assert torch.equal(train_masks.unique(), torch.tensor([0, 1], dtype=torch.uint8)), "Train masks are not binary!"
assert torch.equal(val_masks.unique(), torch.tensor([0, 1], dtype=torch.uint8)), "Validation masks are not binary!"
assert torch.equal(test_masks.unique(), torch.tensor([0, 1], dtype=torch.uint8)), "Test masks are not binary!"

print("Dataset verification completed successfully.")


In [None]:
# Add channel dimension to the masks: (N, H, W) → (N, 1, H, W)
train_masks = train_masks.unsqueeze(1)
val_masks = val_masks.unsqueeze(1)
test_masks = test_masks.unsqueeze(1)


above is all preprocessing the dataset which already loaded in the datasets(input tab)

In [None]:

# Check shapes of image and mask tensors
print("Train Images Shape: ", train_images.shape)
print("Train Masks Shape: ", train_masks.shape)
print("Validation Images Shape: ", val_images.shape)
print("Validation Masks Shape: ", val_masks.shape)
print("Test Images Shape: ", test_images.shape)
print("Test Masks Shape: ", test_masks.shape)


In [None]:
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_slow_operation_alarm=false"


train

In [None]:
import os
os.environ["XLA_FLAGS"] = "--xla_gpu_enable_slow_operation_alarm=false"


In [None]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os

# Hyperparameters
batch_size = 8
epochs = 50
patience = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare datasets and dataloaders
train_dataset = TensorDataset(train_images, train_masks)
val_dataset = TensorDataset(val_images, val_masks)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Instantiate model
model = msca_net(channels=3).to(device)

# Loss and optimizer
criterion = nn.BCEWithLogitsLoss()  # Assumes raw logits from model
optimizer = Adam(model.parameters(), lr=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=7, verbose=True)

# Model checkpoint directory
checkpoint_path = 'msca_model_50_2.pt'
best_val_loss = float('inf')
early_stop_counter = 0

# Add this before your training loop
history = {
    'train_loss': [],
    'val_loss': []
    # Optionally: 'train_accuracy': [], 'val_accuracy': []
}

# Training loop
for epoch in range(epochs):
    model.train()
    train_loss = 0.0

    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device).float()
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * images.size(0)

    train_loss /= len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device).float()
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item() * images.size(0)

    val_loss /= len(val_loader.dataset)
    scheduler.step(val_loss)

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)


    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    # Save best model
    if val_loss < best_val_loss - 1e-4:
        best_val_loss = val_loss
        torch.save(model.state_dict(), checkpoint_path)
        print("Validation loss improved. Model saved.")
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        print(f"No improvement in val_loss for {early_stop_counter} epochs.")

    # Early stopping
    if early_stop_counter >= patience:
        print("Early stopping triggered.")
        break

# Save final model
torch.save(model.state_dict(), 'msca_model_50_2.pt')
print("Training completed successfully.")


In [None]:
import matplotlib.pyplot as plt

def plot_training_history(history):
    plt.figure(figsize=(12, 4))
    
    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(loc='upper right')
    
    # Accuracy plot (optional, only if you're tracking it)
    if 'train_accuracy' in history and 'val_accuracy' in history:
        plt.subplot(1, 2, 2)
        plt.plot(history['train_accuracy'], label='Training Accuracy')
        plt.plot(history['val_accuracy'], label='Validation Accuracy')
        plt.title('Accuracy over Epochs')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend(loc='lower right')

    plt.tight_layout()
    plt.show()

# Call the function
plot_training_history(history)


evaluation--matrixs

In [None]:
import torch
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, jaccard_score, confusion_matrix

# Assume device and model are already defined (same as training)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load your saved PyTorch model weights
model = UCM_Net(channels=3).to(device)
model.load_state_dict(torch.load('msca_model_50_2.pt'))
model.eval()

# Convert test data (numpy arrays) to torch tensors
test_images_torch = torch.from_numpy(test_images).to(device).float()   # shape: (N, 256, 256, 3)
test_masks_torch = torch.from_numpy(test_masks).to(device).float()     # shape: (N, 256, 256, 1)

# Rearrange images to (N, C, H, W) format expected by PyTorch models
test_images_torch = test_images_torch.permute(0, 3, 1, 2)  # from NHWC to NCHW

# Inference
with torch.no_grad():
    outputs = model(test_images_torch)  # raw logits, shape: (N, 1, 256, 256)
    probs = torch.sigmoid(outputs)      # convert logits to probabilities

# Binarize predictions at threshold 0.5
preds = (probs > 0.5).cpu().numpy().astype(np.uint8)  # shape: (N, 1, 256, 256)
true_masks = test_masks.astype(np.uint8)               # shape: (N, 256, 256, 1)

# Flatten predictions and ground truths for metrics calculation
preds_flat = preds.reshape(-1)
true_flat = true_masks.reshape(-1)

# Compute metrics using sklearn
accuracy = accuracy_score(true_flat, preds_flat)
f1 = f1_score(true_flat, preds_flat)
iou = jaccard_score(true_flat, preds_flat)
conf_matrix = confusion_matrix(true_flat, preds_flat)

# Print results
print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Jaccard Index (IoU): {iou:.4f}")
print("Confusion Matrix:")
print(conf_matrix)


result --generated mask,true mask,image

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def display_test_sample(images, true_masks, predicted_masks, index):
    """Displays test image, true mask, and predicted mask side by side."""
    # Convert tensors to numpy arrays and transpose if needed
    # images shape: (N, C, H, W) -> (H, W, C) for plt.imshow
    image = images[index].cpu().numpy().transpose(1, 2, 0)
    true_mask = true_masks[index].cpu().numpy().squeeze()        # (H, W)
    predicted_mask = predicted_masks[index].cpu().numpy().squeeze()  # (H, W)

    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(image)
    ax[0].set_title('Test Image')
    ax[0].axis('off')

    ax[1].imshow(true_mask, cmap='gray')
    ax[1].set_title('True Mask')
    ax[1].axis('off')

    ax[2].imshow(predicted_mask, cmap='gray')
    ax[2].set_title('Predicted Mask')
    ax[2].axis('off')

    plt.show()

# Assuming:
# test_images_torch: torch tensor of shape (N, 3, 256, 256)
# test_masks_torch: torch tensor of shape (N, 1, 256, 256)
# preds: predicted masks tensor of shape (N, 1, 256, 256), binary (0 or 1)

# Visualize first 15 test samples
for i in range(15):
    display_test_sample(test_images_torch, test_masks_torch, preds, i)

In [None]:
import shutil

shutil.copy2('/kaggle/input/ucm-utils/utils.py', '/kaggle/working/')
#COpying file from input dir to working dir 

exact copy of the implemented in paper