Image Segmentation with U-Net


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
import tifffile
import numpy as np

In [2]:
! pip install Pillow
! pip install --upgrade tifffile imagecodecs
! pip install --force-reinstall tifffile imagecodecs

Collecting tifffile
  Using cached tifffile-2024.2.12-py3-none-any.whl.metadata (31 kB)
Collecting imagecodecs
  Using cached imagecodecs-2024.1.1-cp312-cp312-win_amd64.whl.metadata (20 kB)
Collecting numpy (from tifffile)
  Using cached numpy-1.26.4-cp312-cp312-win_amd64.whl.metadata (61 kB)
Using cached tifffile-2024.2.12-py3-none-any.whl (224 kB)
Using cached imagecodecs-2024.1.1-cp312-cp312-win_amd64.whl (25.4 MB)
Using cached numpy-1.26.4-cp312-cp312-win_amd64.whl (15.5 MB)
Installing collected packages: numpy, tifffile, imagecodecs
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
  Attempting uninstall: tifffile
    Found existing installation: tifffile 2024.2.12
    Uninstalling tifffile-2024.2.12:
      Successfully uninstalled tifffile-2024.2.12
  Attempting uninstall: imagecodecs
    Found existing installation: imagecodecs 2024.1.1
    Uninstalling imagecodecs-2024.1.1:
    

  You can safely remove it manually.
  You can safely remove it manually.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
segmentation-models 1.0.1 requires efficientnet==1.0.0, but you have efficientnet 1.1.1 which is incompatible.


In [3]:
! pip install imagecodecs



In [4]:
! pip install -U imagecodecs[all]




In [5]:
! pip install rasterio




In [6]:
import numpy as np
import torch
import torchvision.transforms as transforms
import rasterio


def load_image(file_path, bands=None):
    # Open TIF file using rasterio
    with rasterio.open(file_path) as dataset:
        if bands is None:
            # Read all bands
            image_bands = dataset.read().transpose(1, 2, 0)
        else:
            # Read the specified bands of the image
            image_bands = dataset.read(bands).transpose(1, 2, 0)

    # Convert image to float
    # Assuming 16-bit TIF, adjust if needed
    image_bands = image_bands.astype(np.float32) / 65535.0

    # Check if image has only one channel, if so, duplicate it to create three channels
    if image_bands.shape[-1] == 1:
        image_bands = np.repeat(image_bands, 3, axis=-1)

    # Normalize the image bands with custom mean and std values
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.2, 0.2, 0.2]),
    ])

    input_tensor = transform(image_bands).unsqueeze(0)
    return input_tensor


# Example usage:
file_path = "D:\College files\TISS\Sentinal.tif"
input_image = load_image(file_path)
print("Input image shape:", input_image.shape)

  file_path = "D:\College files\TISS\Sentinal.tif"


Input image shape: torch.Size([1, 3, 846, 1262])


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import tifffile
import numpy as np


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = DoubleConv(in_channels, 64)  # Update input channels
        self.down_conv_2 = DoubleConv(64, 128)
        self.down_conv_3 = DoubleConv(128, 256)
        self.down_conv_4 = DoubleConv(256, 512)
        self.down_conv_5 = DoubleConv(512, 1024)

        self.up_trans_1 = nn.ConvTranspose2d(
            1024, 512, kernel_size=2, stride=2)
        self.up_conv_1 = DoubleConv(1024, 512)

        self.up_trans_2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_conv_2 = DoubleConv(512, 256)

        self.up_trans_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv_3 = DoubleConv(256, 128)

        self.up_trans_4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv_4 = DoubleConv(128, 64)

        self.out = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.down_conv_1(x)
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2)
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4)
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6)
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)

        # Decoder
        x = self.up_trans_1(x9)
        x = F.interpolate(x, size=x7.size()[
                          2:], mode='bilinear', align_corners=True)
        x = self.up_conv_1(torch.cat([x, x7], 1))

        x = self.up_trans_2(x)
        x = F.interpolate(x, size=x5.size()[
                          2:], mode='bilinear', align_corners=True)
        x = self.up_conv_2(torch.cat([x, x5], 1))

        x = self.up_trans_3(x)
        x = F.interpolate(x, size=x3.size()[
                          2:], mode='bilinear', align_corners=True)
        x = self.up_conv_3(torch.cat([x, x3], 1))

        x = self.up_trans_4(x)
        x = F.interpolate(x, size=x1.size()[
                          2:], mode='bilinear', align_corners=True)
        x = self.up_conv_4(torch.cat([x, x1], 1))

        x = self.out(x)
        return x


def load_image(file_path):
    # Load the image using tifffile
    image = tifffile.imread(file_path)

    # Convert the image to grayscale
    # Assuming it's a single-channel image
    image = image[np.newaxis, ...]  # Add batch dimension

    # Transform the image to tensor
    image_tensor = torch.from_numpy(image).float()

    return image_tensor


# Example usage:
# Update with your image file path
file_path = "D:\College files\TISS\Sentinal.tif"
# file_path= gdal.Open("D:\College files\TISS\Data\Sentinal.tif").ReadAsArray()
image = load_image(file_path)
# Adjust out_channels if needed
model = UNet(in_channels=image.size(1), out_channels=2)
output = model(image)
print(output.size())

  file_path = "D:\College files\TISS\Sentinal.tif"
  file_path = "D:\College files\TISS\Sentinal.tif"


RuntimeError: Given groups=1, weight of size [64, 846, 3, 3], expected input[1, 1, 846, 1262] to have 846 channels, but got 1 channels instead