# Visualize

In [4]:
# big
# scene_1000_2000_1000
# scene_1000_1500_1000

# small 
# scene_1000_2000_150_128_5
# scene_1000_1500_70_128_5


filename = "dhaka_1000_20000_1000_128_coverage_0_33.txt"
checkpoint_path_load = 'best_line_classifier.pt'

import pygame
import sys
import numpy as np

from typing import List
import math

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
from typing import Optional, Tuple, List, Dict
import time
from tqdm import tqdm

class TransformerEncoderLayer(nn.Module):
    """
    A single transformer encoder layer.
    """
    def __init__(
        self,
        d_model: int,
        nhead: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: str = "relu",
        layer_norm_eps: float = 1e-5,
    ):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        
        # Feed forward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
        # Activation function
        self.activation = _get_activation_fn(activation)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # Multi-head attention block
        src2 = self.norm1(src)
        src2, _ = self.self_attn(src2, src2, src2, attn_mask=src_mask, 
                                key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(src2)
        
        # Feed forward block
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        
        return src


class TransformerEncoder(nn.Module):
    """
    Full transformer encoder with configurable number of layers.
    """
    def __init__(
        self,
        d_model: int,
        nhead: int,
        num_layers: int,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: str = "relu",
        layer_norm_eps: float = 1e-5,
    ):
        super().__init__()
        
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                dropout=dropout,
                activation=activation,
                layer_norm_eps=layer_norm_eps
            )
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
        
    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        output = src
        
        for layer in self.layers:
            output = layer(output, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask)
            
        output = self.norm(output)
        return output


class LineSequenceClassifier(nn.Module):
    """
    Transformer model for binary classification of 2D line sequences.
    Each line is represented by 4 numbers (x1, y1, x2, y2).
    """
    def __init__(
        self,
        line_dim: int = 4,  # Dimension of each line (x1, y1, x2, y2)
        d_model: int = 128, # Embedding dimension
        nhead: int = 8,
        num_encoder_layers: int = 4,
        dim_feedforward: int = 512,
        dropout: float = 0.1,
        activation: str = "relu",
        layer_norm_eps: float = 1e-5,
    ):
        super().__init__()
        
        self.d_model = d_model
        
        # Project 4D line features to d_model dimensions
        self.line_embedding = nn.Linear(line_dim, d_model)
        
        # Transformer encoder
        self.transformer_encoder = TransformerEncoder(
            d_model=d_model,
            nhead=nhead,
            num_layers=num_encoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
            layer_norm_eps=layer_norm_eps
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(16, 1)
        )
        
    def forward(self, src, src_lengths):
        # src shape: [batch_size, seq_len, line_dim]
        batch_size, max_len, _ = src.shape
        
        # Create padding mask based on sequence lengths
        src_key_padding_mask = torch.arange(max_len, device=src.device).expand(batch_size, max_len) >= src_lengths.unsqueeze(1)
        
        # Line feature embedding
        x = self.line_embedding(src)
        
        # Transformer encoding
        encoded = self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)
        
        # Get the last valid token for each sequence in the batch
        batch_indices = torch.arange(encoded.size(0), device=encoded.device)
        last_indices = src_lengths - 1  # Convert to 0-indexed
        last_hidden = encoded[batch_indices, last_indices]
        
        # Binary classification
        logits = self.classifier(last_hidden).squeeze(-1)
        return logits

    def predict(self, src, src_lengths):
        """
        Convenience method that returns binary predictions.
        """
        logits = self.forward(src, src_lengths)
        return torch.sigmoid(logits) >= 0.5


class LineSequenceDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.cache = []  # To store parsed scenes
        self.samples = []  # To store indices of individual samples

        # Read and parse the file during initialization
        self._parse_file()

    def _parse_file(self):
        """Reads the file and caches scenes and their queries."""

        line = self.data.strip().split()

        # Parse obstacles
        obstacles = []
        idx = 0
        while idx < len(line) and line[idx] != 'q':
            x, y, x1, y1 = float(line[idx]), float(line[idx + 1]), float(line[idx + 2]), float(line[idx + 3])
            idx += 4
            obstacles.append((x, y, x1, y1))

        # Parse queries and labels
        queries = []
        while idx < len(line):
            if line[idx] == 'q':
                idx += 1  # Skip 'q'
                x, y, x1, y1, label = float(line[idx]), float(line[idx + 1]), float(line[idx + 2]), float(line[idx + 3]), int(line[idx+4])
                idx += 5
                queries.append(((x, y, x1, y1), label))

        # Cache the parsed scene
        scene_idx = len(self.cache)
        self.cache.append((obstacles, queries))

        # Index individual samples
        for query_idx in range(len(queries)):
            self.samples.append((scene_idx, query_idx))


    def __len__(self):
        """Total number of samples (queries)."""
        return len(self.samples)

    def __getitem__(self, idx):
        """Returns a single sample: obstacles, query, label, and augmentation orders."""
        scene_idx, query_idx = self.samples[idx]
        obstacles, queries = self.cache[scene_idx]
        query, label = queries[query_idx]

        return obstacles + [query], len(obstacles + [query]), label
    

def _get_activation_fn(activation):
    """Helper function to get activation function by name."""
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu
    else:
        raise ValueError(f"Activation function {activation} not supported")


def collate_fn(batch):
    """
    Custom collate function for variable length sequences.
    """
    sequences, sequence_lengths, labels = zip(*batch)
    
    # Find max sequence length in this batch
    max_len = max(sequence_lengths)
    
    # Pad sequences to max_len
    padded_sequences = []
    for seq, seq_len in zip(sequences, sequence_lengths):
        padded_seq = np.zeros((max_len, 4))
        padded_seq[:seq_len] = seq
        padded_sequences.append(padded_seq)
    
    # Convert to tensors
    sequences_tensor = torch.tensor(np.array(padded_sequences), dtype=torch.float32)
    sequence_lengths_tensor = torch.tensor(sequence_lengths, dtype=torch.long)
    labels_tensor = torch.tensor(labels, dtype=torch.float32)
    
    return sequences_tensor, sequence_lengths_tensor, labels_tensor


def calculate_metrics(y_true, y_pred):
    """
    Calculate classification metrics: accuracy, recall, precision, specificity, and F1 score.
    Handles cases where precision may be undefined due to no positive predictions.
    """
    # Convert to numpy arrays if tensors
    if isinstance(y_true, torch.Tensor):
        y_true = y_true.cpu().numpy()
    if isinstance(y_pred, torch.Tensor):
        y_pred = y_pred.cpu().numpy()
    
    # Calculate metrics with zero_division=0 to prevent warnings
    accuracy = accuracy_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred, zero_division=0)
    precision = precision_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    
    # Calculate specificity (true negative rate)
    if len(y_true) > 0:
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 1.0
    else:
        specificity = 0.0
    
    return {
        'accuracy': accuracy,
        'recall': recall,
        'precision': precision,
        'specificity': specificity,
        'f1_score': f1
    }

def predict(model,test_loader,device,resume_from=None):
    # Load best model
    if resume_from is not None and os.path.exists(resume_from):
        print(f"Loading checkpoint from {resume_from}")
        checkpoint = torch.load(resume_from, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        
    # Evaluate on test set
    model.eval()
    test_preds = []
    test_labels = []
    with torch.no_grad():
        for sequences, seq_lengths, labels in test_loader:
            # Move data to device
            sequences = sequences.to(device, non_blocking=True)
            seq_lengths = seq_lengths.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            # Forward pass
            logits = model(sequences, seq_lengths)
            
            # Store predictions and true labels
            preds = (torch.sigmoid(logits) >= 0.5).float()
            test_preds.extend(preds.cpu().numpy())
            test_labels.extend(labels.cpu().numpy())
    
    return test_preds

# Check CUDA availability and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = LineSequenceClassifier(
        line_dim=4,
        d_model=128,
        nhead=8,
        num_encoder_layers=4,
        dim_feedforward=512,
        dropout=0.1
    ).to(device)
    

def round_down(value: float) -> int:
    # Extract the integer part and the fractional part
    integer_part = int(value)
    fractional_part = value - integer_part

    # Check if the fractional part is less than or equal to 0.5
    if fractional_part <= 0.5:
        return integer_part  # Round down
    else:
        return integer_part + 1  # Round up
    

def bound(value, min_value, max_value):
    
    return max(min(value, max_value), min_value)


class SceneEncoder:
    def __init__(self, minX, maxX, minY, maxY, obstacles):
        self.minX = minX
        self.minY = minY
        self.maxX = maxX
        self.maxY = maxY
        self.obstacles = obstacles
        self.scene_description = ""
        for id in self.obstacles:
            for i in range(len(self.obstacles[id])):
                self.scene_description += str((self.obstacles[id][i][0]-self.minX)/(self.maxX-self.minX)) + " "
                self.scene_description += str((self.obstacles[id][i][1]-self.minY)/(self.maxY-self.minY)) + " "
                self.scene_description += str((self.obstacles[id][(i+1)%len(self.obstacles[id])][0]-self.minX)/(self.maxX-self.minX)) + " "
                self.scene_description += str((self.obstacles[id][(i+1)%len(self.obstacles[id])][1]-self.minY)/(self.maxY-self.minY)) + " "


def generate_pygame_data(grid,scene,cell):
    prediction = None
    scene_dimension = scene.shape
    index = 0
    updated_scene = scene.copy()
    width = 1/scene_dimension[1]
    height = 1/scene_dimension[0]
    query = ""
    for i in range(scene_dimension[0]):
        for j in range(scene_dimension[1]):
            if (i,j)==cell:
                continue
            query += "q "
            query += str(j*width+width/2) + " "
            query += str(i*height+height/2) + " "
            query += str(cell[1]*width+width/2) + " "
            query += str(cell[0]*height+height/2) + " "
            query += "0 "
    
    test_dataset = LineSequenceDataset(grid.scene_description+query)

    # Create data loaders with pinned memory for faster GPU transfer
    test_loader = DataLoader(
        test_dataset, 
        batch_size=32, 
        shuffle=False, 
        collate_fn=collate_fn
    )

    prediction = predict(model,test_loader,device,checkpoint_path_load)
    for i in range(scene_dimension[0]):
        for j in range(scene_dimension[1]):
            if (i,j)==cell:
                updated_scene[i,j,1] = 2
                continue
            if (prediction[index]==1):
                updated_scene[i,j,1] = 3
            else:
                updated_scene[i,j,1] = 4
            index = index + 1
    return updated_scene

def is_inside(vertex, x_min, x_max, y_min, y_max, edge):
    x, y = vertex
    if edge == 0:  # Left edge: x >= x_min
        return x >= x_min
    elif edge == 1:  # Right edge: x <= x_max
        return x <= x_max
    elif edge == 2:  # Bottom edge: y >= y_min
        return y >= y_min
    elif edge == 3:  # Top edge: y <= y_max
        return y <= y_max
    return False

def intersect(v1, v2, x_min, x_max, y_min, y_max, edge):
    x1, y1 = v1
    x2, y2 = v2
    result = [0, 0]
    if edge == 0:  # Left edge: x = x_min
        slope = (y2 - y1) / (x2 - x1)
        result[0] = x_min
        result[1] = y1 + slope * (x_min - x1)
    elif edge == 1:  # Right edge: x = x_max
        slope = (y2 - y1) / (x2 - x1)
        result[0] = x_max
        result[1] = y1 + slope * (x_max - x1)
    elif edge == 2:  # Bottom edge: y = y_min
        slope = (x2 - x1) / (y2 - y1)
        result[1] = y_min
        result[0] = x1 + slope * (y_min - y1)
    elif edge == 3:  # Top edge: y = y_max
        slope = (x2 - x1) / (y2 - y1)
        result[1] = y_max
        result[0] = x1 + slope * (y_max - y1)
    return tuple(result)

def sutherland_hodgman(poly, x_min, x_max, y_min, y_max):
    input_vertices = poly[:]
    for edge in range(4):  # Iterate through all 4 edges
        output_vertices = []
        for i in range(len(input_vertices)):
            curr = input_vertices[i]
            prev = input_vertices[(i + len(input_vertices) - 1) % len(input_vertices)]

            curr_inside = is_inside(curr, x_min, x_max, y_min, y_max, edge)
            prev_inside = is_inside(prev, x_min, x_max, y_min, y_max, edge)

            if prev_inside and curr_inside:
                output_vertices.append(curr)
            elif prev_inside and not curr_inside:
                output_vertices.append(intersect(prev, curr, x_min, x_max, y_min, y_max, edge))
            elif not prev_inside and curr_inside:
                output_vertices.append(intersect(prev, curr, x_min, x_max, y_min, y_max, edge))
                output_vertices.append(curr)
        input_vertices = output_vertices
    return output_vertices


def read_obstacle_file(filename):
    obstacles = {}
    scenes = []
    with open(filename, 'r') as file:
        for line in file:
            if line.split()[0]=="a":
                type,scene_id,obstacle_id, vertex_id, x, y = line.split()
                scene_id = int(scene_id)
                obstacle_id = int(obstacle_id)
                x, y = float(x), float(y)
                if obstacle_id not in obstacles:
                    obstacles[obstacle_id] = []
                obstacles[obstacle_id].append((x, y))
            elif line.split()[0]=='b':
                type,min_x,max_x,min_y,max_y = line.split()
                min_x, max_x, min_y, max_y = float(min_x), float(max_x), float(min_y), float(max_y)
                # min_x, max_x, min_y, max_y = 0,1,0,1
                scenes.append(SceneEncoder(min_x,max_x,min_y,max_y,obstacles))
                obstacles = {}
    return scenes


def transform_coordinates(vertices, offset, scale, min_x, min_y, screen_height):
    """Transform vertices for Cartesian coordinates (y increases upward)."""
    return [
        (offset[0] + (x - min_x) * scale, 
         screen_height - (offset[1] + (y - min_y) * scale)) 
        for x, y in vertices
    ]

def clamp(value, min_value, max_value):
    """Clamp a value to be within a range."""
    return max(min_value, min(value, max_value))

def get_corner_coordinates(min_x, max_x, min_y, max_y, scale, offset, screen_size):
    # Top-left
    top_left_x = min_x + (0 - offset[0]) / scale
    top_left_y = min_y + (screen_size - offset[1]) / scale
    # Top-right
    top_right_x = min_x + (screen_size - offset[0]) / scale
    top_right_y = min_y + (screen_size - offset[1]) / scale
    # Bottom-left
    bottom_left_x = min_x + (0 - offset[0]) / scale
    bottom_left_y = min_y + (0 - offset[1]) / scale
    # Bottom-right
    bottom_right_x = min_x + (screen_size - offset[0]) / scale
    bottom_right_y = min_y + (0 - offset[1]) / scale

    return [(top_left_x,top_left_y),(top_right_x,top_right_y),(bottom_left_x,bottom_left_y),(bottom_right_x,bottom_right_y)]


def draw_coordinates_on_screen(screen, font, min_x, max_x, min_y, max_y, scale, offset, screen_size):
    """Draw corner coordinates (Cartesian) on the screen."""
       

    [(top_left_x,top_left_y),(top_right_x,top_right_y),(bottom_left_x,bottom_left_y),(bottom_right_x,bottom_right_y)] = get_corner_coordinates(min_x, max_x, min_y, max_y, scale, offset, screen_size)

    # Render coordinates as text
    corners = [
        (f"({top_left_x:.2f}, {top_left_y:.2f})", (10, 10)),  # Top-left
        (f"({top_right_x:.2f}, {top_right_y:.2f})", (screen_size - 180, 10)),  # Top-right
        (f"({bottom_left_x:.2f}, {bottom_left_y:.2f})", (10, screen_size - 30)),  # Bottom-left
        (f"({bottom_right_x:.2f}, {bottom_right_y:.2f})", (screen_size - 180, screen_size - 30)),  # Bottom-right
    ]

    for text, position in corners:
        rendered_text = font.render(text, True, (0, 0, 0))
        screen.blit(rendered_text, position)


def draw_scene_id(screen, font, screen_size, idx):
    rendered_text = font.render(f"{idx}", True, (255, 0, 0))
    screen.blit(rendered_text, (screen_size//2,10))



WHITE = (255, 255, 255)
BLACK = (0,0,0)
RED = (255, 0, 0)
GREEN = (0, 255, 0)
BLUE = (0, 0, 255)
GRAY = (200, 200, 200)
background_color = (255, 255, 255)  # White
obstacle_fill_color = (150, 150, 150)  # Gray
obstacle_outline_color = (0, 0, 255)  # Blue

glass_alpha = 20


color_map = {
    0: WHITE,
    1: obstacle_fill_color,
    2: BLUE,
    3: GREEN,
    4: RED
}

def draw_grid(DIS,data,CELL_WIDTH,CELL_HEIGHT,screen_size):
    for row in range(data.shape[0]):
        for col in range(data.shape[1]):
            pygame.draw.rect(DIS, color_map[data[row][col][0]], (col * CELL_WIDTH, screen_size-(row+1) * CELL_HEIGHT, CELL_WIDTH, CELL_HEIGHT))
            glass_surface = pygame.Surface((CELL_WIDTH, CELL_HEIGHT), pygame.SRCALPHA)
            if (data[row][col][1]>2):
                glass_surface.fill((*color_map[data[row][col][1]], glass_alpha))
                DIS.blit(glass_surface, (col * CELL_WIDTH, screen_size- (row+1) * CELL_HEIGHT))
                pygame.draw.rect(DIS, color_map[data[row][col][1]], (col * CELL_WIDTH + (CELL_WIDTH-CELL_WIDTH//4)//2, screen_size- (row+1) * CELL_HEIGHT + (CELL_HEIGHT-CELL_HEIGHT//4)//2, CELL_WIDTH//4, CELL_HEIGHT//4))
            elif (data[row][col][1]==2):
                glass_surface.fill((*color_map[data[row][col][1]], glass_alpha))
                DIS.blit(glass_surface, (col * CELL_WIDTH, screen_size- (row+1) * CELL_HEIGHT))
                
                pygame.draw.rect(DIS, color_map[data[row][col][1]], (col * CELL_WIDTH + CELL_WIDTH//4, screen_size- (row+1) * CELL_HEIGHT + CELL_HEIGHT//4, CELL_WIDTH//2, CELL_HEIGHT//2))

def get_cell_under_mouse(mouse_pos,CELL_WIDTH,CELL_HEIGHT,ROWS,COLS):
    x, y = mouse_pos
    col = x // CELL_WIDTH
    row = y // CELL_HEIGHT
    if row < ROWS and col < COLS:
        return (int(row), int(col))
    return

def draw_select(screen,start_pos,end_pos):
    glass_surface = pygame.Surface((abs(start_pos[0]-end_pos[0]), abs(start_pos[1]-end_pos[1])), pygame.SRCALPHA)
    glass_surface.fill((*(100,100,100), 100))
    screen.blit(glass_surface, (min(start_pos[0],end_pos[0]),min(start_pos[1],end_pos[1])))

def main():
    # Initialize pygame
    pygame.init()

    # Screen dimensions (square screen)
    screen_size = 704
    screen = pygame.display.set_mode((screen_size, screen_size))
    pygame.display.set_caption("Obstacle Visualization with Zoom and Pan")

    # Colors
    
    # Load obstacles
    scenes = read_obstacle_file(filename)
    
    idx = 0
    
    # Calculate scene bounds and scale
    min_x, max_x,min_y, max_y = scenes[idx].minX, scenes[idx].maxX, scenes[idx].minY, scenes[idx].maxY
    obstacles = scenes[idx].obstacles
    scene_width = max_x - min_x
    scene_height = max_y - min_y

    scale = screen_size / min(scene_width, scene_height)

    offset = [0, 0]

    # Movement and zoom variables
    zoom_factor = 1.05
    dragging = False
    drag_start = (0, 0)
    move_speed = 5
    moving = {"up": False, "down": False, "left": False, "right": False}
    zooming = {"in": False, "out": False}
    current_scale = scale
    special_mode = False
    right_dragging = False
    start_pos = None
    end_pos = None
    bounding_box = True

    cellCount = [32,64]
    cellCount_idx = 0
    cellXCount = cellYCount = cellCount[cellCount_idx]
    CELL_WIDTH = screen_size/cellXCount
    CELL_HEIGHT = screen_size/cellYCount
    grid = None
    grid_marked = None
    new_obstacles = None
    model = None

    # Main loop
    clock = pygame.time.Clock()
    running = True
    while running:
        mouse_pos = pygame.mouse.get_pos()
        for event in pygame.event.get():
            
            if event.type == pygame.QUIT:
                running = False

            # Key down events
            elif event.type == pygame.KEYDOWN:
                if (event.key == pygame.K_i or event.key == pygame.K_k) and not special_mode:
                    if event.key == pygame.K_i:
                        idx = (idx + 1)%len(scenes)
                    elif event.key == pygame.K_k:
                        idx = (idx - 1)%len(scenes)
                    min_x, max_x, min_y, max_y = scenes[idx].minX, scenes[idx].maxX, scenes[idx].minY, scenes[idx].maxY
                    obstacles = scenes[idx].obstacles
                    scene_width = max_x - min_x
                    scene_height = max_y - min_y

                    scale = screen_size / min(scene_width, scene_height)

                    offset = [0, 0]

                    # Movement and zoom variables
                    zoom_factor = 1.05
                    dragging = False
                    drag_start = (0, 0)
                    move_speed = 5
                    moving = {"up": False, "down": False, "left": False, "right": False}
                    zooming = {"in": False, "out": False}
                    current_scale = scale
                    special_mode = False
                    right_dragging = False
                    start_pos = None
                    end_pos = None
                    bounding_box = True

                if event.key == pygame.K_b and special_mode:  # Right mouse button
                    bounding_box = not bounding_box

                if event.key == pygame.K_e and special_mode:  # Right mouse button
                    cellCount_idx = (cellCount_idx+1)%len(cellCount)
                    cellXCount = cellYCount = cellCount[cellCount_idx]
                    CELL_WIDTH = screen_size/cellXCount
                    CELL_HEIGHT = screen_size/cellYCount
    
                    grid_marked = np.zeros((cellYCount,cellXCount,2))
                    
                
                if event.key == pygame.K_g and not special_mode:  # Right mouse button
                    start_pos = pygame.mouse.get_pos()
                    end_pos = start_pos
                    right_dragging = True

                if event.key == pygame.K_0:  # Toggle special mode
                    special_mode = not special_mode
                    if special_mode:
                        [(top_left_x,top_left_y),(top_right_x,top_right_y),(bottom_left_x,bottom_left_y),(bottom_right_x,bottom_right_y)] = get_corner_coordinates(min_x, max_x, min_y, max_y, current_scale, offset, screen_size)
                        new_obstacles = {}
                        new_polygon_id = 0
                        for id in obstacles:
                            clipped_vertices =  sutherland_hodgman(obstacles[id], bottom_left_x, bottom_right_x, bottom_left_y, top_left_y)
                            if clipped_vertices:  # Check if the list of vertices is not empty
                                new_obstacles[new_polygon_id] = clipped_vertices
                                new_polygon_id += 1
                        grid = SceneEncoder(bottom_left_x, bottom_right_x, bottom_left_y, top_left_y,new_obstacles)
                        grid_marked = np.zeros((cellYCount,cellXCount,2))
                if not special_mode:
                    if event.key == pygame.K_PLUS or event.key == pygame.K_EQUALS:
                        zooming["in"] = True
                    elif event.key == pygame.K_MINUS:
                        zooming["out"] = True
                    elif event.key == pygame.K_UP:
                        moving["up"] = True
                    elif event.key == pygame.K_DOWN:
                        moving["down"] = True
                    elif event.key == pygame.K_LEFT:
                        moving["left"] = True
                    elif event.key == pygame.K_RIGHT:
                        moving["right"] = True

            # Key up events
            elif event.type == pygame.KEYUP:
                if not special_mode:
                    if event.key == pygame.K_PLUS or event.key == pygame.K_EQUALS:
                        zooming["in"] = False
                    elif event.key == pygame.K_MINUS:
                        zooming["out"] = False
                    elif event.key == pygame.K_UP:
                        moving["up"] = False
                    elif event.key == pygame.K_DOWN:
                        moving["down"] = False
                    elif event.key == pygame.K_LEFT:
                        moving["left"] = False
                    elif event.key == pygame.K_RIGHT:
                        moving["right"] = False
                    elif event.key == pygame.K_g:
                        print("h")
                        end_pos = pygame.mouse.get_pos()
                        right_dragging = False
                        x,y = min(end_pos[0],start_pos[0]), max(end_pos[1],start_pos[1])
                        max_moved = min(abs(end_pos[0] - start_pos[0]), abs(end_pos[1] - start_pos[1]))
                        if (max_moved!=0):
                            print("up")
                            [(top_left_x,top_left_y),(top_right_x,top_right_y),(bottom_left_x,bottom_left_y),(bottom_right_x,bottom_right_y)] = get_corner_coordinates(min_x, max_x, min_y, max_y, current_scale, offset, screen_size)
                            current_scale = screen_size/((max_moved/screen_size)*(top_right_x-top_left_x))
                            x_actual, y_actual = (x/screen_size)*(top_right_x-top_left_x) + top_left_x , ((screen_size-y)/screen_size)*(top_left_y-bottom_left_y)+bottom_left_y
                            offset_x = -(x_actual - min_x) * current_scale 
                            offset_y = 1 - (y_actual - min_y) * current_scale
                            # current_scale = new_scale
                            offset[0] = offset_x
                            offset[1] = offset_y

            # Mouse drag
            elif event.type == pygame.MOUSEBUTTONDOWN and not special_mode:
                if event.button == 1:  # Left mouse button
                    dragging = True
                    drag_start = event.pos
                
            elif event.type == pygame.MOUSEBUTTONDOWN and special_mode:
                if event.button == 1:  # Left mouse button
                    mouse_pos = pygame.mouse.get_pos()
                    clicked_cell = get_cell_under_mouse(mouse_pos,CELL_WIDTH,CELL_HEIGHT,cellYCount,cellXCount)
                    clicked_cell = (cellYCount-1-clicked_cell[0],clicked_cell[1])
                    if clicked_cell:
                        grid_marked = generate_pygame_data(grid,np.zeros((cellYCount,cellXCount,2)),clicked_cell)
            elif event.type == pygame.MOUSEBUTTONUP and not special_mode:
                if event.button == 1:  # Left mouse button
                    dragging = False
                
            elif event.type == pygame.MOUSEMOTION and not special_mode:
                if dragging:
                    dx, dy = event.rel
                    offset[0] += dx
                    offset[1] -= dy
                elif right_dragging:
                    end_pos = pygame.mouse.get_pos()

            # Mouse wheel
            elif event.type == pygame.MOUSEWHEEL and not special_mode:
                if event.y > 0:
                    zooming["in"] = True
                elif event.y < 0:
                    zooming["out"] = True

        # Handle continuous zoom
        if zooming["in"]:
            mouse_world_x = (mouse_pos[0] - offset[0]) / current_scale
            mouse_world_y = ((screen_size-mouse_pos[1]) - offset[1]) / current_scale
            current_scale = current_scale * zoom_factor
            offset[0] -= mouse_world_x * (zoom_factor - 1) * current_scale
            offset[1] -= mouse_world_y * (zoom_factor - 1) * current_scale

        if zooming["out"]:
            mouse_world_x = (mouse_pos[0] - offset[0]) / current_scale
            mouse_world_y = ((screen_size-mouse_pos[1]) - offset[1]) / current_scale
            current_scale = max(current_scale / zoom_factor , scale)
            # offset[0] += mouse_world_x * (1 - 1 / zoom_factor) * current_scale
            # offset[1] += mouse_world_y * (1 - 1 / zoom_factor) * current_scale
            offset[0] += mouse_world_x * (zoom_factor - 1) * current_scale
            offset[1] += mouse_world_y * (zoom_factor - 1) * current_scale

        # Clamp offset to stay within the initial scene
        offset[0] = clamp(offset[0], screen_size - scene_width * current_scale, 0)
        offset[1] = clamp(offset[1], screen_size - scene_height * current_scale, 0)

        # Handle continuous movement
        if moving["up"]:
            offset[1] = clamp(offset[1] - move_speed, screen_size - scene_height * current_scale, 0)
        if moving["down"]:
            offset[1] = clamp(offset[1] + move_speed, screen_size - scene_height * current_scale, 0)
        if moving["left"]:
            offset[0] = clamp(offset[0] + move_speed, screen_size - scene_width * current_scale, 0)
        if moving["right"]:
            offset[0] = clamp(offset[0] - move_speed, screen_size - scene_width * current_scale, 0)

        # Clear screen
        screen.fill(background_color)
        # Draw obstacles
        if special_mode:
            screen.fill(WHITE)

            # Draw the grid with the initial colors
            draw_grid(screen,grid_marked,CELL_WIDTH,CELL_HEIGHT,screen_size)

            # Get the mouse position and check which cell is hovered
            mouse_pos = pygame.mouse.get_pos()
            hovered_cell = get_cell_under_mouse(mouse_pos,CELL_WIDTH,CELL_HEIGHT,cellYCount,cellXCount)

            
            for obstacle_id, vertices in new_obstacles.items():
                transformed_vertices = transform_coordinates(vertices, offset, current_scale, min_x, min_y,screen_size)
                pygame.draw.polygon(screen, obstacle_fill_color, transformed_vertices)
                if bounding_box:
                    pygame.draw.polygon(screen, obstacle_outline_color, transformed_vertices, width=2)

            if hovered_cell:
                row, col = hovered_cell
                pygame.draw.rect(screen, GRAY, (col * CELL_WIDTH, row * CELL_HEIGHT, CELL_WIDTH, CELL_HEIGHT))

        else:
            for obstacle_id, vertices in obstacles.items():
                transformed_vertices = transform_coordinates(vertices, offset, current_scale, min_x, min_y,screen_size)
                pygame.draw.polygon(screen, obstacle_fill_color, transformed_vertices)
                pygame.draw.polygon(screen, obstacle_outline_color, transformed_vertices, width=2)
                # Calculate and display dynamic corner coordinates
            draw_coordinates_on_screen(screen, pygame.font.SysFont("Arial", 18), min_x, max_x, min_y, max_y, current_scale, offset, screen_size)
            draw_scene_id(screen, pygame.font.SysFont("Arial", 20), screen_size, idx)
            if right_dragging:
                draw_select(screen,start_pos,end_pos)

        
        # Update display
        pygame.display.flip()
        clock.tick(60)

    pygame.quit()
    sys.exit()

if __name__ == "__main__":
    main() 


Using device: cpu
Loading checkpoint from best_line_classifier.pt


  checkpoint = torch.load(resume_from, map_location=device)


Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classifier.pt
Loading checkpoint from best_line_classi

SystemExit: 

In [None]:

# intersect 3 on 