# Visualize

In [5]:
# big
# scene_1000_2000_1000
# scene_1000_1500_1000

# small 
# scene_1000_2000_150_128_5
# scene_1000_1500_70_128_5


filename = "dhaka_1000_20000_1000_128_coverage_0_33.txt"
checkpoint_path_load = 'checkpoint.pth'

import pygame
import sys
import numpy as np

from typing import List
import math

import torch
from torch.utils.data import Dataset, DataLoader
import random


class SceneDataset(Dataset):
    def __init__(self, data, do_augmentation=True):
        self.data = data
        self.do_augmentation = do_augmentation
        self.cache = []  # To store parsed scenes
        self.samples = []  # To store indices of individual samples

        # Read and parse the file during initialization
        self._parse_data()

    def _parse_data(self):
        line = self.data.strip().split()

        # Parse obstacles
        obstacles = []
        idx = 0
        while idx < len(line) and line[idx] == 'p':
            idx += 1  # Skip 'p'
            obstacle = []
            while idx < len(line) and line[idx] != 'p' and line[idx] != 'q':
                x, y = float(line[idx]), float(line[idx + 1])
                obstacle.append((x, y))
                idx += 2
            obstacles.append(obstacle)

        # Parse queries and labels
        queries = []
        while idx < len(line):
            if line[idx] == 'q':
                idx += 1  # Skip 'q'
                query = []
                for _ in range(2):  # Each query has 2 coordinate pairs
                    x, y = float(line[idx]), float(line[idx + 1])
                    query.append((x, y))
                    idx += 2
                label = int(line[idx])  # Label follows the query
                idx += 1
                queries.append((query, label))

        # Cache the parsed scene
        scene_idx = len(self.cache)
        self.cache.append((obstacles, queries))

        # Index individual samples
        for query_idx in range(len(queries)):
            self.samples.append((scene_idx, query_idx))

    def _generate_vertex_order(self, vertices):
        """Generates cyclic random orders for vertices."""
        n = len(vertices)
        orders = []
        r = 1
        if self.do_augmentation:
            r = 2
        for _ in range(r):
            start_idx = random.randint(0, n - 1)
            orders.append(list(range(start_idx, n)) + list(range(0, start_idx)))
        return orders

    def _generate_obstacle_order(self, num_obstacles, num_order):
        """Generates random orders for obstacles."""
        orders = []
        if not self.do_augmentation:
            num_order = 1
        for _ in range(num_order):
            orders.append(random.sample(range(num_obstacles), num_obstacles))
        return orders

    def __len__(self):
        """Total number of samples (queries)."""
        return len(self.samples)

    def __getitem__(self, idx):
        """Returns a single sample: obstacles, query, label, and augmentation orders."""
        scene_idx, query_idx = self.samples[idx]
        obstacles, queries = self.cache[scene_idx]
        query, label = queries[query_idx]

        # Generate augmentation orders
        vertex_orders = [self._generate_vertex_order(obstacle) for obstacle in obstacles]
        obstacle_order = self._generate_obstacle_order(len(obstacles),3)

        return {
            'obstacles': obstacles,        # Original obstacle coordinates
            'query': query,                # Original query coordinates
            'label': label,                # Binary label (0 or 1)
            'vertex_orders': vertex_orders, # Vertex augmentation orders per obstacle
            'obstacle_order': obstacle_order # Obstacle augmentation orders
        }

# Custom collate function for batching
def collate_fn(batch):
    """Prepares a batch by grouping obstacles, queries, labels, and orders."""
    obstacle_batch = []
    query_batch = []
    label_batch = []
    vertex_orders_batch = []
    obstacle_orders_batch = []

    for item in batch:
        obstacle_batch.append(item['obstacles'])
        query_batch.append(item['query'])
        label_batch.append(item['label'])
        vertex_orders_batch.append(item['vertex_orders'])
        obstacle_orders_batch.append(item['obstacle_order'])

    return {
        'obstacles': obstacle_batch,          # List of obstacles for each sample
        'queries': query_batch,               # List of queries for each sample
        'labels': torch.tensor(label_batch, dtype=torch.float),  # Labels as tensor
        'vertex_orders': vertex_orders_batch, # Vertex augmentation orders
        'obstacle_orders': obstacle_orders_batch # Obstacle augmentation orders
    }

import os
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix
from tqdm import tqdm

# Define the model
class SceneQueryModel(nn.Module):
    def __init__(self, vertex_input_dim, obstacle_hidden_dim, scene_hidden_dim, output_dim):
        super(SceneQueryModel, self).__init__()

        # Shared RNN block for obstacles and queries
        self.rnn_obstacle = nn.LSTM(input_size=vertex_input_dim, hidden_size=obstacle_hidden_dim, 
                                    num_layers=1, batch_first=True, dropout=0.2)

        self.obstacle_embedding_fc = nn.Sequential(
            nn.Linear(obstacle_hidden_dim, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32)
        )

        # Scene-level RNN block
        self.rnn_scene = nn.LSTM(input_size=32, hidden_size=scene_hidden_dim, 
                                 num_layers=1, batch_first=True, dropout=0.2)

        self.scene_embedding_fc = nn.Sequential(
            nn.Linear(scene_hidden_dim, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128)
        )

        # Final classification block
        self.classifier = nn.Sequential(
            nn.Linear(128 + 32, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Linear(32, 8),
            nn.ReLU(),
            nn.BatchNorm1d(8),
            nn.Linear(8, output_dim),
            nn.Sigmoid()
        )

    def forward(self, obstacles, queries, vertex_orders, obstacle_orders):
        device = next(self.parameters()).device  # Automatically get the device of the model
    
        # Flatten obstacles and vertex_orders for batch processing
        flat_obstacles = []
        flat_orders = []
        batch_indices = []
        obstacle_indices = []
        flat_obstacle_indices = []
        order_lengths = []
    
        m = 0
        for i, (obstacle_set, vertex_order_set) in enumerate(zip(obstacles, vertex_orders)):
            for j, (obstacle, orders) in enumerate(zip(obstacle_set, vertex_order_set)):
                flat_obstacles.append(torch.tensor(obstacle, dtype=torch.float, device=device))
                for order in orders:
                    flat_obstacle_indices.append(m)
                    flat_orders.append(order)
                    batch_indices.append(i)
                    obstacle_indices.append(j)
                    order_lengths.append(len(order))
                m += 1
    
        # Pad vertex orders
        max_order_length = max(order_lengths)
        padded_orders = torch.zeros((len(flat_orders), max_order_length), dtype=torch.long, device=device)
        for idx, order in enumerate(flat_orders):
            padded_orders[idx, :len(order)] = torch.tensor(order, dtype=torch.long, device=device)
    
        # Pad obstacles
        max_vertices = max(len(obs) for obs in flat_obstacles)
        padded_obstacles = torch.zeros((len(flat_obstacles), max_vertices, 2), dtype=torch.float, device=device)
        for idx, obs in enumerate(flat_obstacles):
            padded_obstacles[idx, :len(obs)] = obs
    
        # Reorder vertices according to padded orders
        ordered_vertices = []
        for k in range(len(padded_orders)):
            order = padded_orders[k]
            vertices = padded_obstacles[flat_obstacle_indices[k]]
            ordered_vertices.append(vertices[order])
    
        ordered_vertices = torch.stack(ordered_vertices)
    
        # Create sequence lengths for packing
        sequence_lengths = torch.tensor([len(order) for order in flat_orders], device=device)
        # Move sequence_lengths to the CPU and convert to int64 for compatibility
        sequence_lengths = sequence_lengths.cpu().to(torch.int64)
    
        # Pack the sequences for RNN
        packed_vertices = nn.utils.rnn.pack_padded_sequence(ordered_vertices, sequence_lengths, batch_first=True, enforce_sorted=False)
        _, (h_n, _) = self.rnn_obstacle(packed_vertices)
    
        # Compute embeddings for each order
        embeddings = self.obstacle_embedding_fc(h_n[-1])  # Shape: (total_orders, embedding_output_size)
    
        # Aggregate embeddings back to obstacle level
        obstacle_embeddings = torch.zeros((len(obstacles), max([len(o) for o in obstacles]), embeddings.size(-1)), device=device)
        order_counts = torch.zeros_like(obstacle_embeddings[..., 0])  # For averaging
    
        for i, (batch_idx, obstacle_idx) in enumerate(zip(batch_indices, obstacle_indices)):
            obstacle_embeddings[batch_idx, obstacle_idx] += embeddings[i]
            order_counts[batch_idx, obstacle_idx] += 1
    
        # Avoid division by zero and compute the mean
        obstacle_embeddings /= order_counts.unsqueeze(-1).clamp(min=1)
    
        flat_orders = []
        batch_indices = []
        order_lengths = []
    
        for i, orders in enumerate(obstacle_orders):
            for order in orders:
                flat_orders.append(order)
                batch_indices.append(i)
                order_lengths.append(len(order))
    
        max_order_length = max(order_lengths)
        padded_orders = torch.zeros((len(flat_orders), max_order_length), dtype=torch.long, device=device)
        for idx, order in enumerate(flat_orders):
            padded_orders[idx, :len(order)] = torch.tensor(order, dtype=torch.long, device=device)
    
        ordered_obstacles = []
        for k in range(len(padded_orders)):
            order = padded_orders[k]
            embed = obstacle_embeddings[batch_indices[k]]
            ordered_obstacles.append(embed[order])
    
        ordered_obstacles = torch.stack(ordered_obstacles)
    
        # Create sequence lengths for packing
        sequence_lengths = torch.tensor([len(order) for order in flat_orders], device=device)
        # Move sequence_lengths to the CPU and convert to int64 for compatibility
        sequence_lengths = sequence_lengths.cpu().to(torch.int64)
    
        # Pack the sequences for RNN
        packed_obstacles = nn.utils.rnn.pack_padded_sequence(ordered_obstacles, sequence_lengths, batch_first=True, enforce_sorted=False)
        _, (h_n, _) = self.rnn_scene(packed_obstacles)
    
        # Compute embeddings for each order
        flat_scene_embeddings = self.scene_embedding_fc(h_n[-1])   # Shape: (total_orders, embedding_output_size)
    
        # Aggregate embeddings back to obstacle level
        scene_embeddings = torch.zeros((len(obstacles), flat_scene_embeddings.size(-1)), device=device)
        order_counts = torch.zeros_like(scene_embeddings[..., 0])  # For averaging
    
        for i, batch_idx in enumerate(batch_indices):
            scene_embeddings[batch_idx] += flat_scene_embeddings[i]
            order_counts[batch_idx] += 1
    
        # Avoid division by zero and compute the mean
        scene_embeddings /= order_counts.unsqueeze(-1).clamp(min=1)
    
        # Convert queries into a tensor
        queries_tensor = torch.tensor(queries, dtype=torch.float, device=device)  # Shape: [batch_size, seq_len, feature_dim]
    
        # Pass the batch through the RNN
        _, (h_n, _) = self.rnn_obstacle(queries_tensor)  # h_n shape: [num_layers * num_directions, batch_size, hidden_size]
        
        # Use the last layer's hidden state (for standard RNN or GRU, use h_n[-1]; for LSTM, use the hidden state only)
        query_embeddings = self.obstacle_embedding_fc(h_n[-1])  # Shape: [batch_size, embedding_dim]
    
        # Concatenate query and scene embeddings
        combined = torch.cat((query_embeddings, scene_embeddings), dim=1)
    
        # Classification
        outputs = self.classifier(combined)
    
        return outputs


def calculate_metrics(outputs, labels):
    # Convert probabilities to binary predictions
    predictions = (outputs > 0.5).float()

    # Flatten the tensors
    predictions = predictions.view(-1)
    labels = labels.view(-1)
    
    # Calculate confusion matrix
    tn, fp, fn, tp = confusion_matrix(labels.cpu(), predictions.cpu(), labels=[0, 1]).ravel()
    
    # Calculate metrics
    accuracy = accuracy_score(labels.cpu(), predictions.cpu())
    precision = precision_score(labels.cpu(), predictions.cpu(), zero_division=0)
    recall = recall_score(labels.cpu(), predictions.cpu(), zero_division=0)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    f1 = f1_score(labels.cpu(), predictions.cpu(), zero_division=0)

    return accuracy, precision, recall, specificity, f1

def predict(model, test_loader):
    
    # Check if a checkpoint exists and load it
    if os.path.exists(checkpoint_path_load):
        print("Loading checkpoint...")
        checkpoint = torch.load(checkpoint_path_load, weights_only=True, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])

    # Set model to evaluation mode
    model.eval()
    all_outputs = []
    
    with torch.no_grad():
        for batch in test_loader:
            obstacles = batch['obstacles']
            queries = batch['queries']
            labels = batch['labels'].to(device)
            vertex_orders = batch['vertex_orders']
            obstacle_orders = batch['obstacle_orders']
    
            labels = labels.unsqueeze(1)
    
            outputs = model(obstacles, queries, vertex_orders, obstacle_orders)
            
            all_outputs.append(outputs.cpu())
    
    # Concatenate all outputs and labels
    all_outputs = torch.cat(all_outputs)
    
    return all_outputs

vertex_input_dim = 2
obstacle_hidden_dim = 128
scene_hidden_dim = 512
output_dim = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SceneQueryModel(vertex_input_dim, obstacle_hidden_dim, scene_hidden_dim, output_dim).to(device)


def round_down(value: float) -> int:
    # Extract the integer part and the fractional part
    integer_part = int(value)
    fractional_part = value - integer_part

    # Check if the fractional part is less than or equal to 0.5
    if fractional_part <= 0.5:
        return integer_part  # Round down
    else:
        return integer_part + 1  # Round up
    

def bound(value, min_value, max_value):
    
    return max(min(value, max_value), min_value)


class SceneEncoder:
    def __init__(self, minX, maxX, minY, maxY, obstacles):
        self.minX = minX
        self.minY = minY
        self.maxX = maxX
        self.maxY = maxY
        self.obstacles = obstacles
        self.scene_description = ""
        for id in self.obstacles:
            self.scene_description += "p "
            for (x,y) in self.obstacles[id]:
                self.scene_description += str((x-self.minX)/(self.maxX-self.minX)) + " "
                self.scene_description += str((y-self.minY)/(self.maxY-self.minY)) + " "


def generate_pygame_data(grid,scene,cell):
    prediction = None
    scene_dimension = scene.shape
    index = 0
    updated_scene = scene.copy()
    width = 1/scene_dimension[1]
    height = 1/scene_dimension[0]
    query = ""
    for i in range(scene_dimension[0]):
        for j in range(scene_dimension[1]):
            if (i,j)==cell:
                continue
            query += "q "
            query += str(j*width+width/2) + " "
            query += str(i*height+height/2) + " "
            query += str(cell[1]*width+width/2) + " "
            query += str(cell[0]*height+height/2) + " "
            query += "0 "
    test_dataset = SceneDataset(grid.scene_description+query,do_augmentation=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

    prediction = (predict(model,test_loader) > 0.5).float()
    for i in range(scene_dimension[0]):
        for j in range(scene_dimension[1]):
            if (i,j)==cell:
                updated_scene[i,j,1] = 2
                continue
            if (prediction[index]==1):
                updated_scene[i,j,1] = 3
            else:
                updated_scene[i,j,1] = 4
            index = index + 1
    return updated_scene

def is_inside(vertex, x_min, x_max, y_min, y_max, edge):
    x, y = vertex
    if edge == 0:  # Left edge: x >= x_min
        return x >= x_min
    elif edge == 1:  # Right edge: x <= x_max
        return x <= x_max
    elif edge == 2:  # Bottom edge: y >= y_min
        return y >= y_min
    elif edge == 3:  # Top edge: y <= y_max
        return y <= y_max
    return False

def intersect(v1, v2, x_min, x_max, y_min, y_max, edge):
    x1, y1 = v1
    x2, y2 = v2
    result = [0, 0]
    if edge == 0:  # Left edge: x = x_min
        slope = (y2 - y1) / (x2 - x1)
        result[0] = x_min
        result[1] = y1 + slope * (x_min - x1)
    elif edge == 1:  # Right edge: x = x_max
        slope = (y2 - y1) / (x2 - x1)
        result[0] = x_max
        result[1] = y1 + slope * (x_max - x1)
    elif edge == 2:  # Bottom edge: y = y_min
        slope = (x2 - x1) / (y2 - y1)
        result[1] = y_min
        result[0] = x1 + slope * (y_min - y1)
    elif edge == 3:  # Top edge: y = y_max
        slope = (x2 - x1) / (y2 - y1)
        result[1] = y_max
        result[0] = x1 + slope * (y_max - y1)
    return tuple(result)

def sutherland_hodgman(poly, x_min, x_max, y_min, y_max):
    input_vertices = poly[:]
    for edge in range(4):  # Iterate through all 4 edges
        output_vertices = []
        for i in range(len(input_vertices)):
            curr = input_vertices[i]
            prev = input_vertices[(i + len(input_vertices) - 1) % len(input_vertices)]

            curr_inside = is_inside(curr, x_min, x_max, y_min, y_max, edge)
            prev_inside = is_inside(prev, x_min, x_max, y_min, y_max, edge)

            if prev_inside and curr_inside:
                output_vertices.append(curr)
            elif prev_inside and not curr_inside:
                output_vertices.append(intersect(prev, curr, x_min, x_max, y_min, y_max, edge))
            elif not prev_inside and curr_inside:
                output_vertices.append(intersect(prev, curr, x_min, x_max, y_min, y_max, edge))
                output_vertices.append(curr)
        input_vertices = output_vertices
    return output_vertices


def read_obstacle_file(filename):
    obstacles = {}
    scenes = []
    with open(filename, 'r') as file:
        for line in file:
            if line.split()[0]=="a":
                type,scene_id,obstacle_id, vertex_id, x, y = line.split()
                scene_id = int(scene_id)
                obstacle_id = int(obstacle_id)
                x, y = float(x), float(y)
                if obstacle_id not in obstacles:
                    obstacles[obstacle_id] = []
                obstacles[obstacle_id].append((x, y))
            elif line.split()[0]=='b':
                type,min_x,max_x,min_y,max_y = line.split()
                min_x, max_x, min_y, max_y = float(min_x), float(max_x), float(min_y), float(max_y)
                # min_x, max_x, min_y, max_y = 0,1,0,1
                scenes.append(SceneEncoder(min_x,max_x,min_y,max_y,obstacles))
                obstacles = {}
    return scenes


def transform_coordinates(vertices, offset, scale, min_x, min_y, screen_height):
    """Transform vertices for Cartesian coordinates (y increases upward)."""
    return [
        (offset[0] + (x - min_x) * scale, 
         screen_height - (offset[1] + (y - min_y) * scale)) 
        for x, y in vertices
    ]

def clamp(value, min_value, max_value):
    """Clamp a value to be within a range."""
    return max(min_value, min(value, max_value))

def get_corner_coordinates(min_x, max_x, min_y, max_y, scale, offset, screen_size):
    # Top-left
    top_left_x = min_x + (0 - offset[0]) / scale
    top_left_y = min_y + (screen_size - offset[1]) / scale
    # Top-right
    top_right_x = min_x + (screen_size - offset[0]) / scale
    top_right_y = min_y + (screen_size - offset[1]) / scale
    # Bottom-left
    bottom_left_x = min_x + (0 - offset[0]) / scale
    bottom_left_y = min_y + (0 - offset[1]) / scale
    # Bottom-right
    bottom_right_x = min_x + (screen_size - offset[0]) / scale
    bottom_right_y = min_y + (0 - offset[1]) / scale

    return [(top_left_x,top_left_y),(top_right_x,top_right_y),(bottom_left_x,bottom_left_y),(bottom_right_x,bottom_right_y)]


def draw_coordinates_on_screen(screen, font, min_x, max_x, min_y, max_y, scale, offset, screen_size):
    """Draw corner coordinates (Cartesian) on the screen."""
       

    [(top_left_x,top_left_y),(top_right_x,top_right_y),(bottom_left_x,bottom_left_y),(bottom_right_x,bottom_right_y)] = get_corner_coordinates(min_x, max_x, min_y, max_y, scale, offset, screen_size)

    # Render coordinates as text
    corners = [
        (f"({top_left_x:.2f}, {top_left_y:.2f})", (10, 10)),  # Top-left
        (f"({top_right_x:.2f}, {top_right_y:.2f})", (screen_size - 180, 10)),  # Top-right
        (f"({bottom_left_x:.2f}, {bottom_left_y:.2f})", (10, screen_size - 30)),  # Bottom-left
        (f"({bottom_right_x:.2f}, {bottom_right_y:.2f})", (screen_size - 180, screen_size - 30)),  # Bottom-right
    ]

    for text, position in corners:
        rendered_text = font.render(text, True, (0, 0, 0))
        screen.blit(rendered_text, position)


def draw_scene_id(screen, font, screen_size, idx):
    rendered_text = font.render(f"{idx}", True, (255, 0, 0))
    screen.blit(rendered_text, (screen_size//2,10))



WHITE = (255, 255, 255)
BLACK = (0,0,0)
RED = (255, 0, 0)
GREEN = (0, 255, 0)
BLUE = (0, 0, 255)
GRAY = (200, 200, 200)
background_color = (255, 255, 255)  # White
obstacle_fill_color = (150, 150, 150)  # Gray
obstacle_outline_color = (0, 0, 255)  # Blue

glass_alpha = 20


color_map = {
    0: WHITE,
    1: obstacle_fill_color,
    2: BLUE,
    3: GREEN,
    4: RED
}

def draw_grid(DIS,data,CELL_WIDTH,CELL_HEIGHT,screen_size):
    for row in range(data.shape[0]):
        for col in range(data.shape[1]):
            pygame.draw.rect(DIS, color_map[data[row][col][0]], (col * CELL_WIDTH, screen_size-(row+1) * CELL_HEIGHT, CELL_WIDTH, CELL_HEIGHT))
            # glass_surface = pygame.Surface((CELL_WIDTH, CELL_HEIGHT), pygame.SRCALPHA)
            if (data[row][col][1]>2):
                # glass_surface.fill((*color_map[data[row][col][1]], glass_alpha))
                # DIS.blit(glass_surface, (col * CELL_WIDTH, screen_size- (row+1) * CELL_HEIGHT))
                pygame.draw.rect(DIS, color_map[data[row][col][1]], (col * CELL_WIDTH + (CELL_WIDTH-CELL_WIDTH//4)//2, screen_size- (row+1) * CELL_HEIGHT + (CELL_HEIGHT-CELL_HEIGHT//4)//2, CELL_WIDTH//4, CELL_HEIGHT//4))
            elif (data[row][col][1]==2):
                # glass_surface.fill((*color_map[data[row][col][1]], glass_alpha))
                # DIS.blit(glass_surface, (col * CELL_WIDTH, screen_size- (row+1) * CELL_HEIGHT))
                
                pygame.draw.rect(DIS, color_map[data[row][col][1]], (col * CELL_WIDTH + CELL_WIDTH//4, screen_size- (row+1) * CELL_HEIGHT + CELL_HEIGHT//4, CELL_WIDTH//2, CELL_HEIGHT//2))

def get_cell_under_mouse(mouse_pos,CELL_WIDTH,CELL_HEIGHT,ROWS,COLS):
    x, y = mouse_pos
    col = x // CELL_WIDTH
    row = y // CELL_HEIGHT
    if row < ROWS and col < COLS:
        return (int(row), int(col))
    return

def draw_select(screen,start_pos,end_pos):
    glass_surface = pygame.Surface((abs(start_pos[0]-end_pos[0]), abs(start_pos[1]-end_pos[1])), pygame.SRCALPHA)
    glass_surface.fill((*(100,100,100), 100))
    screen.blit(glass_surface, (min(start_pos[0],end_pos[0]),min(start_pos[1],end_pos[1])))

def main():
    # Initialize pygame
    pygame.init()

    # Screen dimensions (square screen)
    screen_size = 704
    screen = pygame.display.set_mode((screen_size, screen_size))
    pygame.display.set_caption("Obstacle Visualization with Zoom and Pan")

    # Colors
    
    # Load obstacles
    scenes = read_obstacle_file(filename)
    
    idx = 0
    
    # Calculate scene bounds and scale
    min_x, max_x,min_y, max_y = scenes[idx].minX, scenes[idx].maxX, scenes[idx].minY, scenes[idx].maxY
    obstacles = scenes[idx].obstacles
    scene_width = max_x - min_x
    scene_height = max_y - min_y

    scale = screen_size / min(scene_width, scene_height)

    offset = [0, 0]

    # Movement and zoom variables
    zoom_factor = 1.05
    dragging = False
    drag_start = (0, 0)
    move_speed = 5
    moving = {"up": False, "down": False, "left": False, "right": False}
    zooming = {"in": False, "out": False}
    current_scale = scale
    special_mode = False
    right_dragging = False
    start_pos = None
    end_pos = None
    bounding_box = True

    cellCount = [32,64,128]
    cellCount_idx = 0
    cellXCount = cellYCount = cellCount[cellCount_idx]
    CELL_WIDTH = screen_size/cellXCount
    CELL_HEIGHT = screen_size/cellYCount
    grid = None
    grid_marked = None
    new_obstacles = None
    model = None

    # Main loop
    clock = pygame.time.Clock()
    running = True
    while running:
        mouse_pos = pygame.mouse.get_pos()
        for event in pygame.event.get():
            
            if event.type == pygame.QUIT:
                running = False

            # Key down events
            elif event.type == pygame.KEYDOWN:
                if (event.key == pygame.K_i or event.key == pygame.K_k) and not special_mode:
                    if event.key == pygame.K_i:
                        idx = (idx + 1)%len(scenes)
                    elif event.key == pygame.K_k:
                        idx = (idx - 1)%len(scenes)
                    min_x, max_x, min_y, max_y = scenes[idx].minX, scenes[idx].maxX, scenes[idx].minY, scenes[idx].maxY
                    obstacles = scenes[idx].obstacles
                    scene_width = max_x - min_x
                    scene_height = max_y - min_y

                    scale = screen_size / min(scene_width, scene_height)

                    offset = [0, 0]

                    # Movement and zoom variables
                    zoom_factor = 1.05
                    dragging = False
                    drag_start = (0, 0)
                    move_speed = 5
                    moving = {"up": False, "down": False, "left": False, "right": False}
                    zooming = {"in": False, "out": False}
                    current_scale = scale
                    special_mode = False
                    right_dragging = False
                    start_pos = None
                    end_pos = None
                    bounding_box = True

                if event.key == pygame.K_b and special_mode:  # Right mouse button
                    bounding_box = not bounding_box

                if event.key == pygame.K_e and special_mode:  # Right mouse button
                    cellCount_idx = (cellCount_idx+1)%len(cellCount)
                    cellXCount = cellYCount = cellCount[cellCount_idx]
                    CELL_WIDTH = screen_size/cellXCount
                    CELL_HEIGHT = screen_size/cellYCount
    
                    grid_marked = np.zeros((cellYCount,cellXCount,2))
                    
                
                if event.key == pygame.K_g and not special_mode:  # Right mouse button
                    start_pos = pygame.mouse.get_pos()
                    end_pos = start_pos
                    right_dragging = True

                if event.key == pygame.K_0:  # Toggle special mode
                    special_mode = not special_mode
                    if special_mode:
                        [(top_left_x,top_left_y),(top_right_x,top_right_y),(bottom_left_x,bottom_left_y),(bottom_right_x,bottom_right_y)] = get_corner_coordinates(min_x, max_x, min_y, max_y, current_scale, offset, screen_size)
                        new_obstacles = {}
                        new_polygon_id = 0
                        for id in obstacles:
                            clipped_vertices =  sutherland_hodgman(obstacles[id], bottom_left_x, bottom_right_x, bottom_left_y, top_left_y)
                            if clipped_vertices:  # Check if the list of vertices is not empty
                                new_obstacles[new_polygon_id] = clipped_vertices
                                new_polygon_id += 1
                        grid = SceneEncoder(bottom_left_x, bottom_right_x, bottom_left_y, top_left_y,new_obstacles)
                        grid_marked = np.zeros((cellYCount,cellXCount,2))
                if not special_mode:
                    if event.key == pygame.K_PLUS or event.key == pygame.K_EQUALS:
                        zooming["in"] = True
                    elif event.key == pygame.K_MINUS:
                        zooming["out"] = True
                    elif event.key == pygame.K_UP:
                        moving["up"] = True
                    elif event.key == pygame.K_DOWN:
                        moving["down"] = True
                    elif event.key == pygame.K_LEFT:
                        moving["left"] = True
                    elif event.key == pygame.K_RIGHT:
                        moving["right"] = True

            # Key up events
            elif event.type == pygame.KEYUP:
                if not special_mode:
                    if event.key == pygame.K_PLUS or event.key == pygame.K_EQUALS:
                        zooming["in"] = False
                    elif event.key == pygame.K_MINUS:
                        zooming["out"] = False
                    elif event.key == pygame.K_UP:
                        moving["up"] = False
                    elif event.key == pygame.K_DOWN:
                        moving["down"] = False
                    elif event.key == pygame.K_LEFT:
                        moving["left"] = False
                    elif event.key == pygame.K_RIGHT:
                        moving["right"] = False
                    elif event.key == pygame.K_g:
                        print("h")
                        end_pos = pygame.mouse.get_pos()
                        right_dragging = False
                        x,y = min(end_pos[0],start_pos[0]), max(end_pos[1],start_pos[1])
                        max_moved = min(abs(end_pos[0] - start_pos[0]), abs(end_pos[1] - start_pos[1]))
                        if (max_moved!=0):
                            print("up")
                            [(top_left_x,top_left_y),(top_right_x,top_right_y),(bottom_left_x,bottom_left_y),(bottom_right_x,bottom_right_y)] = get_corner_coordinates(min_x, max_x, min_y, max_y, current_scale, offset, screen_size)
                            current_scale = screen_size/((max_moved/screen_size)*(top_right_x-top_left_x))
                            x_actual, y_actual = (x/screen_size)*(top_right_x-top_left_x) + top_left_x , ((screen_size-y)/screen_size)*(top_left_y-bottom_left_y)+bottom_left_y
                            offset_x = -(x_actual - min_x) * current_scale 
                            offset_y = 1 - (y_actual - min_y) * current_scale
                            # current_scale = new_scale
                            offset[0] = offset_x
                            offset[1] = offset_y

            # Mouse drag
            elif event.type == pygame.MOUSEBUTTONDOWN and not special_mode:
                if event.button == 1:  # Left mouse button
                    dragging = True
                    drag_start = event.pos
                
            elif event.type == pygame.MOUSEBUTTONDOWN and special_mode:
                if event.button == 1:  # Left mouse button
                    mouse_pos = pygame.mouse.get_pos()
                    clicked_cell = get_cell_under_mouse(mouse_pos,CELL_WIDTH,CELL_HEIGHT,cellYCount,cellXCount)
                    clicked_cell = (cellYCount-1-clicked_cell[0],clicked_cell[1])
                    if clicked_cell:
                        grid_marked = generate_pygame_data(grid,np.zeros((cellYCount,cellXCount,2)),clicked_cell)
            elif event.type == pygame.MOUSEBUTTONUP and not special_mode:
                if event.button == 1:  # Left mouse button
                    dragging = False
                
            elif event.type == pygame.MOUSEMOTION and not special_mode:
                if dragging:
                    dx, dy = event.rel
                    offset[0] += dx
                    offset[1] -= dy
                elif right_dragging:
                    end_pos = pygame.mouse.get_pos()

            # Mouse wheel
            elif event.type == pygame.MOUSEWHEEL and not special_mode:
                if event.y > 0:
                    zooming["in"] = True
                elif event.y < 0:
                    zooming["out"] = True

        # Handle continuous zoom
        if zooming["in"]:
            mouse_world_x = (mouse_pos[0] - offset[0]) / current_scale
            mouse_world_y = ((screen_size-mouse_pos[1]) - offset[1]) / current_scale
            current_scale = current_scale * zoom_factor
            offset[0] -= mouse_world_x * (zoom_factor - 1) * current_scale
            offset[1] -= mouse_world_y * (zoom_factor - 1) * current_scale

        if zooming["out"]:
            mouse_world_x = (mouse_pos[0] - offset[0]) / current_scale
            mouse_world_y = ((screen_size-mouse_pos[1]) - offset[1]) / current_scale
            current_scale = max(current_scale / zoom_factor , scale)
            # offset[0] += mouse_world_x * (1 - 1 / zoom_factor) * current_scale
            # offset[1] += mouse_world_y * (1 - 1 / zoom_factor) * current_scale
            offset[0] += mouse_world_x * (zoom_factor - 1) * current_scale
            offset[1] += mouse_world_y * (zoom_factor - 1) * current_scale

        # Clamp offset to stay within the initial scene
        offset[0] = clamp(offset[0], screen_size - scene_width * current_scale, 0)
        offset[1] = clamp(offset[1], screen_size - scene_height * current_scale, 0)

        # Handle continuous movement
        if moving["up"]:
            offset[1] = clamp(offset[1] - move_speed, screen_size - scene_height * current_scale, 0)
        if moving["down"]:
            offset[1] = clamp(offset[1] + move_speed, screen_size - scene_height * current_scale, 0)
        if moving["left"]:
            offset[0] = clamp(offset[0] + move_speed, screen_size - scene_width * current_scale, 0)
        if moving["right"]:
            offset[0] = clamp(offset[0] - move_speed, screen_size - scene_width * current_scale, 0)

        # Clear screen
        screen.fill(background_color)
        # Draw obstacles
        if special_mode:
            screen.fill(WHITE)

            # Draw the grid with the initial colors
            draw_grid(screen,grid_marked,CELL_WIDTH,CELL_HEIGHT,screen_size)

            # Get the mouse position and check which cell is hovered
            mouse_pos = pygame.mouse.get_pos()
            hovered_cell = get_cell_under_mouse(mouse_pos,CELL_WIDTH,CELL_HEIGHT,cellYCount,cellXCount)

            
            for obstacle_id, vertices in new_obstacles.items():
                transformed_vertices = transform_coordinates(vertices, offset, current_scale, min_x, min_y,screen_size)
                pygame.draw.polygon(screen, obstacle_fill_color, transformed_vertices)
                if bounding_box:
                    pygame.draw.polygon(screen, obstacle_outline_color, transformed_vertices, width=2)

            if hovered_cell:
                row, col = hovered_cell
                pygame.draw.rect(screen, GRAY, (col * CELL_WIDTH, row * CELL_HEIGHT, CELL_WIDTH, CELL_HEIGHT))

        else:
            for obstacle_id, vertices in obstacles.items():
                transformed_vertices = transform_coordinates(vertices, offset, current_scale, min_x, min_y,screen_size)
                pygame.draw.polygon(screen, obstacle_fill_color, transformed_vertices)
                pygame.draw.polygon(screen, obstacle_outline_color, transformed_vertices, width=2)
                # Calculate and display dynamic corner coordinates
            draw_coordinates_on_screen(screen, pygame.font.SysFont("Arial", 18), min_x, max_x, min_y, max_y, current_scale, offset, screen_size)
            draw_scene_id(screen, pygame.font.SysFont("Arial", 20), screen_size, idx)
            if right_dragging:
                draw_select(screen,start_pos,end_pos)

        
        # Update display
        pygame.display.flip()
        clock.tick(60)

    pygame.quit()
    sys.exit()

if __name__ == "__main__":
    main() 


Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...
Loading checkpoint...


SystemExit: 

In [None]:

# intersect 3 on 