In [1]:
import sys

sys.path.append("../")

In [2]:
import numpy 
import torch
from models import Pilgrim 
from utils import open_pickle
from cube3_game import Cube3Game
from datasets import get_torch_scrambles

from sklearn.metrics import root_mean_squared_error

In [3]:
game = Cube3Game("../assets/envs/qtm_cube3.pickle")

In [4]:
generators = torch.tensor(game.actions, dtype=torch.int64)

In [5]:
permutations = torch.tensor(
    game.actions, 
    dtype=torch.int64,
    device="cpu",            
)

In [6]:
model = Pilgrim()
model.load_state_dict(torch.load("../assets/models/Cube3ResnetModel.pt"))
model.eval()

Pilgrim(
  (input_layer): Linear(in_features=324, out_features=400, bias=True)
  (hidden_layer): Linear(in_features=400, out_features=200, bias=True)
  (residual_blocks): ModuleList(
    (0-1): 2 x ResidualBlock(
      (fc1): Linear(in_features=200, out_features=200, bias=True)
      (bn1): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (dropout): Dropout(p=0.1, inplace=False)
      (fc2): Linear(in_features=200, out_features=200, bias=True)
      (bn2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (output_layer): Linear(in_features=200, out_features=1, bias=True)
  (output_probs_layer): Linear(in_features=200, out_features=12, bias=True)
  (relu): ReLU()
  (bn1): BatchNorm1d(400, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm1d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [7]:
validation = open_pickle("../assets/data/validation/val.pickle")

In [8]:
states, actions, values = get_torch_scrambles(
    n = 2,
    space_size = game.space_size,
    action_size = game.action_size,
    length = 26,
    permutations = permutations
)
# print(states[:2, :])

print("values:", values)
with torch.no_grad():
    predict_value, _ = model(states)
    print("predict:", predict_value.squeeze(1))

root_mean_squared_error(
    values.detach().cpu().numpy(),
    predict_value.squeeze(1).detach().cpu().numpy(),
)

values: tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
        15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26.,  1.,  2.,
         3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14., 15., 16.,
        17., 18., 19., 20., 21., 22., 23., 24., 25., 26.])
predict: tensor([ 1.4058,  2.8751,  3.9373,  4.9909,  3.9373,  5.7543,  7.5723, 10.5223,
        11.8231, 16.7446, 17.1952, 17.6661, 18.4018, 18.7154, 18.2653, 17.7984,
        17.4907, 18.5518, 18.5048, 17.8386, 18.0384, 16.9139, 16.7532, 17.6934,
        17.9158, 18.4731,  1.4058,  2.5966,  3.5465,  4.8532,  6.4005,  7.5903,
         7.7779,  9.7212, 10.6549,  9.9496, 13.4054, 16.7819, 16.9136, 13.7297,
        14.6201, 18.5912, 17.8902, 18.3051, 17.7731, 18.3851, 16.9488, 17.4943,
        17.6450, 16.4988, 18.1087, 18.8623])


3.7055116

In [9]:
# k = 2
# # state = torch.tensor(validation['states'][k], dtype=torch.int64).unsqueeze(0)
# # true_value = validation["values"][k]
# # print(state.shape, "; true_value:", true_value)
# predict_value, _ = model(state)
# print("predict:", predict_value)
# # V0 = torch.arange(0, 54)#.repeat_interleave(54//6)

# # def get_neighbors(states: torch.Tensor) -> torch.Tensor:
# #     n_gens = 12
# #     state_size = 54
    
# #     return torch.gather(
# #         states.unsqueeze(1).expand(states.size(0), n_gens, state_size), 
# #         2, 
# #         generators.unsqueeze(0).expand(states.size(0), n_gens, state_size)
# #     )

# # def search(
# #     state, 
# #     B: int = 10000
# # ):
# #     neighbors = get_neighbors(state).flatten(end_dim=1)
# #     values, _ = model(neighbors)

# #     print("neighbors:", neighbors.shape)
# #     print("values:", values)
# #     pass

# # search(state)

In [23]:
import torch

class BeamSearch:
    def __init__(self, model: torch.nn.Module, state: torch.Tensor, num_steps: int, generators: torch.Tensor, device: torch.device) -> None:
        """
        Initialize the BeamSearch class.

        :param model: Model to use for predictions.
        :param state: Initial state tensor.
        :param num_steps: Number of steps to perform in the search.
        :param generators: Generators to create new states.
        :param device: Device to perform computations (e.g., 'cuda', 'cpu').
        """
        self.model = model
        self.state = state
        self.num_steps = num_steps
        self.generators = generators
        self.n_gens = generators.size(0)
        self.state_size = generators.size(1)
        self.device = device
        self.hash_vec = torch.randint(0, 1_000_000_000_000, (self.state_size,))
        self.target_val = None

    def get_unique_states(self, states: torch.Tensor) -> torch.Tensor:
        """
        Get unique states by hashing.

        :param states: Tensor of states.
        :return: Tensor of unique states.
        """
        hashed = torch.sum(self.hash_vec * states, dim=1)
        hashed_sorted, idx = torch.sort(hashed)
        mask = torch.cat((torch.tensor([True]), hashed_sorted[1:] - hashed_sorted[:-1] > 0))
        return states[idx][mask]

    def get_neighbors(self, states: torch.Tensor) -> torch.Tensor:
        """
        Get neighboring states.

        :param states: Tensor of states.
        :return: Tensor of neighboring states.
        """
        return torch.gather(
            states.unsqueeze(1).expand(states.size(0), self.n_gens, self.state_size), 
            2, 
            self.generators.unsqueeze(0).expand(states.size(0), self.n_gens, self.state_size))

    def states_to_input(self, states: torch.Tensor) -> torch.Tensor:
        """
        Convert states to input tensor.

        :param states: Tensor of states.
        :return: Input tensor for the model.
        """
        return torch.nn.functional.one_hot(states, num_classes=6).view(-1, self.state_size * 6).to(torch.float)

    def batch_predict(self, model: torch.nn.Module, data: torch.Tensor, device: torch.device, batch_size: int) -> torch.Tensor:
        """
        Perform batch prediction.

        :param model: Model to use for predictions.
        :param data: Input data tensor.
        :param device: Device to perform computations (e.g., 'cuda', 'cpu').
        :param batch_size: Batch size for predictions.
        :return: Predictions tensor.
        """
        model.eval()
        model.to(device)

        n_samples = data.shape[0]
        outputs = []

        for start in range(0, n_samples, batch_size):
            end = start + batch_size
            batch = data[start:end].to(device)

            with torch.no_grad():
                batch_output, _  = model(batch)
                batch_output = batch_output.flatten()

            outputs.append(batch_output)

        final_output = torch.cat(outputs, dim=0)
        return final_output

    def predict_values(self, states: torch.Tensor) -> torch.Tensor:
        """
        Predict values for given states.

        :param states: Tensor of states.
        :return: Predicted values tensor.
        """
        return self.batch_predict(self.model, states, self.device, 4096).cpu()

    def predict_clipped_values(self, states: torch.Tensor) -> torch.Tensor:
        """
        Predict clipped values for given states.

        :param states: Tensor of states.
        :return: Clipped predicted values tensor.
        """
        return torch.clip(self.predict_values(states) - self.target_val, 0, torch.inf)

    def do_greedy_step(self, states: torch.Tensor, B: int = 1000) -> torch.Tensor:
        """
        Perform a greedy step in the search.

        :param states: Tensor of current states.
        :param B: Beam size.
        :return: Tensor of new states after the greedy step.
        """
        neighbors = self.get_neighbors(states).flatten(end_dim=1)
        neighbors = self.get_unique_states(neighbors)
        y_pred = self.predict_clipped_values(neighbors)
        idx = torch.argsort(y_pred)[:B]
        return neighbors[idx]

    def search(
            self, 
            # V0: torch.Tensor = torch.arange(6, dtype=torch.int64).repeat_interleave(54//6), 
            V0: torch.Tensor = torch.arange(0, 54, dtype=torch.int64),
            B: int = 1000
        ) -> int:
        """
        Perform the beam search.

        :param V0: Target state tensor.
        :param B: Beam size.
        :return: Number of steps to reach the target state, or -1 if not found.
        """
        self.target_val = self.predict_values(V0.unsqueeze(0)).item()
        states = self.state.clone()
        for j in range(self.num_steps):
            states = self.do_greedy_step(states, B)
            if (states == V0).all(dim=1).any():
                return j
        return -1
    
beam_search = BeamSearch(
    model, 
    states[20, :].unsqueeze(0), 
    100, 
    generators, 
    device="cpu"
)

beam_search.search(B=100000)

24