In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

In [4]:



class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = DoubleConv(1, 64)
        self.down_conv_2 = DoubleConv(64, 128)
        self.down_conv_3 = DoubleConv(128, 256)
        self.down_conv_4 = DoubleConv(256, 512)
        self.down_conv_5 = DoubleConv(512, 1024)

        self.up_trans_1 = nn.ConvTranspose2d(
            1024, 512, kernel_size=2, stride=2)
        self.up_conv_1 = DoubleConv(1024, 512)

        self.up_trans_2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_conv_2 = DoubleConv(512, 256)

        self.up_trans_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv_3 = DoubleConv(256, 128)

        self.up_trans_4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv_4 = DoubleConv(128, 64)

        self.out = nn.Conv2d(64, 2, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.down_conv_1(x)
        x2 = self.max_pool_2x2(x1)
        x3 = self.down_conv_2(x2)
        x4 = self.max_pool_2x2(x3)
        x5 = self.down_conv_3(x4)
        x6 = self.max_pool_2x2(x5)
        x7 = self.down_conv_4(x6)
        x8 = self.max_pool_2x2(x7)
        x9 = self.down_conv_5(x8)

        # Decoder
        x = self.up_trans_1(x9)
        x = F.interpolate(x, size=x7.size()[
                          2:], mode='bilinear', align_corners=True)
        x = self.up_conv_1(torch.cat([x, x7], 1))

        x = self.up_trans_2(x)
        x = F.interpolate(x, size=x5.size()[
                          2:], mode='bilinear', align_corners=True)
        x = self.up_conv_2(torch.cat([x, x5], 1))

        x = self.up_trans_3(x)
        x = F.interpolate(x, size=x3.size()[
                          2:], mode='bilinear', align_corners=True)
        x = self.up_conv_3(torch.cat([x, x3], 1))

        x = self.up_trans_4(x)
        x = F.interpolate(x, size=x1.size()[
                          2:], mode='bilinear', align_corners=True)
        x = self.up_conv_4(torch.cat([x, x1], 1))

        x = self.out(x)
        return x


def load_image(file_path):
    # Load the image using PIL
    image = Image.open(file_path)

    # Convert the image to grayscale
    image = image.convert('L')

    # Transform the image to tensor
    transform = transforms.ToTensor()
    image_tensor = transform(image)

    # Add batch dimension
    image_tensor = image_tensor.unsqueeze(0)

    return image_tensor


# Example usage:
# Update with your image file path
file_path = "D:\College files\TISS\Data\Sentinal.tif"
image = load_image(file_path)
height, width = image.shape[2], image.shape[3]  # Get image height and width
model = UNet()
output = model(image)
print(output.size())

  file_path = "D:\College files\TISS\Data\Sentinal.tif"
More samples per pixel than can be decoded: 23
  file_path = "D:\College files\TISS\Data\Sentinal.tif"


UnidentifiedImageError: cannot identify image file 'D:\\College files\\TISS\\Data\\Sentinal.tif'