In [1]:
# Import necessary modules
import sys
import os

# Set root folder to project root
os.chdir(os.path.dirname(os.getcwd()))

# Add root folder to path
sys.path.append(os.getcwd())

import torch
from maze_dataset import set_serialize_minimal_threshold
from maze_dataset.generation import LatticeMazeGenerators
from maze_dataset.dataset.rasterized import MazeDatasetConfig, MazeDataset, RasterizedMazeDataset

from src.utils.loading import get_mazes, load_model
from src.utils.plotting import plot_mazes

In [None]:
def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Print available devices
print("Available devices:")
print(torch.cuda.device_count())

In [4]:
def get_mazes(dataset='maze-dataset', maze_size=9, num_mazes=10, gen='dfs_perc', percolation=0.0, deadend_start=True):
    """ Generate mazes of the given size and number, 
        from the given dataset, and load to device"""
    
    if dataset == 'maze-dataset':
        """ https://github.com/understanding-search/maze-dataset """

        assert maze_size % 2 == 1
        grid_n = (maze_size + 1) // 2
        
        # Generate base maze dataset
        if gen == 'dfs':
            maze_ctor = LatticeMazeGenerators.gen_dfs
            maze_ctor_kwargs = dict()
        elif gen == 'dfs_perc':
            maze_ctor = LatticeMazeGenerators.gen_dfs_percolation
            maze_ctor_kwargs = dict(p=percolation)
        elif gen == 'percolation':
            maze_ctor = LatticeMazeGenerators.gen_percolation
            maze_ctor_kwargs = dict(p=percolation)
        endpoint_kwargs=dict(deadend_start=deadend_start, endpoints_not_equal=True)

        base_dataset = MazeDataset.from_config(
            MazeDatasetConfig(
                name='test',
                grid_n=grid_n,
                n_mazes=num_mazes,
                seed=42,
                maze_ctor=maze_ctor, 
                maze_ctor_kwargs=maze_ctor_kwargs,
                endpoint_kwargs=endpoint_kwargs
            ),
            gen_parallel = True,
            local_base_path='data/maze_dataset/',
        )

        # Generate rasterized maze dataset
        dataset = RasterizedMazeDataset.from_base_MazeDataset(
            base_dataset=base_dataset,
            added_params=dict(
                remove_isolated_cells=True,
                extend_pixels=True, # maps from 1x1 to 2x2 pixels and adds 3 padding
            )
        )

        dataset = dataset.get_batch(idxs=None)

        # Get inputs
        inputs = dataset[0,:,:,:]
        inputs = inputs / 255.0
        inputs = inputs.permute(0, 3, 1, 2)
        inputs = inputs.float().detach().to(get_device(), dtype=torch.float32)

        # Get solutions
        solutions = dataset[1,:,:, :]
        solutions = solutions / 255.0
        solutions = solutions.permute(0, 3, 1, 2)
        solutions, _ = torch.max(solutions, dim=1)
        solutions = solutions.float().detach().to(get_device(), dtype=torch.float32) 

    elif dataset == 'easy-to-hard-data':
        """ https://github.com/aks2203/easy-to-hard-data """

        assert maze_size in [9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 59]
        # 50,000 training mazes for maze_size [9]
        # 10,000 testing mazes for each smaller maze_size in [9, 11, 13, 15, 17]
        # 1,000 testing mazes for each larger maze_size in [19, 21, 23, 25, 27, 29, 31, 33, 35, 37, 59]

        maze_dataset = EasyToHardMazeDataset(root='data/easy-to-hard-data/', train=False, size=maze_size)
        inputs = maze_dataset.inputs[:num_mazes].float().detach().to(get_device(), dtype=torch.float32)
        solutions = maze_dataset.targets[:num_mazes].float().detach().to(get_device(), dtype=torch.float32)

    return inputs, solutions

ValueError: no valid start or end positions found

In [3]:
# Create percolated maze

inputs, solutions = get_mazes(
    dataset='maze-dataset', 
    maze_size=9, 
    num_mazes=10,
    percolation=0.5,
    deadend_start=True)

plot_mazes(inputs, solutions=solutions, font_size=20, file_name='outputs/mazes/maze_percolated');

ValueError: no valid start or end positions found

In [None]:
# Plot predictions too

# model = load_model('dt_net')
# predictions = model.predict(inputs, iters=30)
# plot_mazes(inputs, predictions=predictions, solutions=solutions, font_size=20, file_name='outputs/mazes/maze_percolated_prediction');