# PyTorch Dataloader

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import random


class SceneDataset(Dataset):
    def __init__(self, file_path, do_augmentation=True):
        self.file_path = file_path
        self.do_augmentation = do_augmentation
        self.cache = []  # To store parsed scenes
        self.samples = []  # To store indices of individual samples

        # Read and parse the file during initialization
        self._parse_file()

    def _parse_file(self):
        """Reads the file and caches scenes and their queries."""
        with open(self.file_path, 'r') as file:
            for line_idx, line in enumerate(file):
                line = line.strip().split()

                # Parse obstacles
                obstacles = []
                idx = 0
                while idx < len(line) and line[idx] == 'p':
                    idx += 1  # Skip 'p'
                    obstacle = []
                    while idx < len(line) and line[idx] != 'p' and line[idx] != 'q':
                        x, y = float(line[idx]), float(line[idx + 1])
                        obstacle.append((x, y))
                        idx += 2
                    obstacles.append(obstacle)

                # Parse queries and labels
                queries = []
                while idx < len(line):
                    if line[idx] == 'q':
                        idx += 1  # Skip 'q'
                        query = []
                        for _ in range(2):  # Each query has 2 coordinate pairs
                            x, y = float(line[idx]), float(line[idx + 1])
                            query.append((x, y))
                            idx += 2
                        label = int(line[idx])  # Label follows the query
                        idx += 1
                        queries.append((query, label))

                # Cache the parsed scene
                scene_idx = len(self.cache)
                self.cache.append((obstacles, queries))

                # Index individual samples
                for query_idx in range(len(queries)):
                    self.samples.append((scene_idx, query_idx))

    def _generate_vertex_order(self, vertices):
        """Generates cyclic random orders for vertices."""
        n = len(vertices)
        orders = []
        r = 1
        if self.do_augmentation:
            r = 2
        for _ in range(r):
            start_idx = random.randint(0, n - 1)
            orders.append(list(range(start_idx, n)) + list(range(0, start_idx)))
        return orders

    def _generate_obstacle_order(self, num_obstacles, num_order):
        """Generates random orders for obstacles."""
        orders = []
        if not self.do_augmentation:
            num_order = 1
        for _ in range(num_order):
            orders.append(random.sample(range(num_obstacles), num_obstacles))
        return orders

    def __len__(self):
        """Total number of samples (queries)."""
        return len(self.samples)

    def __getitem__(self, idx):
        """Returns a single sample: obstacles, query, label, and augmentation orders."""
        scene_idx, query_idx = self.samples[idx]
        obstacles, queries = self.cache[scene_idx]
        query, label = queries[query_idx]

        # Generate augmentation orders
        vertex_orders = [self._generate_vertex_order(obstacle) for obstacle in obstacles]
        obstacle_order = self._generate_obstacle_order(len(obstacles),3)

        return {
            'obstacles': obstacles,        # Original obstacle coordinates
            'query': query,                # Original query coordinates
            'label': label,                # Binary label (0 or 1)
            'vertex_orders': vertex_orders, # Vertex augmentation orders per obstacle
            'obstacle_order': obstacle_order # Obstacle augmentation orders
        }

# Custom collate function for batching
def collate_fn(batch):
    """Prepares a batch by grouping obstacles, queries, labels, and orders."""
    obstacle_batch = []
    query_batch = []
    label_batch = []
    vertex_orders_batch = []
    obstacle_orders_batch = []

    for item in batch:
        obstacle_batch.append(item['obstacles'])
        query_batch.append(item['query'])
        label_batch.append(item['label'])
        vertex_orders_batch.append(item['vertex_orders'])
        obstacle_orders_batch.append(item['obstacle_order'])

    return {
        'obstacles': obstacle_batch,          # List of obstacles for each sample
        'queries': query_batch,               # List of queries for each sample
        'labels': torch.tensor(label_batch, dtype=torch.float),  # Labels as tensor
        'vertex_orders': vertex_orders_batch, # Vertex augmentation orders
        'obstacle_orders': obstacle_orders_batch # Obstacle augmentation orders
    }


train_dataset = SceneDataset("/kaggle/input/data0007/train_3500.txt")
val_dataset = SceneDataset("/kaggle/input/data0001/val_50.txt",do_augmentation=False)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
    
