In [1]:
# All fuction preparation

In [None]:
from PIL import Image
import os
import torch
import torch.nn.functional as F

def generate_crop_positions(total, crop_size=1024, overlap=0.25, min_overlap=256):
    """
    Generate a list of crop positions.

    :param total: The total width or height of the image.
    :param crop_size: The size of the crop (default 1024).
    :param overlap: The overlap ratio (default 0.25).
    :param min_overlap: The minimum overlap in pixels (default 512).
    :return: A list of crop positions, e.g., [[start1, end1], [start2, end2], ...].
    """
    if total <= crop_size:
        # If total dimension is smaller than crop size, return the full range
        return [[0, total]]

    # Calculate step size based on overlap ratio
    step_size = int(crop_size * (1 - overlap))
    positions = []
    pos = 0

    # Generate positions with calculated step size
    while pos + crop_size <= total:
        positions.append([pos, pos + crop_size])
        pos += step_size

    # Handle remaining portion of the image
    if positions:
        remaining = total - positions[-1][1]
    else:
        remaining = total

    if remaining > 0:
        if remaining < min_overlap:
            # Adjust last position if remaining is smaller than min_overlap
            positions[-1][1] = total
        else:
            # Add new position for remaining portion if it's significant
            new_start = positions[-1][1] - (crop_size - step_size)
            if new_start < 0:
                new_start = 0
            positions.append([new_start, total])

    return positions

def crop_image_with_overlap(image, output_dir, crop_size=1024, overlap=0.25, min_overlap=256):
    """
    Crop the image based on the overlap ratio.

    :param image_path: Path to the input image.
    :param output_dir: Output directory.
    :param crop_size: Crop size (default 1024).
    :param overlap: Overlap ratio (default 0.25).
    :param min_overlap: Minimum overlap in pixels (default 512).
    :return: Returns the number of rows and columns of crops.
    """
    # Get image dimensions
    width, height = image.shape[1], image.shape[0]

    # Generate crop positions for both dimensions
    crop_positions_width = generate_crop_positions(width, crop_size, overlap, min_overlap)
    crop_positions_height = generate_crop_positions(height, crop_size, overlap, min_overlap)

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Base name for output files
    image_name = 'temp'
    count = 1

    # Process each crop position
    for row, (top, bottom) in enumerate(crop_positions_height):
        for col, (left, right) in enumerate(crop_positions_width):
            # Perform the actual cropping
            cropped_image = image[top:bottom, left:right]

            # Convert numpy array to PIL Image
            cropped_image_pil = Image.fromarray(cropped_image)

            # Generate output path
            output_path = os.path.join(output_dir, f"{image_name}_{row + 1}_{col + 1}.png")

            # Save the cropped image
            cropped_image_pil.save(output_path)
            count += 1

    return len(crop_positions_height), len(crop_positions_width)

def mask_nms_with_scores(cls_logits, mask_logits, scores, iou_thr=0.5):
    """
    Use NMS to filter overlapping masks, returning the indices of the retained queries.

    Args:
        cls_logits (Tensor): Classification logits, shape (num_queries, num_classes).
        mask_logits (Tensor): Mask logits, shape (num_queries, h, w).
        scores (Tensor): Scores for each query, shape (num_queries,).
        iou_thr (float): IoU threshold.

    Returns:
        keep_indices (List[int]): Indices of the retained queries (sorted in ascending order).
    """
    # Convert mask logits to probabilities
    mask_pred = mask_logits.sigmoid()

    # Early return if no masks
    num_masks = mask_pred.shape[0]
    if num_masks == 0:
        return []

    # Flatten masks for IoU calculation
    mask_bool = (mask_pred >= 0.5).view(num_masks, -1)

    # Calculate intersection and union
    intersection = torch.matmul(mask_bool.float(), mask_bool.float().transpose(0, 1))
    union = mask_bool.sum(dim=1).view(-1, 1) + mask_bool.sum(dim=1).view(1, -1) - intersection
    iou_matrix = intersection / union

    # Sort by score (descending)
    sorted_indices = torch.argsort(scores, descending=True)
    keep_indices = []

    # NMS filtering
    for idx in sorted_indices:
        # Check overlap with kept indices
        if all(iou_matrix[idx, kept_idx] <= iou_thr for kept_idx in keep_indices):
            keep_indices.append(idx.item())

    # Return sorted indices
    keep_indices.sort()
    return keep_indices

def aggregate_mask_logits(labels_final, mask_logits_final, num_classes=5):
    """
    Aggregate mask logits of the same class to generate masks for each class.

    Args:
        labels_final (Tensor): Filtered labels, shape (num_queries,).
        mask_logits_final (Tensor): Filtered mask logits, shape (num_queries, H, W).
        num_classes (int): Number of classes (default 5).

    Returns:
        class_masks (Tensor): Aggregated masks for each class, shape (num_classes, H, W).
    """
    # Initialize output tensor with negative infinity
    h, w = mask_logits_final.shape[-2:]
    class_masks = torch.full((num_classes, h, w), -float('inf'), device=mask_logits_final.device)

    # Process each class separately
    for class_id in range(num_classes):
        # Find masks belonging to current class
        class_mask_indices = (labels_final == class_id)
        
        if class_mask_indices.any():
            # Get all masks for this class
            class_mask_logits = mask_logits_final[class_mask_indices]
            
            # Take element-wise maximum
            class_mask_max = torch.max(class_mask_logits, dim=0).values
            
            # Store result
            class_masks[class_id] = class_mask_max

    return class_masks

In [None]:
# Weight matrix generation with cosine transition

import numpy as np
import matplotlib.pyplot as plt

# Function to create a base weight matrix for corners (top-left corner)
def create_corner_base_weight_matrix(height, width, overlap):
    """
    Create a base weight matrix for the top-left corner mask.
    """
    weight = np.ones((height, width))
    x = np.linspace(0, np.pi / 2, overlap)  # For cosine transition
    cos_weights = np.cos(x) ** 2  # Cosine weights

    # Right overlap
    weight[:, -overlap:] *= cos_weights
    # Bottom overlap
    weight[-overlap:, :] *= cos_weights[:, None]
    # Bottom-right corner overlap
    weight[-overlap:, -overlap:] *= np.outer(cos_weights, cos_weights)

    return weight

# Function to create a base weight matrix for edges (top edge)
def create_edge_base_weight_matrix(height, width, overlap, edge_type):
    """
    Create a base weight matrix for the edge mask.
    edge_type: 'top', 'bottom', 'left', 'right'
    """
    weight = np.ones((height, width))
    x = np.linspace(0, np.pi / 2, overlap)  # For cosine transition
    cos_weights = np.cos(x) ** 2  # Cosine weights

    if edge_type == 'top':
        # Left overlap
        weight[:, :overlap] *= cos_weights[::-1]
        # Right overlap
        weight[:, -overlap:] *= cos_weights
        # Bottom overlap
        weight[-overlap:, :] *= cos_weights[:, None]
    elif edge_type == 'bottom':
        # Left overlap
        weight[:, :overlap] *= cos_weights[::-1]
        # Right overlap
        weight[:, -overlap:] *= cos_weights
        # Top overlap
        weight[:overlap, :] *= cos_weights[::-1][:, None]
    elif edge_type == 'left':
        # Top overlap
        weight[:overlap, :] *= cos_weights[::-1][:, None]
        # Bottom overlap
        weight[-overlap:, :] *= cos_weights[:, None]
        # Right overlap
        weight[:, -overlap:] *= cos_weights
    elif edge_type == 'right':
        # Top overlap
        weight[:overlap, :] *= cos_weights[::-1][:, None]
        # Bottom overlap
        weight[-overlap:, :] *= cos_weights[:, None]
        # Left overlap
        weight[:, :overlap] *= cos_weights[::-1]

    return weight

# Function to create a base weight matrix for the center
def create_center_weight_matrix(height, width, overlap):
    """
    Create a weight matrix for the center mask.
    """
    weight = np.ones((height, width))
    x = np.linspace(0, np.pi / 2, overlap)  # For cosine transition
    cos_weights = np.cos(x) ** 2  # Cosine weights

    # Apply weights for all four sides and corners
    weight[:overlap, :] *= cos_weights[::-1][:, None]  # Top overlap
    weight[-overlap:, :] *= cos_weights[:, None]  # Bottom overlap
    weight[:, :overlap] *= cos_weights[::-1]  # Left overlap
    weight[:, -overlap:] *= cos_weights  # Right overlap

    return weight

# Function to create weight matrix based on position (x, y) and grid size (x_max, y_max)
def create_weight_matrix(x, y, x_max, y_max, height, width, overlap):
    """
    Create a weight matrix for a given mask position (x, y) in a grid of size (x_max, y_max).
    The weight matrix size is determined by the input image size (height, width).
    """
    # Determine the position type: corner, edge, or center
    is_corner = (x == 1 and y == 1) or (x == 1 and y == y_max) or (x == x_max and y == 1) or (x == x_max and y == y_max)
    is_edge = not is_corner and (x == 1 or x == x_max or y == 1 or y == y_max)
    is_center = not is_corner and not is_edge

    if is_corner:
        # Use the corner base weight matrix
        if x == 1 and y == 1:  # Top-left corner
            return create_corner_base_weight_matrix(height, width, overlap)
        elif x == 1 and y == y_max:  # Top-right corner
            weight = create_corner_base_weight_matrix(height, width, overlap)
            return np.fliplr(weight)
        elif x == x_max and y == 1:  # Bottom-left corner
            weight = create_corner_base_weight_matrix(height, width, overlap)
            return np.flipud(weight)
        elif x == x_max and y == y_max:  # Bottom-right corner
            weight = create_corner_base_weight_matrix(height, width, overlap)
            return np.flipud(np.fliplr(weight))
    elif is_edge:
        # Use the edge base weight matrix
        if x == 1:  # Top edge
            return create_edge_base_weight_matrix(height, width, overlap, 'top')
        elif x == x_max:  # Bottom edge
            return create_edge_base_weight_matrix(height, width, overlap, 'bottom')
        elif y == 1:  # Left edge
            return create_edge_base_weight_matrix(height, width, overlap, 'left')
        elif y == y_max:  # Right edge
            return create_edge_base_weight_matrix(height, width, overlap, 'right')
    else:
        # Use the center weight matrix
        return create_center_weight_matrix(height, width, overlap)


In [3]:
# Main function

In [None]:
import cv2
import torch
import numpy as np
import os
from PIL import Image
from tqdm.notebook import tqdm
from mmdet.apis import init_detector, inference_detector
import torch.nn.functional as F
import shutil
from scipy.interpolate import griddata

# Initialize model
config_file = 'mask2former_swin-l-p4-w12-384_8xb2-lsj-100e_coco-1227.py'
checkpoint_file = 'epoch_100.pth'
device = 'cuda:0'
model = init_detector(config_file, checkpoint_file, device=device)
model.eval()
class_names = model.dataset_meta.get('classes', [])
num_classes = len(class_names)
risk_thresholds = [(3825, 43333), (3610, 70096), (4415, 15355), (1552, 13158), (714, 4823)]

def visualize_results(big_mask_logits, image_large, class_names, risk_thresholds):
    """
    Visualize results: Generate probability map, binary map, boundary map and segmentation map.
    
    Args:
        big_mask_logits (Tensor): Mask logits for the large image, shape (num_classes, H, W).
        image_large (numpy.ndarray): Original large image, shape (H, W, 3).
        class_names (list): List of class names.
        risk_thresholds (list): Area thresholds for each class in format [(low1, high1), (low2, high2), ...].
    
    Returns:
        dict: Dictionary containing the following key-value pairs:
            - "probability": Probability map (OpenCV format).
            - "binary": Binary map (PIL format).
            - "boundary": Boundary map (PIL format).
            - "segmentation": Segmentation map (PIL format).
            - "original": Original large image (numpy.ndarray format).
    """

    # Generate probability map and binary map
    max_probs, _ = torch.max(big_mask_logits, dim=0)

    # Handle NaN and Inf values
    if torch.isnan(max_probs).any() or torch.isinf(max_probs).any():
        max_probs = torch.nan_to_num(max_probs, nan=np.nan, posinf=np.nan, neginf=np.nan)  # Replace Inf with NaN
    
    probability_map = torch.sigmoid(max_probs).cpu().numpy()  # For binary classification
    # probability_map = torch.softmax(max_probs, dim=0).cpu().numpy()  # For multi-class classification
    
    # Interpolate to fill NaN values
    if np.isnan(probability_map).any():
        valid_mask = ~np.isnan(probability_map)  # Mask of valid values
        points = np.argwhere(valid_mask)  # Coordinates of valid values
        values = probability_map[valid_mask]  # Valid values
        grid_x, grid_y = np.mgrid[0:probability_map.shape[0], 0:probability_map.shape[1]]
        probability_map = griddata(points, values, (grid_x, grid_y), method='nearest')
    
    # Scale probability map to [0, 255] and convert to uint8
    probability_map = (probability_map * 255).clip(0, 255).astype(np.uint8)
    
    # Apply color map
    probability_image = cv2.applyColorMap(probability_map, cv2.COLORMAP_JET)
    
    # Generate binary map
    binary_map = (probability_map > 0).astype(np.uint8) * 255
    binary_image = Image.fromarray(binary_map)

    # Generate boundary map and segmentation map
    logits = big_mask_logits.cpu().numpy()
    colors = [
        (255, 222, 7),    # #07deff
        (61, 207, 255),   # #ffcf3d
        (255, 74, 74),    # #4a4aff
        (127, 85, 0),     # #00557f
        (87, 15, 255)     # #ff0f57
    ]
    bbox_colors = [(0, 255, 0), (0, 165, 255), (0, 0, 255)]  # Green, orange, red

    output_image = image_large.copy()
    segmentation_image = image_large.copy()

    # Store annotation data for each valid contour
    annotations = []  

    # Process each class
    for class_idx in range(len(class_names)):
        mask = logits[class_idx] > 0
        contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
        for contour in contours:
            area = cv2.contourArea(contour)
            if area < 50:  # Skip small contours
                continue
                
            # Get segmentation coordinates (flattened contour points)
            segmentation = contour.flatten().tolist()
            if len(segmentation) < 6 or len(segmentation) % 2 != 0:
                continue  # Skip if insufficient points or not in (x,y) pairs
    
            x, y, w, h = cv2.boundingRect(contour)
            
            # Store annotation data
            annotations.append({
                "class_id": class_idx + 1,  # COCO format starts class IDs from 1
                "segmentation": segmentation,
                "bbox": [x, y, w, h],
                "area": float(area)
            })
    
            # Draw boundary
            cv2.drawContours(output_image, [contour], -1, colors[class_idx], 2)
    
            # Generate bounding box
            x, y, w, h = cv2.boundingRect(contour)
    
            # Set bbox color based on area thresholds
            if area < risk_thresholds[class_idx][0]:
                bbox_color = bbox_colors[0]  # Green
            elif area < risk_thresholds[class_idx][1]:
                bbox_color = bbox_colors[1]  # Orange
            else:
                bbox_color = bbox_colors[2]  # Red
    
            # Draw bounding box
            cv2.rectangle(output_image, (x, y), (x + w, y + h), bbox_color, 2)
    
            # Add class label in top-left corner of bbox
            text = class_names[class_idx]
            font_scale = 0.8  # Font size
            thickness = 2  # Font thickness
            (text_width, text_height), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
    
            # Calculate text position (inside bbox top-left)
            text_x = x + 5  # 5px from left edge
            text_y = y + 20  # 20px from top edge
    
            # Ensure text stays within bbox
            if text_y - text_height < y:  # If text goes above bbox
                text_y = y + text_height + 5  # Move text down
            if text_x + text_width > x + w:  # If text goes beyond right edge
                text_x = x + w - text_width - 5  # Move text left
    
            # Draw text background
            cv2.rectangle(output_image, (text_x, text_y - text_height), (text_x + text_width, text_y), bbox_color, -1)  # -1 for filled
    
            # Draw text
            cv2.putText(output_image, text, (text_x, text_y - 5), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness)
    
            # Draw segmentation (filled contour)
            overlay = segmentation_image.copy()
            cv2.drawContours(overlay, [contour], -1, colors[class_idx], -1)  # Fill contour
            alpha = 0.5  # 50% transparency
            cv2.addWeighted(overlay, alpha, segmentation_image, 1 - alpha, 0, segmentation_image)

    # Convert OpenCV images to PIL format
    boundary_map = Image.fromarray(cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB))
    segmentation_map = Image.fromarray(cv2.cvtColor(segmentation_image, cv2.COLOR_BGR2RGB))

    return {
        "probability": probability_image,
        "binary": binary_image,
        "boundary": boundary_map,
        "segmentation": segmentation_map,
        "original": image_large,
        "json_coords": annotations,  # Contour data for annotations
    }



Loads checkpoint by local backend from path: epoch_100.pth


In [None]:
import os
import shutil
import numpy as np
import torch
import cv2
from tqdm import tqdm

m = None  # Global variable to store mask logits

def process_image(image, confidence_threshold=0.5, crack_threshold=0.01, iou_thr_nms=0.2, 
                 crop_sizes=[1000, 1400, 1800], overlap=0.5, min_overlap=256):
    """
    Main function: Process large image, generate mask logits and visualize results.
    
    Args:
        image (numpy.ndarray): Input large image with shape (H, W, 3).
        confidence_threshold (float): Confidence threshold (default 0.5).
        crack_threshold (float): Special threshold for crack detection (default 0.01).
        iou_thr_nms (float): IoU threshold for NMS (default 0.2).
        crop_sizes (list): List of crop sizes to process (default [1000, 1400, 1800]).
        overlap (float): Overlap ratio between crops (default 0.5).
        min_overlap (int): Minimum overlap in pixels (default 256).
    
    Returns:
        dict: Dictionary containing visualization results with keys:
            - "probability": Probability map (OpenCV format).
            - "binary": Binary map (PIL format).
            - "boundary": Boundary map (PIL format).
            - "segmentation": Segmentation map (PIL format).
            - "original": Original image (numpy.ndarray format).
    """
    global m  # Access global variable
    temp_dir = 'temp_test2'  # Temporary directory for cropped images
    image_large = image  # Store original image
    big_mask_logits_list = []  # List to store mask logits from different crop sizes

    # Process image with each crop size
    for crop_size in crop_sizes:
        # Initialize tensor for aggregated mask logits
        big_mask_logits = torch.zeros((num_classes, image_large.shape[0], image_large.shape[1]), 
                         device=device)

        # Generate crop positions for current size
        crop_positions_height = generate_crop_positions(image_large.shape[0], crop_size, 
                                                      overlap, min_overlap)
        crop_positions_width = generate_crop_positions(image_large.shape[1], crop_size, 
                                                     overlap, min_overlap)
        
        # Crop image and get crop positions
        rows, cols = crop_image_with_overlap(image_large, temp_dir, crop_size, 
                                           overlap, min_overlap)

        try:
            # Process each cropped image
            for filename in tqdm(os.listdir(temp_dir)):
                if filename.endswith('.png'):
                    try:
                        # Parse coordinates from filename (format: prefix_x_y.png)
                        parts = filename.split('_')
                        if len(parts) >= 3:
                            x = int(parts[-2])  # Row index
                            y = int(parts[-1].split('.')[0])  # Column index
                        else:
                            # Fallback: extract digits from filename
                            x, y = map(int, ''.join(filter(str.isdigit, filename)).split())
                    except ValueError as e:
                        print(f"Cannot parse coordinates from filename {filename}: {e}")
                        continue

                    # Validate coordinates
                    if x - 1 >= len(crop_positions_height) or y - 1 >= len(crop_positions_width):
                        print(f"Coordinates ({x}, {y}) out of range, skipping image")
                        continue

                    # Perform inference on cropped image
                    img_path = os.path.join(temp_dir, filename)
                    image = cv2.imread(img_path)
                    result = inference_detector(model, image)

                    # Get detection results
                    labels = result.pred_instances['labels']
                    mask_logits = result.pred_instances['mask_logits']
                    scores = result.pred_instances['scores']

                    # Apply confidence thresholds (special lower threshold for cracks)
                    valid_indices = torch.where(
                        (scores > confidence_threshold) | 
                        ((labels == 4) & (scores > crack_threshold))
                    )[0]
                    
                    # Apply NMS to filter overlapping masks
                    keep_indices = mask_nms_with_scores(
                        labels[valid_indices], 
                        mask_logits[valid_indices], 
                        scores[valid_indices], 
                        iou_thr=iou_thr_nms
                    )
                    labels_final = labels[valid_indices][keep_indices]
                    mask_logits_final = mask_logits[valid_indices][keep_indices]
                    
                    # Aggregate masks by class
                    class_masks = aggregate_mask_logits(labels_final, mask_logits_final, num_classes)

                    # Create weight matrix for blending
                    weight_matrix = create_weight_matrix(
                        x, y, rows, cols, 
                        image.shape[0], image.shape[1], 
                        overlap=256
                    )
                    weight_matrix = torch.from_numpy(weight_matrix.copy()).float().to(device)
        
                    # Blend masks into the large image
                    for class_id in range(num_classes):
                        class_mask = class_masks[class_id] * weight_matrix
                        start_x, end_x = crop_positions_height[x - 1]
                        start_y, end_y = crop_positions_width[y - 1]
                        big_mask_logits[class_id, start_x:end_x, start_y:end_y] += class_mask

            # Store results for current crop size
            big_mask_logits_list.append(big_mask_logits)

        finally:
            # Clean up temporary directory
            if os.path.exists(temp_dir):
                try:
                    shutil.rmtree(temp_dir, ignore_errors=True)
                except Exception as e:
                    print(f"Error cleaning temp directory: {e}")

    # Combine results from different crop sizes
    if big_mask_logits_list:
        # Take maximum values across all crop sizes
        final_big_mask_logits = torch.stack(big_mask_logits_list).max(dim=0).values
    else:
        final_big_mask_logits = torch.zeros(
            (num_classes, image_large.shape[0], image_large.shape[1]), 
            device=device
        )
    
    m = final_big_mask_logits  # Store in global variable
    
    # Generate visualization results
    results = visualize_results(final_big_mask_logits, image_large, class_names, risk_thresholds)
    return results

In [6]:
# Define the legend HTML with CSS (matching Gradio's style)
legend_html = """
<div style="border: 1px solid #e0e0e0; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);">
    <h3 style="margin-top: 0;">Legend for Boundary</h3>
    <div style="display: flex; gap: 30px;">
        <!-- Left: Class -->
        <div style="display: flex; flex-direction: column; gap: 10px;">
            <h4 style="margin: 0;">Class</h4>
            <div style="display: flex; align-items: center;">
                <div style="width: 20px; height: 4px; background-color: rgb(7, 222, 255); margin-right: 10px;"></div>
                <span>Seepage</span>
            </div>
            <div style="display: flex; align-items: center;">
                <div style="width: 20px; height: 4px; background-color: rgb(255, 207, 61); margin-right: 10px;"></div>
                <span>Corrosion</span>
            </div>
            <div style="display: flex; align-items: center;">
                <div style="width: 20px; height: 4px; background-color: rgb(75, 74, 255); margin-right: 10px;"></div>
                <span>Damaged Joint</span>
            </div>
            <div style="display: flex; align-items: center;">
                <div style="width: 20px; height: 4px; background-color: rgb(0, 85, 127); margin-right: 10px;"></div>
                <span>Spalling</span>
            </div>
            <div style="display: flex; align-items: center;">
                <div style="width: 20px; height: 4px; background-color: rgb(255, 15, 87); margin-right: 10px;"></div>
                <span>Crack</span>
            </div>
        </div>

        <!-- Right: Risk Levels -->
        <div style="display: flex; flex-direction: column; gap: 10px;">
            <h4 style="margin: 0;">Risk Levels</h4>
            <div style="display: flex; align-items: center;">
                <div style="width: 20px; height: 20px; border: 4px solid green; background-color: transparent; margin-right: 10px; border-radius: 4px;"></div>
                <span>Low Risk</span>
            </div>
            <div style="display: flex; align-items: center;">
                <div style="width: 20px; height: 20px; border: 4px solid orange; background-color: transparent; margin-right: 10px; border-radius: 4px;"></div>
                <span>Medium Risk</span>
            </div>
            <div style="display: flex; align-items: center;">
                <div style="width: 20px; height: 20px; border: 4px solid red; background-color: transparent; margin-right: 10px; border-radius: 4px;"></div>
                <span>High Risk</span>
            </div>
        </div>
    </div>
</div>
"""


In [None]:
import os
import tempfile
from fpdf import FPDF
from PyPDF2 import PdfWriter, PdfReader, PageObject
from PIL import Image, ImageDraw, ImageFont

def export_to_pdf():
    if not results_cache:
        return "No images processed to export."

    # Save boundary images to temporary files (with original filenames)
    temp_files = []
    try:
        for file_name, result in results_cache.items():
            boundary_image = result.get("boundary")
            if boundary_image:
                # Convert image to RGB
                boundary_image = boundary_image.convert("RGB")
                
                # Get image dimensions
                width, height = boundary_image.size

                # Define the size of the extended area
                margin_left = 100  # Left margin for vertical axis
                margin_bottom = 50  # Bottom margin for horizontal axis

                # Create a new image with extended area
                new_width = width + margin_left
                new_height = height + margin_bottom
                new_image = Image.new("RGB", (new_width, new_height), color="white")

                # Paste the original image into the new image
                new_image.paste(boundary_image, (margin_left, 0))

                # Draw coordinates on the new image
                draw = ImageDraw.Draw(new_image)

                # Load a larger font (adjust path if necessary)
                try:
                    font = ImageFont.truetype("arial.ttf", size=50)  # Larger font size
                except:
                    font = ImageFont.load_default()  # Fallback to default font
                    font.size = 50  # Adjust default font size

                # Calculate the scale factor for vertical axis
                # Horizontal axis: 0m at bottom, 20m at top
                scale_factor = height / 20  # 1m = scale_factor pixels

                # Draw horizontal axis (bottom to top: 0m to 20m)
                for i in range(0, 21):  # From 0m to 20m
                    y = height - int(i * scale_factor)  # Calculate y position (bottom to top)
                    
                    # Draw thicker dashed line
                    dash_length = 5
                    for x in range(margin_left, new_width, dash_length * 2):
                        draw.line((x, y, x + dash_length, y), fill="gray", width=2)  # Thicker line
                    
                    # Label the axis with larger text
                    label = f"{i}m"
                    draw.text((0, y), label, fill="black", font=font)  # Adjust text position

                # Draw vertical axis (center is 0, left to right)
                center_x = width // 2 + margin_left  # Center of the original image
                vertical_scale = scale_factor  # Use the same scale as horizontal axis

                # 1. Draw center line and label "0m"
                # Draw red center line
                dash_length = 5
                for y in range(0, new_height, dash_length * 2):
                    draw.line((center_x, y, center_x, y + dash_length), fill="red", width=2)
                # Label "0m"
                draw.text((center_x + 5, new_height - 50), "0", fill="black", font=font)

                # 2. Generate right-side ticks (+1m, +2m)
                x = center_x + int(vertical_scale)  # Start from the first tick to the right of the center
                while x < new_width:
                    # Draw tick line
                    for y in range(0, new_height, dash_length * 2):
                        draw.line((x, y, x, y + dash_length), fill="gray", width=2)
                    # Calculate offset and label
                    offset = int((x - center_x) / vertical_scale)
                    draw.text((x + 5, new_height - 50), f"{offset+1}", fill="black", font=font)
                    x += int(vertical_scale)

                # 3. Generate left-side ticks (-1m, -2m)
                x = center_x - int(vertical_scale)  # Start from the first tick to the left of the center
                while x > margin_left:
                    # Draw tick line
                    for y in range(0, new_height, dash_length * 2):
                        draw.line((x, y, x, y + dash_length), fill="gray", width=2)
                    # Calculate offset and label
                    offset = int((x - center_x) / vertical_scale)
                    draw.text((x + 5, new_height - 50), f"{offset-1}", fill="black", font=font)
                    x -= int(vertical_scale)

                # Save the new image to a temporary file
                with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
                    temp_path = temp_file.name
                    new_image.save(temp_path)
                    temp_files.append((temp_path, file_name))  # Store both temp path and file name

        # Load the PDF template
        template_path = "template.pdf"  # Path to your PDF template
        if not os.path.exists(template_path):
            return "Template PDF not found."

        # Create a PDF writer
        writer = PdfWriter()

        # Read the template PDF
        template_pdf = PdfReader(template_path)
        template_page = template_pdf.pages[0]  # Original template page

        # Define the position for the boundary image
        boundary_x = 10  # X position in mm
        boundary_y = 48  # Y position in mm (from top)
        target_width = 170  # Target width in mm
        target_height = 220  # Target height in mm

        # Convert mm to points (1 mm = 2.83465 points)
        boundary_x_pt = boundary_x * 2.83465
        boundary_y_pt = (297 - boundary_y - target_height) * 2.83465  # Convert Y coordinate from top to bottom
        target_width_pt = target_width * 2.83465
        target_height_pt = target_height * 2.83465

        # Add each boundary image to a new page
        for temp_path, file_name in temp_files:
            # Create a new blank page and merge the template content
            new_page = PageObject.create_blank_page(width=template_page.mediabox.width, height=template_page.mediabox.height)
            new_page.merge_page(template_page)  # Copy template content to the new page

            # Read the boundary image to get its dimensions
            with Image.open(temp_path) as img:
                img_width, img_height = img.size

            # Calculate the aspect ratio of the boundary image
            aspect_ratio = img_width / img_height

            # Calculate the scaled dimensions to fit within the target area
            if aspect_ratio > (target_width / target_height):
                # Fit to width
                scaled_width = target_width_pt
                scaled_height = scaled_width / aspect_ratio
            else:
                # Fit to height
                scaled_height = target_height_pt
                scaled_width = scaled_height * aspect_ratio

            # Create a new FPDF object for the boundary image
            boundary_pdf = FPDF()
            boundary_pdf.add_page()

            # Add the image to the boundary_pdf
            boundary_pdf.image(
                temp_path,
                x=boundary_x,
                y=boundary_y,
                w=scaled_width / 2.83465,
                h=scaled_height / 2.83465
            )

            # Add the corresponding file name to the bottom-left corner
            filename_without_ext = os.path.splitext(file_name)[0]  # Remove file extension
            
            # Check if the filename is too long and truncate if necessary
            max_filename_length = 30  # Maximum characters to display
            if len(filename_without_ext) > max_filename_length:
                filename_without_ext = filename_without_ext[:max_filename_length - 3] + "..."
            
            # Set font and size for the filename
            boundary_pdf.set_font("helvetica", size=8)  # Set font and size
            
            # Calculate the position for the filename (left-bottom corner)
            filename_x = 8  # 8 mm from the left edge
            filename_y = 290  # 290 mm from the top edge (near the bottom)
            
            # Draw the filename on the PDF
            boundary_pdf.text(filename_x, filename_y, filename_without_ext)

            # Save the temporary PDF for the current page
            boundary_pdf_path = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf").name
            boundary_pdf.output(boundary_pdf_path)

            # Merge the boundary image PDF with the new page
            boundary_pdf_page = PdfReader(boundary_pdf_path).pages[0]
            boundary_pdf_page.add_transformation([1, 0, 0, 1, boundary_x_pt, boundary_y_pt])
            new_page.merge_page(boundary_pdf_page)

            # Add the merged page to the writer
            writer.add_page(new_page)

        # Save the final PDF
        output_path = "tunnel_damage_report.pdf"
        with open(output_path, "wb") as output_file:
            writer.write(output_file)

        return output_path
    finally:
        # Clean up temporary files
        for temp_path, _ in temp_files:
            if os.path.exists(temp_path):
                os.remove(temp_path)

In [None]:
import json
import os
import cv2
import torch
import numpy as np

def export_annotations():
    """
    Generate COCO-format annotation.json file using saved "json_coords" data from results_cache.
    The file is saved in ./annotations/annotation.json and its path is returned.
    
    Returns:
        str: Path to the generated annotation file or error message if no data available.
    """
    global results_cache, class_names

    # Check if there's any processed data
    if not results_cache:
        return "No images processed to export annotations."

    # Initialize COCO format data structure
    coco_format = {
        "info": {},            # Metadata about the dataset
        "licenses": [],        # License information
        "images": [],          # List of image information
        "annotations": [],     # List of all annotations
        "categories": []       # List of object categories
    }

    # Add category information (COCO uses 1-based indexing)
    for i, name in enumerate(class_names):
        coco_format["categories"].append({
            "id": i + 1,           # Category ID (starting from 1)
            "name": name,          # Category name
            "supercategory": "none" # No supercategory
        })

    # Initialize counters for unique IDs
    annotation_id = 1  # Global annotation counter
    image_id = 1       # Image counter (starting from 1)

    # Process each image in the results cache
    for file_name, result in results_cache.items():
        original = result.get("original")
        if original is None:
            continue
            
        # Get image dimensions
        h, w = original.shape[:2]
        
        # Add image information
        coco_format["images"].append({
            "id": image_id,        # Unique image ID
            "file_name": file_name, # Original filename
            "width": w,            # Image width
            "height": h            # Image height
        })
        
        # Add all annotations for this image
        for ann in result.get("json_coords", []):
            coco_format["annotations"].append({
                "id": annotation_id,          # Unique annotation ID
                "image_id": image_id,          # Reference to image
                "category_id": ann["class_id"], # Object category
                "segmentation": [ann["segmentation"]], # Polygon coordinates
                "bbox": ann["bbox"],           # [x,y,width,height] format
                "area": ann["area"],            # Pixel area of the object
                "iscrowd": 0                    # 0 for individual instances
            })
            annotation_id += 1  # Increment annotation counter

        image_id += 1  # Increment image counter

    # Ensure output directory exists
    output_dir = "./annotations"
    os.makedirs(output_dir, exist_ok=True)
    
    # Save JSON file with pretty formatting
    output_path = os.path.join(output_dir, "annotation.json")
    with open(output_path, "w") as f:
        json.dump(coco_format, f, indent=4)

    return output_path

In [None]:
import os
import json
import tempfile
import openai

client = openai.OpenAI(api_key="")

# Define paths for original and processed annotation files
original_json_path = "annotations/annotation.json"
processed_json_path = "annotations/processed_annotation.json"

# Cache for processed COCO data to avoid redundant file reads
cached_coco_json = None

def preprocess_coco_json():
    """ 
    Load COCO JSON, remove segmentation field, and save as a new file if not already processed. 
    """
    global cached_coco_json

    if cached_coco_json is not None:
        return cached_coco_json

    if os.path.exists(processed_json_path):
        with open(processed_json_path, "r", encoding="utf-8") as f:
            cached_coco_json = json.load(f)
        return cached_coco_json

    if os.path.exists(original_json_path):
        try:
            with open(original_json_path, "r", encoding="utf-8") as f:
                coco_json = json.load(f)
        except UnicodeDecodeError:
            with open(original_json_path, "r", encoding="ISO-8859-1") as f:
                coco_json = json.load(f)

        # remove segmentation
        for annotation in coco_json.get("annotations", []):
            annotation.pop("segmentation", None)

        # save processed JSON
        with open(processed_json_path, "w", encoding="utf-8") as f:
            json.dump(coco_json, f, separators=(",", ":"), ensure_ascii=False)

        cached_coco_json = coco_json
        return cached_coco_json

    return None

def clean_generated_code(code):
    """
    Clean GPT-generated Python code.
    """
    code = code.strip()
    if code.startswith("```"):
        code = "\n".join(code.splitlines()[1:])
    if code.endswith("```"):
        code = "\n".join(code.splitlines()[:-1])
    return code.strip()

def analyze_coco_with_chatgpt(user_input):
    """
    Call OpenAI GPT to generate Python code for analyzing the COCO dataset and execute the code.
    """
    coco_json = preprocess_coco_json()
    if not coco_json:
        return "Error: annotations.json file not found", None

    system_prompt = """
    You are a data analysis assistant specialized in analyzing COCO-format JSON annotation files. The files contain three main sections: images, annotations, and categories.
    Your task is to generate Python code to perform statistical analysis based on user queries.

    Rules:
    1. Only generate statistics based on the COCO JSON file. Do not answer questions beyond this scope.
    2. The generated code must be written in Python using `json`, `collections`, and `matplotlib` for statistics and visualization.
    3. The code must be self-contained and dynamically compute all necessary variables from `"annotations/processed_annotation.json"`.
    4. Store the statistical results in the `result` variable and any Matplotlib figure in the `fig` variable.
    5. Operate `result` and `fig` at the global scope. Do not define them inside functions or classes.
    6. Ensure all required libraries (e.g., `json`, `matplotlib.pyplot`) are explicitly imported at the beginning of the code.
    """

    try:
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": f"User query: {user_input}. Please generate Python code."}
            ],
            temperature=0
        )
    except Exception as e:
        return f"OpenAI API call failed: {str(e)}", None

    generated_code = response.choices[0].message.content.strip()
    generated_code = clean_generated_code(generated_code)
    print("Generated Code:\n", generated_code)

    local_vars = {
        "json_data": coco_json,
        "result": None,
        "fig": None
    }

    try:
        exec(generated_code, local_vars, local_vars)
    except Exception as e:
        return f"Code execution error: {str(e)}", None

    result = local_vars.get("result", "Unable to retrieve results")
    fig = local_vars.get("fig", None)

    img_path = None
    if fig:
        img_path = os.path.join(tempfile.gettempdir(), "chart.png")
        try:
            fig.savefig(img_path)
        except Exception as e:
            return f"Failed to save chart: {str(e)}", None

    return result, img_path


In [35]:
import gradio as gr
import tempfile
from fpdf import FPDF

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## Tunnel panoramic image damage detection")
    
    with gr.Row():
        # Left side: controls and download (0.2 width)
        with gr.Column(scale=1):  
            image_input = gr.File(
                label="Upload Images",
                file_count="multiple",
                file_types=["image"]
            )
            confidence_slider = gr.Slider(
                0, 1, 0.2, 
                step=0.01, 
                label="Confidence Threshold"
            )
            confidence_slider_crack = gr.Slider(
                0, 1, 0.05, 
                step=0.01, 
                label="Confidence Threshold (Crack)"
            )
            submit_button = gr.Button("Process Images")
        
            image_selector = gr.Dropdown(
                choices=[],
                label="Select Image",
                interactive=True
            )
            
            # Add the legend section above the export button
            gr.HTML(legend_html)

            # Export annotation button and download link
            export_annotations_button = gr.Button("Export Annotations")
            download_annotations_link = gr.File(label="Download Annotations", file_types=[".json"])
            
            # Export pdf button and download link
            export_pdf_button = gr.Button("Export Boundaries to PDF")
            download_pdf_link = gr.File(label="Download PDF")

        # Middle and right side: Visualization Type, Original Image, and Visualization Result
        with gr.Column(scale=4):
            # Visualization Type moved to a new row above the images
            with gr.Row():
                output_choice = gr.Radio(
                    choices=["Segmentation", "Probability", "Binary", "Boundary"],
                    label="Visualization Type",
                    value="Segmentation"
                )
            
            with gr.Row():
                # Middle side: original image (0.4 width)
                with gr.Column(scale=1):  
                    original_display = gr.Image(label="Original Image")
                
                # Right side: output display (0.4 width)
                with gr.Column(scale=1):  
                    output_display = gr.Image(label="Visualization Result")

            with gr.Row():
                with gr.Column(scale=1):
                    gr.Markdown("### Statistical information query")
                    user_chat_input = gr.Textbox(label="Enter your question", placeholder="How many types of damage did we predict?")
                    chat_submit_button = gr.Button("Query Statistics")
                    chat_output = gr.Textbox(label="Statistical results", interactive=False)
                with gr.Column(scale=1):
                    chart_display = gr.Image(label="Statistical Charts", interactive=False)
        
    results_cache = {}  # Store processed results

    def process_batch(files, confidence_threshold, crack_threshold):
        global results_cache
        results_cache.clear()
        file_names = []
        
        # Process each uploaded image
        for file in files:
            file_name = os.path.basename(file.name)
            file_names.append(file_name)
            
            image = Image.open(file.name)
            img_array = np.array(image)
            result = process_image(img_array, confidence_threshold, crack_threshold)
            results_cache[file_name] = result
        
        # Update the image_selector choices dynamically
        return gr.update(choices=file_names), results_cache.get(file_names[0], {}).get("original", None), results_cache.get(file_names[0], {}).get("segmentation", None)
    
    def update_display(file_name, viz_type):
        if not file_name or file_name not in results_cache:
            return None, None
        result = results_cache[file_name]
        
        # Return the correct visualization based on the selected type
        original_image = result.get("original", None)
        if viz_type == "Segmentation":
            return original_image, result.get("segmentation", None)
        elif viz_type == "Probability":
            return original_image, result.get("probability", None)
        elif viz_type == "Binary":
            return original_image, result.get("binary", None)
        elif viz_type == "Boundary":
            return original_image, result.get("boundary", None)
        return None, None
    
    # Connect components
    submit_button.click(
        fn=process_batch,
        inputs=[image_input, confidence_slider, confidence_slider_crack],
        outputs=[image_selector, original_display, output_display]
    )
    
    image_selector.change(
        fn=update_display,
        inputs=[image_selector, output_choice],
        outputs=[original_display, output_display]
    )
    
    output_choice.change(
        fn=update_display,
        inputs=[image_selector, output_choice],
        outputs=[original_display, output_display]
    )

    # Connect export button
    export_annotations_button.click(
        fn=export_annotations,
        inputs=[],
        outputs=download_annotations_link
    )
    export_pdf_button.click(
        fn=export_to_pdf,
        inputs=[],
        outputs=download_pdf_link
    )

    # llm interface
    chat_submit_button.click(
        fn=analyze_coco_with_chatgpt,
        inputs=[user_chat_input],
        outputs=[chat_output, chart_display]
    )

demo.launch()

Running on local URL:  http://127.0.0.1:7864

To create a public link, set `share=True` in `launch()`.




100%|██████████| 24/24 [00:13<00:00,  1.73it/s]
100%|██████████| 12/12 [00:10<00:00,  1.19it/s]
100%|██████████| 6/6 [00:12<00:00,  2.06s/it]
