# EfficientCube+: DNN-based Rubik's Cube Solver

## Overview
The standalone notebook serves as a demostration of the methods proposed in our ECE448 Final Project @ UIUC in SP2025 and provides the necessary code to reproduce the experiment.

## Related work
The work is based on the the following publiciation:
> K. Takano. Self-Supervision is All You Need for Solving Rubik's Cube. Transactions on Machine Learning Research, ISSN 2835-8856, 2023. URL: https://openreview.net/forum?id=bnBeNFB27b.

## Environment Reference
The notebook is designed to be run and tested on [Illinois Computes Research Notebooks](http://go.ncsa.illinois.edu/jupyter) (ICRN), which is equipped with the following resources:
- AMD EPYC-Milan Processor Core * 2
- 8GB of RAM
- NVIDIA A100-SXM4-80GB (shared)

## Setup
Due to the short time period of the project and the limited resources, the default training and searching configuration is sub-optimal aimed to achieve a balance between time/resource consumption and the performance of the resulting model.

To make comparsion to and reproduce the best-reported result and in the original [EfficientCube](https://github.com/kyo-takano/efficientcube) project, which we refered to, it is suggested to set `TrainConfig.num_steps = 2000000` and `SearchConfig.beam_width = 2**18`.

To accelerate training and inference, the mixed precision mode can be enabled by setting `ENABLE_FP16` to `True` with possible minor performance degradation.

In [40]:
class TrainConfig:
    max_depth = 26                          # God's Number
    batch_size_per_depth = 1000
    num_steps = 10000
    learning_rate = 1e-3
    INTERVAL_PLOT, INTERVAL_SAVE = 100, 1000
    ENABLE_FP16 = False                     # Set this to True if you want to train the model faster

class SearchConfig:
    beam_width = 2**11                      # This controls the trade-off between time and optimality
    max_depth = TrainConfig.max_depth * 2   # Any number above God's Number will do
    ENABLE_FP16 = False                     # Set this to True if you want to solve faster

In [41]:
import os
import pickle
import random
import time
from contextlib import nullcontext
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from cycler import cycler
from IPython.display import clear_output
from torch import nn
from tqdm import tqdm, trange

# Set the default color cycle for matplotlib plots
plt.rcParams["axes.prop_cycle"] = cycler(color=["#000000", "#2180FE", "#EB4275"])

# Enable TensorFloat32 (TF32) Training/Inference on Ampere (or higher) GPUs
# Supported on Nvidia A100 (test environment)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# Use GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [42]:
# Diagnostic information
print(f'device: {device}')
print(f'os.cpu_count(): {os.cpu_count()}')
!nvidia-smi -L

device: cpu
os.cpu_count(): 8


'nvidia-smi' is not recognized as an internal or external command,
operable program or batch file.


## Rubik's Cube
The Rubik's Cube is represented and operated based on the locations and color labels of $6\times3\times3$ stickers.

In this work, the [Quarter-Turn Metric](https://www.speedsolving.com/wiki/index.php/Metric#QTM) (90° turns count as one move; 180°, two) is employed for the movement of the cube.

Compared to the original implementation, our implementation makes full use of the **Vector and SIMD operation**, making the cube suitable to be operated on GPUs and achieved significant performance improvement on both training and inference.

In [43]:
class Cubes:
    """
    A class for a set of 3x3x3 Rubik's cubes.

    Each cube is represented as a 1D tensor of shape (6 * 3 * 3,).
    Initial color:

            0 0 0
            0 Y 0
            0 0 0

    2 2 2   5 5 5   3 3 3  4 4 4
    2 B 2   5 R 5   3 G 3  4 O 4
    2 2 2   5 5 5   3 3 3  4 4 4

            1 1 1
            1 W 1
            1 1 1
    
    Indices of state (each starting with 9*(n-1)):

                 2   5   8
                 1   4   7
                [0]  3   6
     20  23 26  47  50  53  29  32 35  38  41 44
     19  22 25  46  49  52  28  31 34  37  40 43
    [18] 21 24 [45] 48  51 [27] 30 33 [36] 39 42
                11   14 17
                10   13 16
                [9]  12 15
    """

    ## Class variables ##

    # Initialization indicator
    __initialized: bool = False

    # Dtype of the cube representation (0-6 integers) and position representation (0-53 integers)
    __dtype: torch.dtype = torch.long

    # Move map for the cube
    __MOVE_MAP = {
        'U': (np.array([ 6,  7,  8,  5,  2,  1,  0,  3, 47, 50, 53, 29, 32, 35, 38, 41, 44, 20, 23, 26]),
              np.array([ 0,  3,  6,  7,  8,  5,  2,  1, 20, 23, 26, 47, 50, 53, 29, 32, 35, 38, 41, 44])),
        'D': (np.array([15, 12,  9, 10, 11, 14, 17, 16, 36, 39, 42, 18, 21, 24, 45, 48, 51, 27, 30, 33]),
              np.array([ 9, 10, 11, 14, 17, 16, 15, 12, 18, 21, 24, 45, 48, 51, 27, 30, 33, 36, 39, 42])),
        'R': (np.array([27, 28, 29, 32, 35, 34, 33, 30, 38, 37, 36, 15, 16, 17, 51, 52, 53,  6,  7,  8]),
              np.array([29, 32, 35, 34, 33, 30, 27, 28, 15, 16, 17, 51, 52, 53,  6,  7,  8, 38, 37, 36])),
        'L': (np.array([20, 23, 26, 25, 24, 21, 18, 19, 42, 43, 44,  2,  1,  0, 47, 46, 45, 11, 10,  9]),
              np.array([26, 25, 24, 21, 18, 19, 20, 23,  2,  1,  0, 47, 46, 45, 11, 10,  9, 42, 43, 44])),
        'F': (np.array([45, 46, 47, 50, 53, 52, 51, 48, 24, 25, 26,  0,  3,  6, 29, 28, 27, 17, 14, 11]),
              np.array([47, 50, 53, 52, 51, 48, 45, 46,  0,  3,  6, 29, 28, 27, 17, 14, 11, 24, 25, 26])),
        'B': (np.array([36, 37, 38, 41, 44, 43, 42, 39, 35, 34, 33, 15, 12,  9, 18, 19, 20,  2,  5,  8]),
              np.array([38, 41, 44, 43, 42, 39, 36, 37,  2,  5,  8, 35, 34, 33, 15, 12,  9, 18, 19, 20])),
    }
    
    # Faces and turn directions
    __FACES: list[str] = ["U", "D", "L", "R", "B", "F"]

    # Available rotation degrees
    # Current only 90 degrees is supported
    ## [90 degrees clockwise, 90 degrees counter-clockwise]
    __DEGREES: list[str] = ["", "'"]

    # Goal state of the cube
    GOAL: torch.Tensor = torch.arange(0, 6 * 3 * 3, dtype=__dtype, device=device) // 9

    ## Variables to be initialized ##

    # Moves available for the cube
    MOVES: list[str] = None

    # Map of move names to indices
    MOVE_TO_INDEX: dict[str, int] = None

    # Source and target indices for the move map
    MOVE_MAP_SOURCE: torch.Tensor = None
    MOVE_MAP_TARGET: torch.Tensor = None

    # The moves available for scrambling the cube
    SCRAMBLE_MOVES_AVAILABLE: torch.Tensor = None

    @classmethod
    def init_class(cls, device=device) -> None:
        """
        Initialize the class variables.
        Supposed to be called once when the class is first loaded.

        Args:
            device (str): Device to use for the tensor.
        """
        # Check if the class has already been initialized
        if cls.__initialized:
            return
        
        # Initialize the moves available for the cube
        cls.MOVES = [f"{face}{degree}" for degree in cls.__DEGREES for face in cls.__FACES]

        # Initialize the move-to-index mapping
        cls.MOVE_TO_INDEX = {move: i for i, move in enumerate(cls.MOVES)}

        # Initialize the source and target indices for the move map
        cls.MOVE_MAP_SOURCE = torch.tensor(np.array([cls.__MOVE_MAP[move[0]][0 if "'" not in move else 1] for move in cls.MOVES]), dtype=cls.__dtype, device=device)
        cls.MOVE_MAP_TARGET = torch.tensor(np.array([cls.__MOVE_MAP[move[0]][1 if "'" not in move else 0] for move in cls.MOVES]), dtype=cls.__dtype, device=device)

        # Initialize the scramble moves available for the cube
        cls.SCRAMBLE_MOVES_AVAILABLE = torch.arange(len(cls.MOVES), dtype=cls.__dtype, device=device).repeat(len(cls.MOVES), 1)
        exclude = ((torch.arange(len(cls.MOVES), device=device) + 6) % len(cls.MOVES)).unsqueeze(1)
        cls.SCRAMBLE_MOVES_AVAILABLE = cls.SCRAMBLE_MOVES_AVAILABLE[cls.SCRAMBLE_MOVES_AVAILABLE != exclude].reshape(len(cls.MOVES), -1)

        # Set the initialization flag to True
        cls.__initialized = True
    
    @staticmethod
    def reverse_moves(moves: torch.Tensor) -> torch.Tensor:
        """
        Reverse the moves for the cube.

        Args:
            moves (torch.Tensor): Tensor of moves to reverse.

        Returns:
            torch.Tensor: Tensor of reversed moves.
        """
        # Reverse the moves and apply the inverse mapping
        return (moves + 6) % 12

    def __init__(self, tensor: torch.Tensor | None = None, num_cubes: int | None = 1, device: str | None = device):
        """
        Initialize the Cubes object.

        Args:
            tensor (torch.Tensor): Tensor representation of the cubes.
            num_cubes (int): Number of cubes to initialize.
            device (str): Device to use for the tensor.
        """
        # Call the class initialization method
        self.init_class(device=device)

        # Set the tensor for the cubes
        if tensor is None:
            # Check if num_cubes and device are provided
            if num_cubes is None or device is None:
                raise ValueError("Either tensor or both num_cubes and device must be provided")
            
            # Create num_cubes cubes in the goal state on the specified device
            self.reset(num_cubes=num_cubes, device=device)

        else:
            # Ignore the num_cubes and device arguments if tensor is provided
            self.tensor = tensor

            # Verify the tensor shape
            if (self.tensor.ndim == 1):
                self.tensor = self.tensor.unsqueeze(0)

            if (self.tensor.ndim != 2) or (self.tensor.shape[1] != 6 * 3 * 3):
                raise ValueError("Tensor must be of shape (num_cubes, 6 * 3 * 3)")
            
    def __len__(self):
        """
        Get the number of cubes.

        Returns:
            int: Number of cubes.
        """
        return self.tensor.shape[0]
    
    def __getitem__(self, index: int):
        """
        Get a specific cube by index.

        Args:
            index (int): Index of the cube.

        Returns:
            torch.Tensor: The cube at the specified index.
        """
        return self.tensor[index]
    
    def __setitem__(self, index: int, value: torch.Tensor):
        """
        Set a specific cube by index.

        Args:
            index (int): Index of the cube.
            value (torch.Tensor): New value for the cube.
        """
        if value.shape != (6 * 3 * 3,):
            raise ValueError("Value must be of shape (6 * 3 * 3)")
        self.tensor[index] = value.to(self.tensor.device)

    def __repr__(self):
        """
        Get a string representation of the cubes.

        Returns:
            str: String representation of the cubes.
        """
        return f"Cubes(tensor={self.tensor})"
    
    def to(self, device: str) -> None:
    
        """
        Move the cubes to the specified device.

        Args:
            device (str): Device to move the tensor to.
        """
        self.tensor = self.tensor.to(device)

    def reset(self, num_cubes: int | None = None, device: str | None = None) -> None:
        """
        Reset the cubes to the goal state.

        Args:
            num_cubes (int): Number of cubes to reset.
            device (str): Device to use for the tensor.
        """
        # Set parameters to default values if not provided
        if num_cubes is None:
            num_cubes = self.tensor.shape[0]
        if device is None:
            device = self.tensor.device
        
        # Move the goal state to the specified device
        self.GOAL = self.GOAL.to(device)

        # Create num_cubes cubes in the goal state on the specified device
        self.tensor = self.GOAL.unsqueeze(0).repeat(num_cubes, 1)

    def is_solved(self) -> torch.Tensor:
        """
        Check if the cubes are in the solved state.

        Returns:
            torch.Tensor: Boolean tensor indicating if each cube is solved.
        """
        return (self.tensor == self.GOAL).all(dim=1)
    
    def move(self, moves: str | list[str] | list[list[str]] | int | list[int] | list[list[int]] | torch.Tensor) -> None:
        """
        Apply a single or a sequence of moves to the cubes.

        Args:
            move (str): Move to apply.
        """
        # Convert the parameter to Tensor

        # str -> int
        if isinstance(moves, str):
            return self.move(self.MOVE_TO_INDEX[moves])
        
        # list[str] -> list[int]
        elif isinstance(moves, list) and all(isinstance(move, str) for move in moves):
            if len(moves) != self.tensor.shape[0]:
                raise ValueError("Length of move list must match the number of cubes")
            return self.move([self.MOVE_TO_INDEX[move] for move in moves])
        
        # list[list[str]] -> list[list[int]]
        elif isinstance(moves, list) and all(isinstance(step, list) and all(isinstance(move, str) for move in step) for step in moves):
            if any(len(moves[i]) != self.tensor.shape[0] for i in range(len(moves))):
                raise ValueError("Length of move list must match the number of cubes")
            return self.move([[self.MOVE_TO_INDEX[move] for move in round] for round in moves])
        
        # int -> Tensor
        elif isinstance(moves, int):
            moves = torch.full((self.tensor.shape[0],), moves, dtype=self.__dtype, device=self.tensor.device)
            return self.move(moves)
        
        # list[int] -> Tensor
        elif isinstance(moves, list) and all(isinstance(move, int) for move in moves):
            if len(moves) != self.tensor.shape[0]:
                raise ValueError("Length of move list must match the number of cubes")
            moves = torch.tensor(moves, dtype=self.__dtype, device=self.tensor.device)
            return self.move(moves)
        
        # list[list[int]] -> Tensor
        elif isinstance(moves, list) and all(isinstance(step, list) and all(isinstance(move, int) for move in step) for step in moves):
            if any(len(moves[i]) != self.tensor.shape[0] for i in range(len(moves))):
                raise ValueError("Length of move list must match the number of cubes")
            moves = torch.tensor(moves, dtype=self.__dtype, device=self.tensor.device)
            return self.move(moves)
        
        # 2D Tensor -> Tensor
        elif isinstance(moves, torch.Tensor) and moves.ndim == 2 and moves.shape[1] == self.tensor.shape[0]:
            moves = moves.to(self.tensor.device)
            for i in range(moves.shape[0]):
                self.move(moves[i])
            return
        
        # 1D Tensor -> Tensor
        elif isinstance(moves, torch.Tensor) and moves.ndim == 1:
            if moves.shape[0] != self.tensor.shape[0]:
                raise ValueError("Length of move list must match the number of cubes")
            self.__move_torch(moves)  # Implement move for cubes using the __move_torch method
            
        # Other types
        else:
            raise ValueError("Invalid move type or shape.")

    def __move_torch(self, move: torch.Tensor) -> None:
        """
        Apply a move to the cubes using PyTorch.

        Args:
            move (torch.Tensor): Tensor of moves to apply.
        """
        move = move.to(self.tensor.device)

        # Source indices for the move map
        self.MOVE_MAP_SOURCE = self.MOVE_MAP_SOURCE.to(self.tensor.device)
        source_idx = self.MOVE_MAP_SOURCE[move]

        # Target indices for the move map
        self.MOVE_MAP_TARGET = self.MOVE_MAP_TARGET.to(self.tensor.device)
        target_idx = self.MOVE_MAP_TARGET[move]

        # Batch indices for the cubes
        batch_idx = torch.arange(self.tensor.shape[0], device=self.tensor.device).unsqueeze(1)

        # Apply the move to the cubes
        self.tensor[batch_idx, target_idx] = self.tensor[batch_idx, source_idx]

    def scramble(self, scramble_length: int) -> torch.Tensor:
        """
        Generate a random scramble for the cubes.

        Args:
            scramble_length (int): Length of the scramble.

        Returns:
            torch.Tensor: Tensor of moves for the scramble.
        """
        # Generate a random scramble
        plan = self.plan_scramble(scramble_length)
        self.move(plan)

        return plan

    def plan_scramble(self, scramble_length: int) -> torch.Tensor:
        """
        Generate a random scramble plan.

        Args:
            scramble_length (int): Length of the scramble.

        Returns:
            torch.Tensor: Tensor of moves for the scramble plan.
        """
        # Move the needed variables to the device
        self.SCRAMBLE_MOVES_AVAILABLE = self.SCRAMBLE_MOVES_AVAILABLE.to(self.tensor.device)

        # Generate a random scramble plan
        plan = torch.empty((scramble_length, self.tensor.shape[0]), dtype=self.__dtype, device=self.tensor.device)
        for i in range(scramble_length):
            if i == 0:
                # The initial move is chosen randomly from the available moves
                plan[i] = torch.randint(0, len(self.MOVES), (self.tensor.shape[0],), dtype=self.__dtype, device=self.tensor.device)
            elif i == 1:
                # The second move is chosen randomly from the available moves, excluding the inverse of the first move
                plan[i] = torch.randint(0, self.SCRAMBLE_MOVES_AVAILABLE.shape[1], (self.tensor.shape[0],), dtype=self.__dtype, device=self.tensor.device)
                plan[i] = self.SCRAMBLE_MOVES_AVAILABLE[plan[i - 1], plan[i]]
            else:
                generate_idx = torch.arange(self.tensor.shape[0], device=self.tensor.device)
                while generate_idx.shape[0] > 0:
                    # Choose a random move from the available moves, excluding the inverse of the previous move
                    plan[i, generate_idx] = torch.randint(0, self.SCRAMBLE_MOVES_AVAILABLE.shape[1], (generate_idx.shape[0],), dtype=self.__dtype, device=self.tensor.device)
                    plan[i, generate_idx] = self.SCRAMBLE_MOVES_AVAILABLE[plan[i - 1, generate_idx], plan[i, generate_idx]]

                    # We use a range of 2 for redundancy checking
                    # Prevent three consecutive moves from being the same -> Can be replaced with a single move
                    # e.g. U U (U) -> U'
                    mask1 = (plan[i, generate_idx] == plan[i - 1, generate_idx]) & (plan[i - 1, generate_idx] == plan[i - 2, generate_idx])

                    # Prevent two mutually canceling moves sandwiching an opposite face move
                    # e.g. U D (U') -> D
                    mask2 = (self.reverse_moves(plan[i, generate_idx]) == plan[i - 2, generate_idx]) & ((plan[i, generate_idx] // 2) % 3 == (plan[i - 1, generate_idx] // 2) % 3) & (plan[i, generate_idx] % 6 != plan[i - 1, generate_idx] % 6)

                    # Continue if there are no invalid moves
                    generate_idx = generate_idx[mask1 | mask2]

        return plan

Cubes.init_class(device=device)  # Initialize the Cubes class with the specified device

## Model
This section defines the model used to predict the last move of the scrambling path.

In [44]:
class LinearBlock(nn.Module):
    """
    Linear layer with ReLU and BatchNorm
    """
    def __init__(self, input_prev, embed_dim):
        super(LinearBlock, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_prev, embed_dim),
            nn.ReLU(),
            nn.BatchNorm1d(embed_dim)
        )

    def forward(self, inputs):
        return self.layers(inputs)
    
class ResidualBlock(nn.Module):
    """
    Residual block with two linear layers
    """
    def __init__(self, embed_dim):
        super(ResidualBlock, self).__init__()
        self.layers = nn.Sequential(
            LinearBlock(embed_dim, embed_dim),
            LinearBlock(embed_dim, embed_dim)
        )

    def forward(self, inputs):
        x = inputs
        x = self.layers(x)
        x += inputs # skip-connection
        return x
    
class Model(nn.Module):
    """
    Fixed architecture following DeepCubeA.
    """
    def __init__(self, input_dim=6*3*3*6, output_dim=len(Cubes.MOVES)):
        super(Model, self).__init__()
        self.input_dim = input_dim
        self.layers = nn.Sequential(
            LinearBlock(input_dim, 5000),
            LinearBlock(5000,1000),
            ResidualBlock(1000),
            ResidualBlock(1000),
            ResidualBlock(1000),
            ResidualBlock(1000),
            nn.Linear(1000, output_dim)
        )

    def forward(self, inputs):
        # int indices => float one-hot vectors
        x = F.one_hot(inputs, num_classes=6).to(torch.float)
        x = x.reshape(-1, self.input_dim)
        x = self.layers(x)
        return x
    
model = Model()
model = model.to(device).compile()

## Training
In this section, the model is training using real-time generated data.

In [45]:
class EchoDataset(torch.utils.data.Dataset):
    """
    Dummy dataset to drive the training loop.

    The dataset generation logic is implemented in the collate_fn function.
    """

    def __init__(
            self, 
            total_samples = TrainConfig.batch_size_per_depth * TrainConfig.num_steps
        ):
        self.total_samples = total_samples

    def __len__(self):
        return self.total_samples

    def __getitem__(self, idx):
        # Return the index only
        return idx


def collate_fn(batch):
    """
    Collate function to generate a batch of data.

    Args:
        batch (list): List of indices from the dataset.

    Returns:
        tuple: A tuple containing the input and target tensors.
    """
    # Generate random cubes and moves
    num_cubes = len(batch)
    cubes = Cubes(num_cubes=num_cubes, device=device)

    # Prepare output data
    tensor = torch.empty(
        (TrainConfig.max_depth * num_cubes, 6 * 3 * 3), dtype=torch.long, device=device
    )

    # Generate cubes for each depth
    plan = cubes.plan_scramble(TrainConfig.max_depth)
    for i in range(plan.shape[0]):
        # Apply the moves to the cubes
        cubes.move(plan[i])

        # Store the cubes in the tensor
        tensor[i * num_cubes : (i + 1) * num_cubes] = cubes.tensor

    # Generate target moves
    moves = plan.flatten().to(device)

    return tensor, moves


dataloader = torch.utils.data.DataLoader(
    EchoDataset(),
    collate_fn=collate_fn,
    batch_size=TrainConfig.batch_size_per_depth,
)

In [None]:
def plot_loss_curve(losses):
    """
    Plot the loss curve.
    
    Args:
        h (list): List of loss values.
    """
    fig, ax = plt.subplots(1, 1)
    ax.plot(losses)
    ax.set_xlabel("Steps")
    ax.set_ylabel("Cross-entropy loss")
    ax.set_xscale("log")
    plt.show()

def train(model: Model, dataloader: torch.utils.data.DataLoader) -> Model:
    """
    Train the model on the dataset.
    
    Args:
        model (Model): The model to be trained.
        dataloader (torch.utils.data.DataLoader): DataLoader for the training data.
    """
    # Set the model to training mode
    model.train()

    # Loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=TrainConfig.learning_rate)

    # Data generator
    loop = tqdm(dataloader, unit="batch")

    # Training losses
    losses = []

    # Context manager for mixed precision training
    ctx = torch.amp.autocast('cuda', dtype=torch.float16) if TrainConfig.ENABLE_FP16 else nullcontext()

    # TODO: Change steps to epochs
    for batch_index, (batch_x, batch_y) in enumerate(loop):
        # Adjust data shape for the model
        batch_x, batch_y = batch_x.reshape(-1, 54).to(device), batch_y.reshape(-1).to(device)

        # Training step
        with ctx:
            pred_y = model(batch_x)
            loss = loss_fn(pred_y, batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update losses and progress bar
        losses.append(loss.item())
        loop.set_postfix(loss=loss.item())

        # Plot the loss curve
        # Plot the loss curve every INTERVAL_PLOT steps
        if TrainConfig.INTERVAL_PLOT and (batch_index+1) % TrainConfig.INTERVAL_PLOT == 0:
            clear_output()
            plot_loss_curve(losses)

        # Save the model
        # Save the model every INTERVAL_SAVE steps
        if TrainConfig.INTERVAL_SAVE and (batch_index+1) % TrainConfig.INTERVAL_SAVE == 0:
            torch.save(model.state_dict(), f"{batch_index+1}steps.pth")
            print("Model saved.")

    print(f"Trained on data equivalent to {TrainConfig.batch_size_per_depth * TrainConfig.num_steps} solves.")
    
    return model

model = train(model, dataloader)

## Inference
We test and comapare our model on the DeepCubeA dataset.

### Load Dataset
Retrieve the data set from GitHub if not exist.

In [9]:
# Download the DeepCubeA repository if not already present
if "DeepCubeA" != os.getcwd().split("/")[-1]:
    if not os.path.exists("DeepCubeA"):
        !git clone -q https://github.com/forestagostinelli/DeepCubeA
    %cd ./DeepCubeA/

# Load the data set
print('### Optimal Solver ###')
filename = 'data/cube3/test/data_0.pkl'
with open(filename, 'rb') as f:
    result_Optimal = pickle.load(f)

    print(result_Optimal.keys())
    result_Optimal["solution_lengths"] = [len(s) for s in result_Optimal["solutions"]]
    result_Optimal["solution_lengths_count"] = {
        i: result_Optimal["solution_lengths"].count(i)
        for i in range(min(result_Optimal["solution_lengths"]), max(result_Optimal["solution_lengths"]))
    }

    print('No. of cases:', len(result_Optimal["solution_lengths"]))

# Load the result of DeepCubeA for comparison
print('\n### DeepCubeA ###')
filename = 'results/cube3/results.pkl'
with open(filename, 'rb') as f:
    result_DeepCubeA = pickle.load(f)

    print(result_DeepCubeA.keys())
    result_DeepCubeA["solution_lengths"] = [len(s) for s in result_DeepCubeA["solutions"]]
    result_DeepCubeA["solution_lengths_count"] = {
        i: result_DeepCubeA["solution_lengths"].count(i)
        for i in range(min(result_DeepCubeA["solution_lengths"]), max(result_DeepCubeA["solution_lengths"]))
    }

    print('No. of cases:', len(result_DeepCubeA["solution_lengths"]))

%cd ../

c:\Users\haoti\code\uiuc\ece448\project\UIUC-ECE448-Project-SP25\DeepCubeA
### Optimal Solver ###
dict_keys(['states', 'times', 'solutions', 'num_nodes_generated'])
No. of cases: 1000

### DeepCubeA ###
dict_keys(['states', 'solutions', 'paths', 'times', 'num_nodes_generated'])
No. of cases: 1000
c:\Users\haoti\code\uiuc\ece448\project\UIUC-ECE448-Project-SP25


In [10]:
# Convert optimal solutions to test scrambles
def solution2scramble(solution):
    return [m[0] if m[1] == -1 else m[0] + "'" for m in solution[::-1]]

test_scrambles = [solution2scramble(s) for s in result_Optimal["solutions"]]

print(f"""Example:\n{result_Optimal["solutions"][0]}\n-> {test_scrambles[0]}""")

Example:
[['D', -1], ['F', 1], ['R', 1], ['U', -1], ['F', 1], ['F', 1], ['R', 1], ['U', 1], ['F', 1], ['R', 1], ['B', -1], ['R', -1], ['F', -1], ['R', -1], ['D', -1], ['U', -1], ['R', -1], ['U', -1], ['U', -1], ['R', -1], ['U', 1], ['B', -1]]
-> ['B', "U'", 'R', 'U', 'U', 'R', 'U', 'D', 'R', 'F', 'R', 'B', "R'", "F'", "U'", "R'", "F'", "F'", 'U', "R'", "F'", 'D']


### Beam Search
We use beam search to expand the traced set of possible solutions, which does not guarantee to give a solution but effiectively improved the probability of finding one (and the one as optimial as possible).

In [None]:
@torch.no_grad()
def beam_search(
    cubes: Cubes, 
    model: Model, 
    beam_width: int = SearchConfig.beam_width, 
    max_depth: int = SearchConfig.max_depth,
    skip_redundant_moves: bool = True   # FIXME: This is not implemented yet
    ) -> list[None | dict]:
    """
    Best-first beam search for the optimal solution.
    
    Args:
        cubes (Cubes): The cubes to be solved.
        model (Model): The model to be used for prediction.
        beam_width (int): The width of the beam search.
        max_depth (int): The maximum depth of the search.
        skip_redundant_moves (bool): Whether to skip redundant moves.
    
    Returns:
        torch.Tensor: The predicted moves for the solution.
    """
    # Set the model to evaluation mode
    model.eval()

    # Prepare the data structure for the beam search
    candidates = torch.empty((len(cubes), beam_width, 6 * 3 * 3), dtype=torch.long, device=cubes.tensor.device)
    candidates[:, 0] = cubes.tensor
    candidate_paths = torch.empty((len(cubes), beam_width, max_depth), dtype=torch.long, device=candidates.device)
    candidate_log_probs = torch.zeros((len(cubes), beam_width), dtype=torch.float, device=candidates.device)
    candidate_cube_idx = torch.arange(len(cubes), device=candidates.device).unsqueeze(1).expand(-1, beam_width)

    # Prepare the data structure for output
    output = [None] * len(cubes)
    time_0 = time.time()

    # Initialize the beam search
    for depth in range(max_depth):
        # Select the candidates for the current depth
        candidate_len = min(beam_width, len(Cubes.MOVES)**depth)
        active_candidates = candidates[:, :candidate_len].reshape(-1, 6 * 3 * 3)
        active_candidate_paths = candidate_paths[:, :candidate_len, :depth].reshape(candidate_paths.shape[0] * candidate_len, depth)
        active_candidate_cube_idx = candidate_cube_idx[:, :candidate_len].flatten()

        # Check if the candidates are already solved
        solved_mask = Cubes(active_candidates).is_solved()
        if solved_mask.any():
            # If any of the candidates are solved, update the output
            solved_active_candidate_idx = torch.arange(solved_mask.shape[0], device=candidates.device)[solved_mask]
            solved_cube_idx = active_candidate_cube_idx[solved_mask]
            time_duration = time.time() - time_0
            for i in range(solved_active_candidate_idx.shape[0]):
                solved_idx = solved_active_candidate_idx[i]
                cube_idx = solved_cube_idx[i]
                if output[cube_idx] is None:
                    output[cube_idx] = {
                        "solution": active_candidate_paths[solved_idx].cpu().numpy().tolist(),
                        "time": time_duration,
                        "depth": depth,
                    }

            # If all cubes are solved, break the loop
            solved_cube_idx = solved_cube_idx.unique()
            if solved_cube_idx.shape[0] == candidates.shape[0]:
                break
            
            # Remove the solved cubes from the candidates
            cube_mask = torch.ones(candidates.shape[0], dtype=torch.bool, device=candidates.device)
            cube_mask[solved_cube_idx.unique()] = False
            candidates = candidates[cube_mask]
            candidate_paths = candidate_paths[cube_mask]
            candidate_log_probs = candidate_log_probs[cube_mask]
            candidate_cube_idx = candidate_cube_idx[cube_mask]

            # Regenerate the active candidates
            active_candidates = candidates[:, :candidate_len].reshape(-1, 6 * 3 * 3)
            active_candidate_paths = candidate_paths[:, :candidate_len, :depth].reshape(candidate_paths.shape[0] * candidate_len, depth)
            active_candidate_cube_idx = candidate_cube_idx[:, :candidate_len].flatten()
            
        # Get the predictions from the model
        pred = model(active_candidates).reshape(candidates.shape[0], candidate_len, -1)

        # Calculate the log probabilities
        log_probs = F.log_softmax(pred, dim=-1)

        # Filter the log probabilities based on the active candidates
        active_candidate_log_probs = candidate_log_probs[:, :candidate_len]
        next_move_log_probs = active_candidate_log_probs.unsqueeze(-1) + log_probs
        next_move_log_probs = next_move_log_probs.reshape(next_move_log_probs.shape[0], -1)
        next_moves = torch.arange(log_probs.shape[-1], device=candidates.device).repeat(candidate_len).unsqueeze(0).expand(next_move_log_probs.shape[0], -1)
        next_move_candidate_idx = torch.arange(candidate_len, device=candidates.device).unsqueeze(-1).repeat(1, log_probs.shape[-1]).flatten().unsqueeze(0).expand(next_move_log_probs.shape[0], -1)

        # Remove redundant moves if specified
        if skip_redundant_moves and depth > 0:
            # Build the mask for redundant moves
            last_moves = active_candidate_paths[:, -1].reshape(candidate_paths.shape[0], candidate_len, 1).expand(-1, -1, log_probs.shape[-1]).reshape(next_move_log_probs.shape[0], -1)
            assert last_moves.shape == next_moves.shape, f"last_moves: {last_moves.shape}, next_moves: {next_moves.shape}"
            mask = Cubes.reverse_moves(last_moves) == next_moves
            next_move_log_probs = next_move_log_probs.masked_fill(mask, -float("inf"))
        
        # Filter next moves based on probabilities
        sorted_next_move_log_probs_idx = torch.argsort(next_move_log_probs, dim=-1, descending=True)[:, :beam_width]
        next_move_log_probs = next_move_log_probs.gather(1, sorted_next_move_log_probs_idx)
        next_moves = next_moves.gather(1, sorted_next_move_log_probs_idx)
        next_move_candidate_idx = next_move_candidate_idx.gather(1, sorted_next_move_log_probs_idx)

        # Update the candidates with the next moves
        candidates[:, :next_move_candidate_idx.shape[1]] = candidates.gather(1, next_move_candidate_idx.unsqueeze(-1).expand(-1, -1, 6 * 3 * 3))
        candidate_paths[:, :next_move_candidate_idx.shape[1], :depth] = active_candidate_paths.reshape(candidate_paths.shape[0], candidate_len, depth).gather(1, next_move_candidate_idx.unsqueeze(-1).expand(-1, -1, depth))
        candidate_log_probs[:, :next_move_candidate_idx.shape[1]] = next_move_log_probs
        
        # Apply the next moves to the candidates
        candidate_paths[:, :next_move_candidate_idx.shape[1], depth] = next_moves
        temp_cubes = Cubes(candidates[:, :next_move_candidate_idx.shape[1]].reshape(-1, 6 * 3 * 3))
        temp_cubes.move(Cubes.reverse_moves(next_moves.flatten()))
        candidates[:, :next_move_candidate_idx.shape[1]] = temp_cubes.tensor.reshape(candidates.shape[0], -1, 6 * 3 * 3).clone()
    
    return output


In [None]:
cube = Cubes(num_cubes=1)
cube.move([[move] for move in test_scrambles[1]])
output = beam_search(cube, model, beam_width=2*11)

In [None]:
cube.move(Cubes.reverse_moves(torch.tensor([[move] for move in output[0]["solution"]])))
cube.is_solved()

In [None]:
generator = iter(torch.utils.data.DataLoader(
    EchoDataset(),
    collate_fn=collate_fn,
    batch_size=2,
))

import matplotlib.patches as patches

def visualize(cubes):
    grid = [[-1, -1, -1,  2, 5, 8,  -1, -1, -1,  -1, -1, -1],
            [-1, -1, -1,  1, 4, 7,  -1, -1, -1,  -1, -1, -1],
            [-1, -1, -1,  0, 3, 6,  -1, -1, -1,  -1, -1, -1],
            [20, 23, 26,  47, 50, 53,  29, 32, 35,  38, 41, 44],
            [19, 22, 25,  46, 49, 52,  28, 31, 34,  37, 40, 43],
            [18, 21, 24,  45, 48, 51,  27, 30, 33,  36, 39, 42],
            [-1, -1, -1,  11, 14, 17,  -1, -1, -1,  -1, -1, -1],
            [-1, -1, -1,  10, 13, 16,  -1, -1, -1,  -1, -1, -1],
            [-1, -1, -1,  9, 12, 15,  -1, -1, -1,  -1, -1, -1]]
    fig, axes = plt.subplots(cubes.shape[0], 1, figsize=(12, 8 * cubes.shape[0]), dpi=100)
    for cube_idx, ax in enumerate(axes):
        ax.set_aspect('equal')
        ax.set_xlim(0, 12)
        ax.set_ylim(-8, 1)
        plt.axis('off')
        for i in range(9):
            for j in range(12):
                if grid[i][j] == -1:
                    continue
                color = {0: 'yellow', 1: 'white', 2: 'blue',
                        3: 'green', 4: 'orange', 5: 'red'}[cubes[cube_idx, grid[i][j]].item()]
                square = patches.Rectangle(
                    (j, -i), 1, 1, edgecolor='black', facecolor=color)
                ax.add_patch(square)
    plt.show()
cube = Cubes(num_cubes=2)
cube.move(["U", "U'"])
visualize(cube.tensor)

In [None]:
beam_search(cube, model)

In [37]:
torch.randn(3, 4 , 5)[:, :2].reshape(-1, 5).reshape(torch.randn(3, 4 , 5)[:, :2].shape)

tensor([[[-0.9506, -1.4207,  0.2974,  1.0400, -1.1955],
         [ 0.7597, -0.1511, -1.4915, -0.0370,  0.4546]],

        [[-0.4683,  1.4906,  0.4209, -0.8778,  0.3123],
         [-0.4545, -0.3256,  0.2490,  1.0260,  0.2440]],

        [[ 0.2344,  1.6269,  1.6757,  0.7930,  0.8204],
         [ 1.3450, -0.2901,  1.6133, -0.4839,  0.3678]]])

In [30]:
torch.randn(3, 4, 5).gather(1, torch.tensor([[0, 1], [2, 3], [1, 2]]).unsqueeze(0).expand(3, -1, -1))

tensor([[[-0.0636, -0.7576],
         [-1.6166,  0.8245],
         [-1.8340, -1.8919]],

        [[-1.0256, -1.0919],
         [ 0.4804,  0.6505],
         [-0.5176,  0.1917]],

        [[-0.4464,  0.7612],
         [-0.2182,  0.5445],
         [ 1.9417, -0.8868]]])

In [None]:
cubes = Cubes(num_cubes=12)
cubes.move(Cubes.MOVES)
cubes

In [None]:
cubes.is_solved()

In [None]:
reverse_moves = Cubes.reverse_moves(np.array([Cubes.MOVE_TO_INDEX[m] for m in Cubes.MOVES]))
cubes.move(reverse_moves.tolist())
cubes

In [None]:
cubes.is_solved()