# Importing Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Mirorred UNet

In [2]:
class MirroredUNet(nn.Module):
    def __init__(self):
        super(MirroredUNet, self).__init__()

        # Encoder
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Decoder with mirrored skip connections
        self.conv4 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv5 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv6 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)

        # Output convolution
        self.output_conv = nn.Conv2d(32, 3, kernel_size=3, padding=1)

    def forward(self, x):
        # Encoder
        x1 = F.relu(self.conv1(x))
        x2 = self.pool1(x1)

        x3 = F.relu(self.conv2(x2))
        x4 = self.pool2(x3)

        x5 = F.relu(self.conv3(x4))
        x6 = self.pool3(x5)

        # Decoder with mirrored skip connections
        x7 = F.relu(self.conv4(x6, output_size=x5.shape))
        x8 = F.relu(self.conv5(x7, output_size=x3.shape))
        x9 = F.relu(self.conv6(x8, output_size=x1.shape))

        # Output convolution
        output = self.output_conv(x9)
        return output

In [9]:
import cv2
import numpy as np
import torch
from torchvision import transforms
from PIL import Image

# Load the image
img = cv2.imread('samp.png')

# Convert BGR image to RGB
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# Resize image to match the expected input shape of the model
input_size = (256, 256)  # Assuming the model expects input size of 256x256
img = cv2.resize(img, input_size)

# Convert image to PyTorch tensor and normalize
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Assuming ImageNet normalization
])
img_tensor = transform(img).unsqueeze(0)  # Add batch dimension

# Load the model
model = MirroredUNet()

# Set the model to evaluation mode
model.eval()

# Perform inference
with torch.no_grad():
    output_tensor = model(img_tensor)

# Convert the output tensor to numpy array
output_img = output_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()

# Post-process the output image if needed (e.g., denormalization)
# (Note: This step depends on how you preprocess your input and what normalization you apply)

# Convert the output image to uint8 and ensure it's in the range [0, 255]
output_img = np.clip(output_img * 255, 0, 255).astype(np.uint8)

# Display or save the output image
cv2.imshow('Output Image', cv2.cvtColor(output_img, cv2.COLOR_RGB2BGR))
cv2.waitKey(0)
cv2.destroyAllWindows()