In [1]:
import torch
import torch.nn as nn


class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()

        self.channels_list = [64, 128, 256, 512]
        self.layers = len(self.channels_list)

        # Contracting Part
        self.contracting_conv_blocks = []
        self.max_pooling_blocks = []
        for channels in self.channels_list:
            self.contracting_conv_blocks.append(self.conv_block(in_channels, channels))
            self.max_pooling_blocks.append(nn.MaxPool2d(2))
            in_channels = channels

        # Connecting Part
        self.connect = self.conv_block(
            self.channels_list[-1], 2 * self.channels_list[-1]
        )

        # Expansive Part
        in_channels = 1024
        self.expansive_conv_blocks = []
        self.up_conv_blocks = []
        for channels in reversed(self.channels_list):
            self.expansive_conv_blocks.append(self.conv_block(in_channels, channels))
            self.up_conv_blocks.append(
                nn.ConvTranspose2d(in_channels, channels, 2, stride=2)
            )
            in_channels = channels

        # Final Part
        self.final = nn.Conv2d(self.channels_list[0], out_channels, 1, padding=0)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=0),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, padding=0),
            nn.ReLU(),
        )

    def forward(self, x):
        # Contracting Part
        encoder_outputs = []
        for layer in range(self.layers):
            conv_block = self.contracting_conv_blocks[layer]
            x = conv_block(x)
            encoder_outputs.append(x)
            # TODO: Check if the 'x' object is modified or uniquely generated
            max_pool = self.max_pooling_blocks[layer]
            x = max_pool(x)

        # Connecting Part
        x = self.connect(x)

        # Expansive Part
        for reversed_layer in reversed(range(self.layers)):
            layer = self.layers - reversed_layer - 1
            up_conv_block = self.up_conv_blocks[layer]
            x = up_conv_block(x)
            # Concatenate with the corresponding cropped feature map
            x = torch.cat((x, encoder_outputs[reversed_layer]), dim=1)
            conv_block = self.expansive_conv_blocks[layer]
            x = conv_block(x)

        # Final Part
        x = self.final(x)
        return x


# Create a U-Net model with 1 input channel and 2 output channels for binary segmentation
model = UNet(in_channels=1, out_channels=2)

# Print the model architecture
print(model)

UNet(
  (connect): Sequential(
    (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
  )
  (final): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1))
)


In [2]:
import torchvision.transforms as transforms
from PIL import Image

# Set the model to evaluation mode
model.eval()

# Load and preprocess your input image
image_path = './Datasets/Dataset 1/data/BMMC_1.tif'
image = Image.open(image_path).convert('L')  # Convert to grayscale
transform = transforms.Compose([transforms.ToTensor()])
input_image = transform(image).unsqueeze(0)  # Add a batch dimension

# Forward pass through the model
with torch.no_grad():
    output = model(input_image)

# At this point, 'output' contains the model's prediction for the input image.
# You can post-process the output as needed for your specific task.


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 112 but got size 120 for tensor number 1 in the list.