Hi Team Atlas!

I thought the easiest way to handle this would be to share the notebook on Kaggle, so it’s super simple to run and check that everything works.

Please make sure to run the notebook with GPU enabled right away, since I used Segment Anything for the second task (yeah, I know it’s overkill, but that’s how I started, and then I didn’t feel like reworking it).

The task you said was the easiest actually turned out to be the most annoying and tedious for me! I used Breadth-First Search for it, and, well… it is what it is.

But I really enjoyed task 4 the most. Even though I had to wrack my brain a bit, this was the best I could come up with. Overall, the task instructions weren’t super strict, so I figured that whatever wasn’t explicitly forbidden was fair game.

## Task 4

In [None]:
import cv2
import numpy as np
from skimage.measure import label, regionprops
from skimage.morphology import remove_small_objects
from skimage.util import invert
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

def preprocess_image(image_path):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    image = invert(img)
    
    sobelx = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=3)
    sobely = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=3)
    magnitude = np.sqrt(sobelx**2 + sobely**2)
    magnitude = np.uint8(magnitude * 255 / np.max(magnitude))
    
    _, edges_thresh = cv2.threshold(magnitude, 8, 255, cv2.THRESH_BINARY)
    _, thresh = cv2.threshold(magnitude, 0, 255, cv2.THRESH_BINARY)
    
    binary = thresh > 0
    filled = remove_small_objects(~binary, min_size=1500)
    result_mask = ~filled
    
    edges = edges_thresh > 0
    
    return img, edges, result_mask

def get_shape_features(region):
    moments = cv2.moments(region.image.astype(np.uint8))
    hu_moments = cv2.HuMoments(moments).flatten()
    hu_moments = -np.sign(hu_moments) * np.log10(np.abs(hu_moments))
    
    features = [
        region.eccentricity,  
        region.solidity,     
        region.extent, 
    ]
    
    return np.concatenate([hu_moments, features])

def find_and_analyze_contours(edges_image, masks_image):
    edges = edges_image.astype(np.uint8) * 255
    masks = masks_image.astype(np.uint8) * 255
    
    contours, hierarchy = cv2.findContours(edges, cv2.RETR_TREE, 
                                         cv2.CHAIN_APPROX_SIMPLE)
    
    contour_image = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR)
    cv2.drawContours(contour_image, contours, -1, (0,0,255), 2)
    
    result_mask = masks.copy()
    
    for i, contour in enumerate(contours):
        if cv2.contourArea(contour) < 100:
            continue
            
        mask = np.zeros_like(edges)
        cv2.drawContours(mask, [contour], -1, 255, -1)
        
        mask_pixels = masks[mask == 255]
        
        if len(mask_pixels) > 0:
            white_percentage = np.sum(mask_pixels == 255) / len(mask_pixels)
            if white_percentage > 0.5:
                result_mask[mask == 255] = 255
                
        M = cv2.moments(contour)
        if M['m00'] != 0:
            cx = int(M['m10']/M['m00'])
            cy = int(M['m01']/M['m00'])
            cv2.putText(contour_image, str(i), (cx, cy),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
    
    print(f"Contours: {len(contours)}")
    
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))
    
    axs[0,0].imshow(edges_image, cmap='gray')
    axs[0,0].set_title('Raw Edges')
    axs[0,0].axis('off') 
    
    axs[0,1].imshow(masks_image, cmap='gray')
    axs[0,1].set_title('Raw Masks')
    axs[0,1].axis('off') 
    
    axs[1,0].imshow(cv2.cvtColor(contour_image, cv2.COLOR_BGR2RGB))
    axs[1,0].set_title('Edges')
    axs[1,0].axis('off') 
    
    axs[1,1].imshow(result_mask, cmap='gray')
    axs[1,1].set_title('Masks')
    axs[1,1].axis('off') 
    
    plt.axis('off') 
    plt.show()
    
    return contour_image, result_mask

def separate_large_regions(mask, min_area=1500):
    kernel = np.ones((8,8), np.uint8)
    eroded = cv2.erode(mask.astype(np.uint8), kernel, iterations=2)
    dilated = cv2.dilate(eroded, kernel, iterations=2)
    
    plt.figure(figsize=(8, 8))
    plt.imshow(dilated, cmap='gray')
    plt.title('Large ROI mask')
    plt.axis('off') 
    plt.show()
    
    labeled = label(dilated)
    regions = regionprops(labeled)
    large_regions = [region for region in regions if region.area >= min_area]
    large_regions.sort(key=lambda x: x.area, reverse=True)
    
    return large_regions, labeled

def classify_shapes(regions):
    if len(regions) < 2:
        return [0] * len(regions)
    
    features = np.array([get_shape_features(region) for region in regions])
    features = (features - features.mean(axis=0)) / (features.std(axis=0) + 1e-10)
    
    kmeans = KMeans(n_clusters=3, random_state=42)
    labels = kmeans.fit_predict(features)
    
    return labels

def visualize_classification(original_img, mask, regions, labeled, shape_labels):
    colored_mask = np.zeros((*mask.shape, 3), dtype=np.uint8)
    original_with_overlay = cv2.cvtColor(mask.astype(np.uint8), cv2.COLOR_GRAY2BGR)
    final_overlay = cv2.cvtColor(original_img, cv2.COLOR_GRAY2BGR)
    
    colors_rgb = [
        [255, 30, 191],
        [200, 85, 221],
        [138, 18, 229]
    ]
    
    colors_bgr = [[color[2], color[1], color[0]] for color in colors_rgb]
    
    class_masks = [np.zeros_like(mask) for _ in range(3)]
    
    for region, shape_class in zip(regions[:12], shape_labels[:12]):
        color_bgr = colors_bgr[shape_class]
        region_mask = labeled == region.label
        
        class_masks[shape_class][region_mask] = 1
        
        colored_mask[region_mask] = color_bgr
        original_with_overlay[region_mask] = color_bgr
        final_overlay[region_mask] = color_bgr
    
    plt.figure(figsize=(15, 10))
    
    plt.subplot(221)
    plt.imshow(cv2.cvtColor(colored_mask, cv2.COLOR_BGR2RGB))
    plt.axis('off') 
    plt.title('Classification results')
    
    plt.subplot(222)
    plt.imshow(cv2.cvtColor(original_with_overlay, cv2.COLOR_BGR2RGB))
    plt.axis('off') 
    plt.title('Classification with mask')
    
    fig, axs = plt.subplots(1, len(class_masks), figsize=(20, 8))
    
    for i, ax in enumerate(axs):
        class_visualization = np.zeros((*mask.shape, 3), dtype=np.uint8)
        class_visualization[class_masks[i] == 1] = colors_bgr[i]
        
        ax.imshow(cv2.cvtColor(class_visualization, cv2.COLOR_BGR2RGB))
        ax.set_title(f'Class {i+1} object mask')
        ax.axis('off') 
    
    plt.tight_layout()
    plt.show()

    final_overlay = cv2.cvtColor(original_img, cv2.COLOR_GRAY2BGR)
    
    overlay_color = np.zeros_like(final_overlay, dtype=np.uint8)
    for i in range(len(class_masks)):
        overlay_color[class_masks[i] == 1] = colors_bgr[i]
    
    final_overlay = cv2.addWeighted(final_overlay, 1, overlay_color, 0.7, 0)
    
    plt.figure(figsize=(8, 8))
    plt.imshow(cv2.cvtColor(final_overlay, cv2.COLOR_BGR2RGB))
    plt.title('Masked Image')
    plt.axis('off')
    plt.show()


def process_image(image_path):
    original_img, edges, result_mask = preprocess_image(image_path)
    contour_image, result_mask = find_and_analyze_contours(edges, result_mask)
    
    large_regions, labeled_mask = separate_large_regions(result_mask)
    shape_labels = classify_shapes(large_regions[:12])
    
    visualize_classification(original_img, result_mask, large_regions, labeled_mask, shape_labels)

In [None]:
image_path = "/kaggle/input/3d-singleshot/4.png"
process_image(image_path)

## Task 3

In [None]:
import cv2
import numpy as np
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt
from scipy.spatial import distance
from PIL import Image

def preprocess_image(image):
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    else:
        gray = image
    _, binary = cv2.threshold(gray, 100, 255, cv2.THRESH_BINARY)
    return binary

def filter_hemispheres(binary_image, visualization=True):
    contours, _ = cv2.findContours(binary_image.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    filtered_mask = np.zeros_like(binary_image)
    hemisphere_mask = np.zeros_like(binary_image)
    
    areas = [cv2.contourArea(c) for c in contours]
    median_area = np.median(areas)
    
    for contour in contours:
        area = cv2.contourArea(contour)
        perimeter = cv2.arcLength(contour, True)
        
        circularity = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0
        
        x, y, w, h = cv2.boundingRect(contour)
        aspect_ratio = w / h if h > 0 else 0
        
        if (0.01 < area / median_area < 7 and 
            circularity > 0.01 and 
            0.01 < aspect_ratio < 7):
            cv2.drawContours(hemisphere_mask, [contour], -1, 255, -1)
        else:
            cv2.drawContours(filtered_mask, [contour], -1, 255, -1)
    
    if visualization:
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(binary_image, cmap='gray')
        axes[0].set_title('Original Binary')
        axes[1].imshow(hemisphere_mask, cmap='gray')
        axes[1].set_title('Detected Monkeys')
        axes[2].imshow(filtered_mask, cmap='gray')
        axes[2].set_title('Noise Only')
        plt.tight_layout()
        plt.show()
    
    return hemisphere_mask

def cluster_shapes(binary_mask, eps=50, min_samples=3, visualization=True):
    contours, _ = cv2.findContours(binary_mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    centroids = []
    for contour in contours:
        M = cv2.moments(contour)
        if M["m00"] != 0:
            cx = int(M["m10"] / M["m00"])
            cy = int(M["m01"] / M["m00"])
            centroids.append([cx, cy])
    
    if not centroids:
        return np.zeros_like(binary_mask)
    
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(centroids)
    labels = clustering.labels_
    
    result = np.zeros_like(binary_mask)
    colors = [(255, 0, 0), (0, 255, 0)] 
    
    if visualization:
        plt.figure(figsize=(10, 10))
        plt.imshow(binary_mask, cmap='gray')
        
        # Plot centroids and clusters
        centroids = np.array(centroids)
        unique_labels = set(labels)
        for label_idx, label in enumerate(unique_labels):
            if label == -1:
                continue
            mask = labels == label
            plt.scatter(centroids[mask, 0], centroids[mask, 1], 
                       c=[f'C{label_idx}'], label=f'Cluster {label}')
            
            # convex hull around cluster points
            cluster_points = centroids[mask].astype(np.int32)
            if len(cluster_points) >= 3: 
                hull = cv2.convexHull(cluster_points)
                plt.fill(hull[:, 0, 0], hull[:, 0, 1], alpha=0.2, c=f'C{label_idx}')
        
        plt.legend()
        plt.title('Clustered Monkeys')
        plt.show()
    
    return result

def create_chain_mask(image, binary_mask, centroids, labels, visualization=True):
    height, width = binary_mask.shape
    elegant_red = (138, 18, 229)  
    mask_rgb = np.zeros((height, width, 3), dtype=np.uint8)
    
    centroids = np.array(centroids)
    if len(centroids) == 0:
        return mask_rgb
    
    unique_labels = sorted(set(labels) - {-1})
    if not unique_labels:
        return mask_rgb
    
    cluster_sizes = [(label, np.sum(labels == label)) for label in unique_labels]
    two_largest = sorted(cluster_sizes, key=lambda x: x[1], reverse=True)[:2]
    largest_labels = [label for label, _ in two_largest]
    
    for label in largest_labels:
        mask = labels == label
        cluster_points = centroids[mask]
        
        if len(cluster_points) < 2:
            continue
        
        chain = []
        remaining_points = cluster_points.copy()
        
        start_point = remaining_points[np.argmin(remaining_points[:, 0])]
        chain.append(start_point)
        remaining_points = np.delete(remaining_points, 
                                   np.where((remaining_points == start_point).all(axis=1))[0], 
                                   axis=0)
        
        while len(remaining_points) > 0:
            current = chain[-1]
            distances = distance.cdist([current], remaining_points)
            nearest_idx = np.argmin(distances)
            nearest_point = remaining_points[nearest_idx]
            
            chain.append(nearest_point)
            remaining_points = np.delete(remaining_points, nearest_idx, axis=0)
        
        chain = np.array(chain)
        for i in range(len(chain) - 1):
            pt1 = tuple(map(int, chain[i]))
            pt2 = tuple(map(int, chain[i + 1]))
            
            cv2.line(mask_rgb, pt1, pt2, elegant_red, thickness=50)
            
            cv2.circle(mask_rgb, pt1, 8, elegant_red, -1)
            cv2.circle(mask_rgb, pt2, 8, elegant_red, -1)
    
    mask_rgb = cv2.GaussianBlur(mask_rgb, (5, 5), 0)
    
    if visualization:
        plt.figure(figsize=(12, 6))
        
        # Original binary mask
        plt.subplot(121)
        plt.imshow(image, cmap='gray')
        plt.title('Original Image')
        
        # Elegant visualization
        plt.subplot(122)
        plt.imshow(cv2.cvtColor(mask_rgb, cv2.COLOR_BGR2RGB))
        plt.title('Chain Cluster Visualization')
        
        plt.tight_layout()
        plt.show()
    
    return mask_rgb

def enhanced_cluster_shapes(binary_mask, eps=50, min_samples=3):
    contours, _ = cv2.findContours(binary_mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    centroids = []
    for contour in contours:
        M = cv2.moments(contour)
        if M["m00"] != 0:
            cx = int(M["m10"] / M["m00"])
            cy = int(M["m01"] / M["m00"])
            centroids.append([cx, cy])
    
    if not centroids:
        return [], []
    
    clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(centroids)
    labels = clustering.labels_
    
    return centroids, labels

def process_image_with_chain_mask(image):
    binary = preprocess_image(image)
    dragons_mask = filter_hemispheres(binary)
    centroids, labels = enhanced_cluster_shapes(dragons_mask)
    final_mask = create_chain_mask(image, dragons_mask, centroids, labels)
    return final_mask

In [None]:
image = Image.open("/kaggle/input/3d-singleshot/3.png").convert('RGB')
image = np.array(image)

ab = process_image_with_chain_mask(image)

## Task 2

In [None]:
image = Image.open("/kaggle/input/3d-singleshot/2.png").convert('RGB')
image = np.array(image)

plt.imshow(image)

In [None]:
!pip -q install wget
!pip -q install segment-anything

In [None]:
import torch
import os
import wget
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
destination_dir = "sam-weights"
destination_path = os.path.join(destination_dir, "sam_vit_h_4b8939.pth")

if not os.path.exists(destination_dir):
    os.makedirs(destination_dir)

if not os.path.exists(destination_path):
    wget.download(url, destination_path)
    print("-- weights downloaded --")
else:
    print("-- weights already downloaded --")

def load_sam(model_type, sam_checkpoint, device):
    print("Loading model...")
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    print(f"Shifting model to {device} device...")
    sam.to(device=device)
    return sam

model_type = "vit_h"
sam_checkpoint = "sam-weights/sam_vit_h_4b8939.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam = load_sam(model_type, sam_checkpoint, device)

def get_sam_masks(image, sam_model):
    print("Imstalling SAM mask generator...")
    mask_generator = SamAutomaticMaskGenerator(
    model = sam_model,
    points_per_side = 16,
    pred_iou_thresh = 0.9,
    crop_n_layers = 2,
    crop_n_points_downscale_factor = 2,
    min_mask_region_area = 200
    )
    print("Predicting masks...")
    masks = mask_generator.generate(image)
    return masks

def show_anns(anns):
    if len(anns) == 0:
        return
    if "area" in anns[0]:
        sorted_anns = sorted(anns, key=lambda x: x['area'], reverse=True)
    else:
        sorted_anns = anns

    ax = plt.gca()
    ax.set_autoscale_on(False)

    # Initialize RGBA image (4 channels: R, G, B, A)
    img = np.ones((sorted_anns[0]['segmentation'].shape[0], 
                   sorted_anns[0]['segmentation'].shape[1], 4))

    # Set the alpha channel to 0 (transparent background)
    img[:, :, 3] = 0

    for ann in sorted_anns:
        m = ann["segmentation"]

        # Ensure the mask is a boolean or binary array (True/False or 1/0)
        if not m.dtype == bool:
            m = m.astype(bool)

        # Create a random color with transparency (R, G, B, Alpha)
        color_mask = np.concatenate([np.random.random(3), [0.35]])

        # Apply the color_mask only where the mask (m) is True
        img[m] = color_mask  # This applies the color to areas where m is True

    ax.imshow(img)

In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt

def analyze_mask_contrast(image, mask):
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY).astype(float)
    else:
        gray = image.astype(float)
    
    mask_values = gray[mask]
    
    metrics = {
        'min_val': np.min(mask_values),
        'max_val': np.max(mask_values),
        'contrast_range': np.max(mask_values) - np.min(mask_values),
        'std_dev': np.std(mask_values),
        'percentile_range': np.percentile(mask_values, 95) - np.percentile(mask_values, 5)
    }
    
    hist, bins = np.histogram(mask_values, bins=50)
    peaks = len([i for i in range(1, len(hist)-1) if hist[i] > hist[i-1] and hist[i] > hist[i+1]])
    metrics['peak_count'] = peaks
    
    return metrics

def filter_curved_surfaces(image, masks):
    mask_info = []
    
    for idx, mask_data in enumerate(masks):
        mask = mask_data['segmentation']
        if mask_data['area'] < 500: 
            continue
            
        metrics = analyze_mask_contrast(image, mask)
        
        contrast_score = (
            0.4 * metrics['contrast_range'] + 
            0.3 * metrics['percentile_range'] + 
            0.3 * metrics['std_dev']  
        )
        
        mask_info.append({
            'index': idx,
            'mask': mask,
            'area': mask_data['area'],
            'metrics': metrics,
            'score': contrast_score
        })
    
    sorted_masks = sorted(mask_info, key=lambda x: x['score'], reverse=True)
    
    threshold = np.mean([m['score'] for m in sorted_masks]) * 1.2
    selected_masks = [m for m in sorted_masks if m['score'] > threshold]
    
    return selected_masks[:3]

def visualize_results(image, masks):
    plt.figure(figsize=(20, 20))
    
    # 1. Первая строка: Original Image, Local Contrast Map, Intensity Histograms
    plt.subplot(3, 3, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')
    
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    contrast_map = np.zeros_like(gray, dtype=float)
    window_size = 5
    for i in range(window_size // 2, gray.shape[0] - window_size // 2):
        for j in range(window_size // 2, gray.shape[1] - window_size // 2):
            window = gray[i - window_size // 2:i + window_size // 2 + 1,
                          j - window_size // 2:j + window_size // 2 + 1]
            contrast_map[i, j] = np.max(window) - np.min(window)
    
    plt.subplot(3, 3, 2)
    plt.imshow(contrast_map, cmap='hot')
    plt.title('Local Contrast Map')
    plt.axis('off')
    
    plt.subplot(3, 3, 3)
    colors = ['r', 'g', 'b']
    for idx, mask_info in enumerate(masks):
        mask = mask_info['mask']
        mask_values = gray[mask]
        plt.hist(mask_values, bins=50, alpha=0.5, color=colors[idx], 
                 label=f'Surface {idx+1}\nContrast: {mask_info["score"]:.1f}')
    plt.title('Intensity Histograms')
    plt.legend()
    plt.axis('on')
    
    # 2. Вторая строка: Surface 1, Surface 2, Surface 3
    for idx, mask_info in enumerate(masks):
        plt.subplot(3, 3, 4 + idx)
        mask_vis = image.copy()
        mask = mask_info['mask']
        mask_vis[~mask] = mask_vis[~mask] * 0.3  # Затемняем фон
        mask_vis[mask] = np.array([138, 18, 229])  # Закрашиваем маски заданным цветом
        plt.imshow(mask_vis)
        plt.title(f'Surface {idx+1}\nContrast Score: {mask_info["score"]:.1f}')
        plt.axis('off')
    
    # 3. Третья строка: Combined Masks на оригинальном изображении
    combined_image = image.copy()
    overlay = np.zeros_like(image, dtype=np.float32)
    color = np.array([138, 18, 229])  # Цвет масок
    
    for mask_info in masks:
        mask = mask_info['mask']
        overlay[mask] += color
    
    overlay = np.clip(overlay, 0, 255).astype(np.uint8)  # Ограничиваем значения
    combined_image = cv2.addWeighted(image, 0.5, overlay, 0.5, 0)
    
    plt.subplot(3, 1, 3)  # Крупная визуализация внизу
    plt.imshow(combined_image)
    plt.title('Combined Masks on Original')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return masks

def process_image(image, sam_masks):
    selected_masks = filter_curved_surfaces(image, sam_masks)
    result = visualize_results(image, selected_masks)
    """
    print("\nCurved surfaces with high contrast:")
    for idx, mask_info in enumerate(selected_masks):
        metrics = mask_info['metrics']
        print(f"\nSurface {idx+1}:")
        print(f"Area: {mask_info['area']:.0f} pixels")
        print(f"Contrast score: {mask_info['score']:.1f}")
        print(f"Min-Max range: {metrics['contrast_range']:.1f}")
        print(f"95th percentile range: {metrics['percentile_range']:.1f}")"""
    
    return selected_masks

In [None]:
sam_masks = get_sam_masks(image, sam)

selected_masks = process_image(image, sam_masks)
plt.show()

## Task 1

In [None]:
import numpy as np
from scipy.spatial import KDTree
from PIL import Image
import matplotlib.pyplot as plt
from collections import deque
from typing import List, Tuple, Dict, Set
from collections import defaultdict

def get_unique_pixels(image: np.ndarray) -> Dict[Tuple[int, int, int], int]:
    pixels = image.reshape(-1, 3)
    unique_pixels = defaultdict(int)
    
    for pixel in pixels:
        unique_pixels[tuple(pixel)] += 1
    
    return dict(unique_pixels)

def validate_colors(image: np.ndarray, color_dict: Dict[str, List[int]]) -> None:
    unique_pixels = get_unique_pixels(image)
    valid_colors = {tuple(color) for color in color_dict.values()}
    
    print("Detected unique colors:")
    for color, count in unique_pixels.items():
        status = "✓" if color in valid_colors else "×"
        print(f"{status} RGB{color}: {count} pixels")
        
    invalid_colors = set(unique_pixels.keys()) - valid_colors
    if invalid_colors:
        print("\nDetected unknown colors:")
        for color in invalid_colors:
            print(f"RGB{color}")

def find_positions(image: np.ndarray, target_color: List[int]) -> List[Tuple[int, int]]:
    positions = np.where(np.all(image == target_color, axis=2))
    return list(zip(positions[0], positions[1]))

def get_neighbors(pos: Tuple[int, int], image: np.ndarray) -> List[Tuple[int, int]]:
    y, x = pos
    height, width = image.shape[:2]
    neighbors = []
    
    for dy in [-1, 0, 1]:
        for dx in [-1, 0, 1]:
            if dy == 0 and dx == 0:
                continue
            new_y, new_x = y + dy, x + dx
            if 0 <= new_y < height and 0 <= new_x < width:
                neighbors.append((new_y, new_x))
    
    return neighbors

def map_to_closest_color(image: np.ndarray, color_dict: Dict[str, List[int]]) -> np.ndarray:
    color_list = np.array(list(color_dict.values()))
    tree = KDTree(color_list)
    
    reshaped_image = image.reshape(-1, 3)
    
    _, indices = tree.query(reshaped_image)
    mapped_colors = color_list[indices]
    
    return mapped_colors.reshape(image.shape)

def visualize_image(original: np.ndarray, processed: np.ndarray) -> None:
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].imshow(original)
    ax[0].set_title("Source Image")
    ax[0].axis("off")
    
    ax[1].imshow(processed)
    ax[1].set_title("Processed Image")
    ax[1].axis("off")
    
    plt.tight_layout()
    plt.show()

def get_terrain_type(color: np.ndarray, color_dict: Dict[str, List[int]]) -> str:
    color_tuple = tuple(color)
    for terrain, rgb in color_dict.items():
        if color_tuple == tuple(rgb):
            return terrain
    return "UNKNOWN"

def is_valid_move(from_terrain: str, to_terrain: str, last_terrain: str) -> bool:
    if to_terrain == 'ABYSS':
        return False
        
    if to_terrain in ['START', 'END']:
        return True
        
    if from_terrain in ['START', 'END']:
        return to_terrain != 'ABYSS'
        
    if to_terrain == 'RAMP':
        return True
        
    if from_terrain == 'RAMP':
        return True
        
    if from_terrain == to_terrain:
        return True
        
    if (from_terrain == 'SAND' and to_terrain == 'MOUNTAIN') or \
       (from_terrain == 'MOUNTAIN' and to_terrain == 'SAND'):
        return last_terrain == 'RAMP'
        
    return False

def get_neighbors(pos: Tuple[int, int], image: np.ndarray) -> List[Tuple[int, int]]:
    y, x = pos
    height, width = image.shape[:2]
    neighbors = []
    
    for dy, dx in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
        new_y, new_x = y + dy, x + dx
        if 0 <= new_y < height and 0 <= new_x < width:
            neighbors.append((new_y, new_x))
    
    return neighbors

def find_path(image: np.ndarray, color_dict: Dict[str, List[int]]) -> List[Tuple[int, int]]:
    start_pos = find_positions(image, color_dict['START'])[0]
    end_pos = find_positions(image, color_dict['END'])[0]
    
    height, width = image.shape[:2]
    distances = np.full((height, width), -1, dtype=int)
    distances[start_pos] = 0
    
    previous = {}
    last_terrain = {}
    
    queue = deque([(start_pos, 'START', None)])
    
    while queue:
        current_pos, current_terrain, prev_terrain = queue.popleft()
        
        if current_pos == end_pos:
            break
            
        for next_pos in get_neighbors(current_pos, image):
            if distances[next_pos] != -1: 
                continue
                
            next_terrain = get_terrain_type(image[next_pos], color_dict)
            
            if not is_valid_move(current_terrain, next_terrain, prev_terrain):
                continue
                
            distances[next_pos] = distances[current_pos] + 1
            previous[next_pos] = current_pos
            last_terrain[next_pos] = current_terrain
            queue.append((next_pos, next_terrain, current_terrain))
    
    if distances[end_pos] == -1:
        raise ValueError("Path not found")
    
    path = []
    current = end_pos
    while current is not None:
        path.append(current)
        current = previous.get(current)
    
    return path[::-1]

def visualize_path(image: np.ndarray, path: List[Tuple[int, int]], color_dict: Dict[str, List[int]], thickness: int = 7) -> np.ndarray:
    result = image.copy()
    height, width = image.shape[:2]
    
    for pos in path:
        y, x = pos
        for dy in range(-thickness//2, thickness//2 + 1):
            for dx in range(-thickness//2, thickness//2 + 1):
                if dy*dy + dx*dx <= (thickness//2)**2:
                    new_y, new_x = y + dy, x + dx
                    if 0 <= new_y < height and 0 <= new_x < width:
                        result[new_y, new_x] = color_dict['PATH']
    
    return result

In [None]:
RAMP = [116, 116, 116]
SAND = [147, 139, 101]
MOUNTAIN = [94, 81, 22]
ABYSS = [0, 0, 0]
START = [234, 51, 35]
END = [117, 251, 86]
PATH = [138, 18, 229]

color_dict = {
    'RAMP': RAMP, 
    'SAND': SAND, 
    'MOUNTAIN': MOUNTAIN, 
    'ABYSS': ABYSS, 
    'START': START, 
    'END': END, 
    'PATH': PATH
}

image = Image.open("/kaggle/input/3d-singleshot/1.png").convert('RGB')
image = np.array(image)

processed_image = map_to_closest_color(image, color_dict)

visualize_image(image, processed_image)
validate_colors(processed_image, color_dict)

path = find_path(processed_image, color_dict)

result_image = visualize_path(processed_image, path, {'PATH': PATH}, thickness=10)

plt.imshow(result_image)
plt.show()