# Preprocessing

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import random


class SceneDataset(Dataset):
    def __init__(self, file_path, do_augmentation=True):
        self.file_path = file_path
        self.do_augmentation = do_augmentation
        self.cache = []  # To store parsed scenes
        self.samples = []  # To store indices of individual samples

        # Read and parse the file during initialization
        self._parse_file()

    def _parse_file(self):
        """Reads the file and caches scenes and their queries."""
        with open(self.file_path, 'r') as file:
            for line_idx, line in enumerate(file):
                line = line.strip().split()

                # Parse obstacles
                obstacles = []
                idx = 0
                while idx < len(line) and line[idx] == 'p':
                    idx += 1  # Skip 'p'
                    obstacle = []
                    while idx < len(line) and line[idx] != 'p' and line[idx] != 'q':
                        x, y = float(line[idx]), float(line[idx + 1])
                        obstacle.append((x, y))
                        idx += 2
                    obstacles.append(obstacle)

                # Parse queries and labels
                queries = []
                while idx < len(line):
                    if line[idx] == 'q':
                        idx += 1  # Skip 'q'
                        query = []
                        for _ in range(2):  # Each query has 2 coordinate pairs
                            x, y = float(line[idx]), float(line[idx + 1])
                            query.append((x, y))
                            idx += 2
                        label = int(line[idx])  # Label follows the query
                        idx += 1
                        queries.append((query, label))

                # Cache the parsed scene
                scene_idx = len(self.cache)
                self.cache.append((obstacles, queries))

                # Index individual samples
                for query_idx in range(len(queries)):
                    self.samples.append((scene_idx, query_idx))

    def _generate_vertex_order(self, vertices):
        """Generates cyclic random orders for vertices."""
        n = len(vertices)
        orders = []
        r = 1
        if self.do_augmentation:
            r = 2
        for _ in range(r):
            start_idx = random.randint(0, n - 1)
            orders.append(list(range(start_idx, n)) + list(range(0, start_idx)))
        return orders

    def _generate_obstacle_order(self, num_obstacles, num_order):
        """Generates random orders for obstacles."""
        orders = []
        if not self.do_augmentation:
            num_order = 1
        for _ in range(num_order):
            orders.append(random.sample(range(num_obstacles), num_obstacles))
        return orders

    def __len__(self):
        """Total number of samples (queries)."""
        return len(self.samples)

    def __getitem__(self, idx):
        """Returns a single sample: obstacles, query, label, and augmentation orders."""
        scene_idx, query_idx = self.samples[idx]
        obstacles, queries = self.cache[scene_idx]
        query, label = queries[query_idx]

        # Generate augmentation orders
        vertex_orders = [self._generate_vertex_order(obstacle) for obstacle in obstacles]
        obstacle_order = self._generate_obstacle_order(len(obstacles),3)

        return {
            'obstacles': obstacles,        # Original obstacle coordinates
            'query': query,                # Original query coordinates
            'label': label,                # Binary label (0 or 1)
            'vertex_orders': vertex_orders, # Vertex augmentation orders per obstacle
            'obstacle_order': obstacle_order # Obstacle augmentation orders
        }

# Custom collate function for batching
def collate_fn(batch):
    """Prepares a batch by grouping obstacles, queries, labels, and orders."""
    obstacle_batch = []
    query_batch = []
    label_batch = []
    vertex_orders_batch = []
    obstacle_orders_batch = []

    for item in batch:
        obstacle_batch.append(item['obstacles'])
        query_batch.append(item['query'])
        label_batch.append(item['label'])
        vertex_orders_batch.append(item['vertex_orders'])
        obstacle_orders_batch.append(item['obstacle_order'])

    return {
        'obstacles': obstacle_batch,          # List of obstacles for each sample
        'queries': query_batch,               # List of queries for each sample
        'labels': torch.tensor(label_batch, dtype=torch.float),  # Labels as tensor
        'vertex_orders': vertex_orders_batch, # Vertex augmentation orders
        'obstacle_orders': obstacle_orders_batch # Obstacle augmentation orders
    }


train_dataset = SceneDataset("/kaggle/input/data0007/train_3500.txt")
val_dataset = SceneDataset("/kaggle/input/data0001/val_50.txt",do_augmentation=False)
test_dataset = SceneDataset("/kaggle/input/data0010/train_6000.txt",do_augmentation=False)

# Training and Testing

In [None]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix
from tqdm import tqdm

# Define the model
class SceneQueryModel(nn.Module):
    def __init__(self, vertex_input_dim, obstacle_hidden_dim, scene_hidden_dim, output_dim):
        super(SceneQueryModel, self).__init__()

        # Shared RNN block for obstacles and queries
        self.rnn_obstacle = nn.LSTM(input_size=vertex_input_dim, hidden_size=obstacle_hidden_dim, 
                                    num_layers=1, batch_first=True, dropout=0.2)

        self.obstacle_embedding_fc = nn.Sequential(
            nn.Linear(obstacle_hidden_dim, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32)
        )

        # Scene-level RNN block
        self.rnn_scene = nn.LSTM(input_size=32, hidden_size=scene_hidden_dim, 
                                 num_layers=1, batch_first=True, dropout=0.2)

        self.scene_embedding_fc = nn.Sequential(
            nn.Linear(scene_hidden_dim, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128)
        )

        # Final classification block
        self.classifier = nn.Sequential(
            nn.Linear(128 + 32, 32),
            nn.ReLU(),
            nn.BatchNorm1d(32),
            nn.Linear(32, 8),
            nn.ReLU(),
            nn.BatchNorm1d(8),
            nn.Linear(8, output_dim),
            nn.Sigmoid()
        )

    def forward(self, obstacles, queries, vertex_orders, obstacle_orders):
        device = next(self.parameters()).device  # Automatically get the device of the model
    
        # Flatten obstacles and vertex_orders for batch processing
        flat_obstacles = []
        flat_orders = []
        batch_indices = []
        obstacle_indices = []
        flat_obstacle_indices = []
        order_lengths = []
    
        m = 0
        for i, (obstacle_set, vertex_order_set) in enumerate(zip(obstacles, vertex_orders)):
            for j, (obstacle, orders) in enumerate(zip(obstacle_set, vertex_order_set)):
                flat_obstacles.append(torch.tensor(obstacle, dtype=torch.float, device=device))
                for order in orders:
                    flat_obstacle_indices.append(m)
                    flat_orders.append(order)
                    batch_indices.append(i)
                    obstacle_indices.append(j)
                    order_lengths.append(len(order))
                m += 1
    
        # Pad vertex orders
        max_order_length = max(order_lengths)
        padded_orders = torch.zeros((len(flat_orders), max_order_length), dtype=torch.long, device=device)
        for idx, order in enumerate(flat_orders):
            padded_orders[idx, :len(order)] = torch.tensor(order, dtype=torch.long, device=device)
    
        # Pad obstacles
        max_vertices = max(len(obs) for obs in flat_obstacles)
        padded_obstacles = torch.zeros((len(flat_obstacles), max_vertices, 2), dtype=torch.float, device=device)
        for idx, obs in enumerate(flat_obstacles):
            padded_obstacles[idx, :len(obs)] = obs
    
        # Reorder vertices according to padded orders
        ordered_vertices = []
        for k in range(len(padded_orders)):
            order = padded_orders[k]
            vertices = padded_obstacles[flat_obstacle_indices[k]]
            ordered_vertices.append(vertices[order])
    
        ordered_vertices = torch.stack(ordered_vertices)
    
        # Create sequence lengths for packing
        sequence_lengths = torch.tensor([len(order) for order in flat_orders], device=device)
        # Move sequence_lengths to the CPU and convert to int64 for compatibility
        sequence_lengths = sequence_lengths.cpu().to(torch.int64)
    
        # Pack the sequences for RNN
        packed_vertices = nn.utils.rnn.pack_padded_sequence(ordered_vertices, sequence_lengths, batch_first=True, enforce_sorted=False)
        _, (h_n, _) = self.rnn_obstacle(packed_vertices)
    
        # Compute embeddings for each order
        embeddings = self.obstacle_embedding_fc(h_n[-1])  # Shape: (total_orders, embedding_output_size)
    
        # Aggregate embeddings back to obstacle level
        obstacle_embeddings = torch.zeros((len(obstacles), max([len(o) for o in obstacles]), embeddings.size(-1)), device=device)
        order_counts = torch.zeros_like(obstacle_embeddings[..., 0])  # For averaging
    
        for i, (batch_idx, obstacle_idx) in enumerate(zip(batch_indices, obstacle_indices)):
            obstacle_embeddings[batch_idx, obstacle_idx] += embeddings[i]
            order_counts[batch_idx, obstacle_idx] += 1
    
        # Avoid division by zero and compute the mean
        obstacle_embeddings /= order_counts.unsqueeze(-1).clamp(min=1)
    
        flat_orders = []
        batch_indices = []
        order_lengths = []
    
        for i, orders in enumerate(obstacle_orders):
            for order in orders:
                flat_orders.append(order)
                batch_indices.append(i)
                order_lengths.append(len(order))
    
        max_order_length = max(order_lengths)
        padded_orders = torch.zeros((len(flat_orders), max_order_length), dtype=torch.long, device=device)
        for idx, order in enumerate(flat_orders):
            padded_orders[idx, :len(order)] = torch.tensor(order, dtype=torch.long, device=device)
    
        ordered_obstacles = []
        for k in range(len(padded_orders)):
            order = padded_orders[k]
            embed = obstacle_embeddings[batch_indices[k]]
            ordered_obstacles.append(embed[order])
    
        ordered_obstacles = torch.stack(ordered_obstacles)
    
        # Create sequence lengths for packing
        sequence_lengths = torch.tensor([len(order) for order in flat_orders], device=device)
        # Move sequence_lengths to the CPU and convert to int64 for compatibility
        sequence_lengths = sequence_lengths.cpu().to(torch.int64)
    
        # Pack the sequences for RNN
        packed_obstacles = nn.utils.rnn.pack_padded_sequence(ordered_obstacles, sequence_lengths, batch_first=True, enforce_sorted=False)
        _, (h_n, _) = self.rnn_scene(packed_obstacles)
    
        # Compute embeddings for each order
        flat_scene_embeddings = self.scene_embedding_fc(h_n[-1])   # Shape: (total_orders, embedding_output_size)
    
        # Aggregate embeddings back to obstacle level
        scene_embeddings = torch.zeros((len(obstacles), flat_scene_embeddings.size(-1)), device=device)
        order_counts = torch.zeros_like(scene_embeddings[..., 0])  # For averaging
    
        for i, batch_idx in enumerate(batch_indices):
            scene_embeddings[batch_idx] += flat_scene_embeddings[i]
            order_counts[batch_idx] += 1
    
        # Avoid division by zero and compute the mean
        scene_embeddings /= order_counts.unsqueeze(-1).clamp(min=1)
    
        # Convert queries into a tensor
        queries_tensor = torch.tensor(queries, dtype=torch.float, device=device)  # Shape: [batch_size, seq_len, feature_dim]
    
        # Pass the batch through the RNN
        _, (h_n, _) = self.rnn_obstacle(queries_tensor)  # h_n shape: [num_layers * num_directions, batch_size, hidden_size]
        
        # Use the last layer's hidden state (for standard RNN or GRU, use h_n[-1]; for LSTM, use the hidden state only)
        query_embeddings = self.obstacle_embedding_fc(h_n[-1])  # Shape: [batch_size, embedding_dim]
    
        # Concatenate query and scene embeddings
        combined = torch.cat((query_embeddings, scene_embeddings), dim=1)
    
        # Classification
        outputs = self.classifier(combined)
    
        return outputs


def calculate_metrics(outputs, labels):
    # Convert probabilities to binary predictions
    predictions = (outputs > 0.5).float()

    # Flatten the tensors
    predictions = predictions.view(-1)
    labels = labels.view(-1)
    
    # Calculate confusion matrix
    tn, fp, fn, tp = confusion_matrix(labels.cpu(), predictions.cpu(), labels=[0, 1]).ravel()
    
    # Calculate metrics
    accuracy = accuracy_score(labels.cpu(), predictions.cpu())
    precision = precision_score(labels.cpu(), predictions.cpu(), zero_division=0)
    recall = recall_score(labels.cpu(), predictions.cpu(), zero_division=0)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    f1 = f1_score(labels.cpu(), predictions.cpu(), zero_division=0)

    return accuracy, precision, recall, specificity, f1

def train(model, train_loader, val_loader, val1_loader, num_epochs, learning_rate):
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Check if a checkpoint exists and load it
    start_epoch = 0
    if os.path.exists(checkpoint_path_load):
        print("Loading checkpoint...")
        checkpoint = torch.load(checkpoint_path_load, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming training from epoch {start_epoch + 1}.")

    for epoch in range(start_epoch, num_epochs):
        model.train()
        train_loss = 0
        all_outputs = []
        all_labels = []

        print(f"Epoch {epoch + 1}/{num_epochs}")
        
        # Use tqdm for batch-level progress tracking
        with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} - Training") as batch_bar:
            for batch in batch_bar:
                obstacles = batch['obstacles']
                queries = batch['queries']
                labels = batch['labels'].to(device)
                vertex_orders = batch['vertex_orders']
                obstacle_orders = batch['obstacle_orders']
        
                labels = labels.unsqueeze(1)  # For BCEWithLogitsLoss
        
                optimizer.zero_grad()
                outputs = model(obstacles, queries, vertex_orders, obstacle_orders)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
        
                train_loss += loss.item()
        
                # Move outputs and labels to CPU to save GPU memory and store them
                all_outputs.append(outputs.detach().cpu())
                all_labels.append(labels.detach().cpu())
        
                # Concatenate all outputs and labels up to the current batch for metric calculation
                cumulative_outputs = torch.cat(all_outputs)
                cumulative_labels = torch.cat(all_labels)
        
                # Calculate cumulative metrics using the provided function
                train_accuracy, train_precision, train_recall, train_specificity, train_f1 = calculate_metrics(cumulative_outputs, cumulative_labels)
        
                # Update progress bar with running loss and metrics
                batch_bar.set_postfix(
                    loss=train_loss / (batch_bar.n + 1),
                    accuracy=train_accuracy,
                    precision=train_precision,
                    recall=train_recall,
                    specificity=train_specificity,
                    f1=train_f1
                )
        
        # Average the training loss
        train_loss /= len(train_loader)

        # Print training metrics for the epoch
        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Train Precision: {train_precision:.4f}, "
              f"Train Recall: {train_recall:.4f}, Train Specificity: {train_specificity:.4f}, Train F1: {train_f1:.4f}")

        # Validation loop
        model.eval()
        val_loss = 0
        all_outputs = []
        all_labels = []
        
        with torch.no_grad():
            for batch in val_loader:
                obstacles = batch['obstacles']
                queries = batch['queries']
                labels = batch['labels'].to(device)
                vertex_orders = batch['vertex_orders']
                obstacle_orders = batch['obstacle_orders']
        
                labels = labels.unsqueeze(1)
        
                outputs = model(obstacles, queries, vertex_orders, obstacle_orders)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        
                # Store outputs and labels for metric calculation
                all_outputs.append(outputs.cpu())
                all_labels.append(labels.cpu())
        
        # Concatenate all outputs and labels
        all_outputs = torch.cat(all_outputs)
        all_labels = torch.cat(all_labels)
        
        # Compute metrics using the provided function
        val_accuracy, val_precision, val_recall, val_specificity, val_f1 = calculate_metrics(all_outputs, all_labels)
        
        # Average the validation loss
        val_loss /= len(val_loader)

        # Print validation metrics for the epoch
        print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, "
              f"Validation Precision: {val_precision:.4f}, Validation Recall: {val_recall:.4f}, "
              f"Validation Specificity: {val_specificity:.4f}, Validation F1: {val_f1:.4f}")

        # Save checkpoint every 1 epochs
        if (epoch + 1) % 1 == 0:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }
            torch.save(checkpoint, checkpoint_path_save)
            print(f"Checkpoint saved at epoch {epoch + 1}.")

def test(model, test_loader):
    criterion = nn.BCELoss()
    
    # Check if a checkpoint exists and load it
    if os.path.exists(checkpoint_path_load):
        print("Loading checkpoint...")
        checkpoint = torch.load(checkpoint_path_load, weights_only=True)
        model.load_state_dict(checkpoint['model_state_dict'])

    # Set model to evaluation mode
    model.eval()
    test_loss = 0
    all_outputs = []
    all_labels = []
    
    with torch.no_grad():
        for batch in test_loader:
            obstacles = batch['obstacles']
            queries = batch['queries']
            labels = batch['labels'].to(device)
            vertex_orders = batch['vertex_orders']
            obstacle_orders = batch['obstacle_orders']
    
            labels = labels.unsqueeze(1)
    
            outputs = model(obstacles, queries, vertex_orders, obstacle_orders)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
    
            # Store outputs and labels for metric calculation
            all_outputs.append(outputs.cpu())
            all_labels.append(labels.cpu())
    
    # Concatenate all outputs and labels
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    
    # Compute metrics using the provided function
    test_accuracy, test_precision, test_recall, test_specificity, test_f1 = calculate_metrics(all_outputs, all_labels)
    
    # Average the test loss
    test_loss /= len(test_loader)

    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}, "
          f"Test Precision: {test_precision:.4f}, Test Recall: {test_recall:.4f}, "
          f"Test Specificity: {test_specificity:.4f}, Test F1: {test_f1:.4f}")


# Model instantiation and dataloader setup
if __name__ == "__main__":
    vertex_input_dim = 2
    obstacle_hidden_dim = 128
    scene_hidden_dim = 512
    output_dim = 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SceneQueryModel(vertex_input_dim, obstacle_hidden_dim, scene_hidden_dim, output_dim).to(device)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

    checkpoint_path_save ='checkpoint.pth'
    checkpoint_path_load = '/kaggle/input/model12/pytorch/default/1/checkpoint (7).pth'

    train(model, train_loader, val_loader, num_epochs=15, learning_rate=0.001)
    
    test(model,test_loader)
