In [None]:
import os
import cv2
import torch
import torch.nn as nn
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

In [None]:
# Define the model architecture (UNet) 

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder1 = DoubleConv(3, 64)
        self.encoder2 = DoubleConv(64, 128)
        self.encoder3 = DoubleConv(128, 256)
        self.encoder4 = DoubleConv(256, 512)
        
        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.decoder3 = DoubleConv(512 + 256, 256)
        self.decoder2 = DoubleConv(256 + 128, 128)
        self.decoder1 = DoubleConv(128 + 64, 64)
        
        self.final_conv = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        # Encoder
        e1 = self.encoder1(x)
        e2 = self.encoder2(self.pool(e1))
        e3 = self.encoder3(self.pool(e2))
        e4 = self.encoder4(self.pool(e3))
        
        # Decoder
        d3 = self.decoder3(torch.cat([self.upsample(e4), e3], dim=1))
        d2 = self.decoder2(torch.cat([self.upsample(d3), e2], dim=1))
        d1 = self.decoder1(torch.cat([self.upsample(d2), e1], dim=1))
        
        return torch.sigmoid(self.final_conv(d1))

# Helper function: given a binary mask, get the bounding box coordinates
def get_bounding_box(binary_mask):
    # Find indices where mask is positive
    coords = np.column_stack(np.where(binary_mask > 0))
    if coords.size == 0:
        return None
    y_min, x_min = coords[:,0].min(), coords[:,1].min()
    y_max, x_max = coords[:,0].max(), coords[:,1].max()
    return x_min, y_min, x_max, y_max

In [None]:
# Define the transform used for inference (resize, normalize, tensor conversion)
transform = A.Compose([
    A.Resize(512, 512),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

In [None]:
# Function to process an image, run inference, and crop the original image using the predicted mask.
def process_and_crop_image(image_path, model, transform, device):
    # Read original image
    image_orig = cv2.imread(image_path)
    if image_orig is None:
        print(f"Failed to read {image_path}")
        return None
    # Convert from BGR to RGB
    image_orig = cv2.cvtColor(image_orig, cv2.COLOR_BGR2RGB)
    orig_h, orig_w, _ = image_orig.shape
    
    # Apply transform to prepare input for the model
    transformed = transform(image=image_orig)
    image_trans = transformed['image']  # Tensor shape: (3, 512, 512)
    image_trans = image_trans.unsqueeze(0).to(device)
    
    # Run model inference
    model.eval()
    with torch.no_grad():
        pred = model(image_trans)
    
    # Squeeze batch dimension and convert to numpy
    pred_np = pred.squeeze().cpu().numpy()  # shape: (512, 512)
    
    # Threshold to get binary mask
    binary_mask = (pred_np > 0.5).astype(np.uint8)
    
    # Compute bounding box from mask
    bbox = get_bounding_box(binary_mask)
    if bbox is None:
        print(f"No mask detected for {image_path}. Skipping cropping.")
        return None
    x_min, y_min, x_max, y_max = bbox
    
    # Map coordinates back to the original image dimensions
    scale_x = orig_w / 512.0
    scale_y = orig_h / 512.0
    x_min_orig = int(x_min * scale_x)
    y_min_orig = int(y_min * scale_y)
    x_max_orig = int(x_max * scale_x)
    y_max_orig = int(y_max * scale_y)
    
    # Crop the original image using the mapped bounding box
    cropped_image = image_orig[y_min_orig:y_max_orig, x_min_orig:x_max_orig]
    
    return cropped_image

# Optional: function to display the cropped result
def display_image(image, title="Image"):
    plt.figure(figsize=(6, 6))
    plt.imshow(image)
    plt.title(title)
    plt.axis('off')
    plt.show()

In [None]:
# Setup: load the trained model and prepare directories
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize the model and load trained weights
model = UNet().to(device)
model_weights = 'best_model.pth'
if os.path.exists(model_weights):
    checkpoint = torch.load(model_weights, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model weights from {model_weights}")
else:
    print(f"Model weights file {model_weights} not found. Please check the path.")

# Define input and output directories
input_dir = 'pictures'
output_dir = 'cropped'
os.makedirs(output_dir, exist_ok=True)

In [None]:
# Process each image in the 'pictures' directory and save the cropped version
supported_ext = ['.jpg', '.jpeg', '.png', '.bmp']
image_files = [f for f in os.listdir(input_dir) if os.path.splitext(f)[1].lower() in supported_ext]

print(f"Found {len(image_files)} image(s) in '{input_dir}' directory.")

for filename in tqdm(image_files, desc="Processing images"):
    input_path = os.path.join(input_dir, filename)
    cropped = process_and_crop_image(input_path, model, transform, device)
    if cropped is not None:
        # Convert RGB back to BGR for saving via OpenCV
        cropped_bgr = cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR)
        output_path = os.path.join(output_dir, filename)
        cv2.imwrite(output_path, cropped_bgr)
        # Optional: display the cropped image
        # display_image(cropped, title=filename)
    else:
        print(f"Skipping {filename} due to no detected mask.")

print(f"Cropped images have been saved to the '{output_dir}' directory.")

In [None]:
# Replace 'my_image.jpg' with the name of your image in the "pictures" directory
input_filename = "test2.jpg"
# input_path = os.path.join("pictures", input_filename)

# Process the image to obtain the cropped output
cropped = process_and_crop_image(input_filename, model, transform, device)
if cropped is not None:
    # Construct the output filename (e.g., "my_image_cropped.jpg")
    filename, ext = os.path.splitext(input_filename)
    output_filename = filename + "_cropped" + ext
    output_path = os.path.join("cropped", output_filename)
    
    # Ensure the "cropped" directory exists
    os.makedirs("cropped", exist_ok=True)
    
    # Convert cropped image from RGB (used in processing) to BGR (for OpenCV saving)
    cropped_bgr = cv2.cvtColor(cropped, cv2.COLOR_RGB2BGR)
    cv2.imwrite(output_path, cropped_bgr)
    print(f"Cropped image saved as {output_path}")
else:
    print("No mask detected. Image was not cropped.")