In [29]:
from typing import List
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F


class PuzzleHeuristicModel(nn.Module):
    def __init__(self, input_size, hidden_sizes):
        super(PuzzleHeuristicModel, self).__init__()
        all_sizes = [input_size] + hidden_sizes + [1]
        self.hidden_layers = nn.ModuleList(
            [nn.Linear(all_sizes[i], all_sizes[i + 1]) for i in range(len(all_sizes) - 1)]
        )

    def forward(self, x):
        for layer in self.hidden_layers[:-1]:
            x = F.relu(layer(x))
        x = self.hidden_layers[-1](x)
        return x


def puzzle_model_inference(input_data: List[int], model: PuzzleHeuristicModel) -> int:
    input_data = torch.tensor(
        np.eye(16)[np.array(input_data).astype(np.int64)].ravel(), dtype=torch.float32
    ).unsqueeze(
        0
    )  # Convert input list to tensor

    with torch.no_grad():
        output = model(input_data)
        predicted_value = output.item()

    return int(round(predicted_value))

In [8]:
# input_size = 256
# hidden_sizes = [1024, 1024, 512, 128, 64]
# model = PuzzleHeuristicModel(input_size, hidden_sizes)

# checkpoint_path = "../../data/models/run5/immediate/best_puzzle_model.pth"
# checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))

# if "module." in list(checkpoint.keys())[0]:
#     new_state_dict = {}
#     for key, value in checkpoint.items():
#         new_key = key.replace("module.", "")
#         new_state_dict[new_key] = value
#     checkpoint = new_state_dict

# model.load_state_dict(checkpoint)
# model.eval()

PuzzleHeuristicModel(
  (hidden_layers): ModuleList(
    (0): Linear(in_features=256, out_features=1024, bias=True)
    (1): Linear(in_features=1024, out_features=1024, bias=True)
    (2): Linear(in_features=1024, out_features=512, bias=True)
    (3): Linear(in_features=512, out_features=128, bias=True)
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): Linear(in_features=64, out_features=1, bias=True)
  )
)

In [30]:
def _get_model(device):
    input_size = 256
    hidden_sizes = [1024, 1024, 512, 128, 64]
    model = PuzzleHeuristicModel(input_size, hidden_sizes)

    checkpoint_path = "../../data/models/run5/immediate/best_puzzle_model.pth"
    checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))

    if "module." in list(checkpoint.keys())[0]:
        new_state_dict = {}
        for key, value in checkpoint.items():
            new_key = key.replace("module.", "")
            new_state_dict[new_key] = value
        checkpoint = new_state_dict

    model.load_state_dict(checkpoint)
    model.eval()

    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)

    return model.to(device)

In [31]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = _get_model(device)

In [32]:
input_data = [1, 2, 3, 4, 5, 6, 7, 8, 0, 10, 11, 12, 9, 13, 14, 15]  # Input as a list

predicted_value = puzzle_model_inference(input_data, model)
print("Predicted Value:", predicted_value)

Predicted Value: 5
