In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torcheval.metrics.functional import r2_score

import time
import numpy as np
import math
import heapq

import sklearn
from sklearn import model_selection
from sklearn import linear_model
from catboost import CatBoostRegressor

import matplotlib.pyplot as plt
import pandas as pd
plt.rcParams["font.family"] = "serif"

import pickle 
import json

from tqdm import tqdm
from IPython.display import clear_output
from ipywidgets import IntProgress
from IPython.display import display

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
dtype_int   = torch.int64
# dtype_float = torch.float64
dtype_float = torch.float32

## Add

In [4]:
def tensor2set(states):
    return {tuple(state.tolist()) for state in states}
def set2tensor(states):
    return torch.tensor(list(states), dtype=dtype_int)

def tensor2list(states):
    return [tuple(state.tolist()) for state in states]
def list2tensor(states):
    return torch.tensor(states, dtype=dtype_int)

def states2X(states):
    return torch.nn.functional.one_hot(states, num_classes=6).view(-1, state_size*6).to(torch.float)

## RC3 generators

In [5]:
# QTM
n_gens = 12
state_size = 54
all_moves_kirill = list2tensor(
    [[0,1,2,3,4,5,44,41,38,15,12,9,16,13,10,17,14,11,6,19,20,7,22,23,8,25,26,27,28,29,30,31,32,33,34,35,36,37,45,39,40,46,42,43,47,24,21,18,48,49,50,51,52,53],
     [0,1,2,3,4,5,18,21,24,11,14,17,10,13,16,9,12,15,47,19,20,46,22,23,45,25,26,27,28,29,30,31,32,33,34,35,36,37,8,39,40,7,42,43,6,38,41,44,48,49,50,51,52,53],
     [42,39,36,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,0,21,22,1,24,25,2,29,32,35,28,31,34,27,30,33,51,37,38,52,40,41,53,43,44,45,46,47,48,49,50,26,23,20],
     [20,23,26,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,53,21,22,52,24,25,51,33,30,27,34,31,28,35,32,29,2,37,38,1,40,41,0,43,44,45,46,47,48,49,50,36,39,42],
     [0,1,11,3,4,14,6,7,17,9,10,47,12,13,50,15,16,53,24,21,18,25,22,19,26,23,20,8,28,29,5,31,32,2,34,35,36,37,38,39,40,41,42,43,44,45,46,33,48,49,30,51,52,27],
     [0,1,33,3,4,30,6,7,27,9,10,2,12,13,5,15,16,8,20,23,26,19,22,25,18,21,24,53,28,29,50,31,32,47,34,35,36,37,38,39,40,41,42,43,44,45,46,11,48,49,14,51,52,17],
     [9,1,2,12,4,5,15,7,8,45,10,11,48,13,14,51,16,17,18,19,20,21,22,23,24,25,26,27,28,6,30,31,3,33,34,0,38,41,44,37,40,43,36,39,42,35,46,47,32,49,50,29,52,53],
     [35,1,2,32,4,5,29,7,8,0,10,11,3,13,14,6,16,17,18,19,20,21,22,23,24,25,26,27,28,51,30,31,48,33,34,45,42,39,36,43,40,37,44,41,38,9,46,47,12,49,50,15,52,53],
     [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,42,43,44,18,19,20,21,22,23,15,16,17,27,28,29,30,31,32,24,25,26,36,37,38,39,40,41,33,34,35,51,48,45,52,49,46,53,50,47],
     [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,24,25,26,18,19,20,21,22,23,33,34,35,27,28,29,30,31,32,42,43,44,36,37,38,39,40,41,15,16,17,47,50,53,46,49,52,45,48,51],
     [2,5,8,1,4,7,0,3,6,36,37,38,12,13,14,15,16,17,9,10,11,21,22,23,24,25,26,18,19,20,30,31,32,33,34,35,27,28,29,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53],
     [6,3,0,7,4,1,8,5,2,18,19,20,12,13,14,15,16,17,27,28,29,21,22,23,24,25,26,36,37,38,30,31,32,33,34,35,9,10,11,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53]]
)
# V0 = torch.arange(6, dtype=dtype_int).repeat_interleave(state_size//6)
V0 = torch.arange(0, 54, dtype=dtype_int)

# Define inverse moves mapping
# inverse_moves = torch.tensor([1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10], dtype=dtype_int)

def get_neighbors(states):
    return torch.gather(
        states.unsqueeze(1).expand(states.size(0), n_gens, state_size), 
        2, 
        all_moves.unsqueeze(0).expand(states.size(0), n_gens, state_size))

hash_vec = torch.randint(0, 1_000_000_000_000, (state_size,))
def state2hash(states):
    return torch.sum(hash_vec * states, dim=1)

In [6]:
all_moves_kirill.shape

torch.Size([12, 54])

## Change all_movies

In [7]:
import sys

# base_dir = "./DeepDeepCube2"
base_dir = ".."

sys.path.append(base_dir)

from utils import open_pickle
from g_datasets import reverse_actions

# qtm = open_pickle("/teamspace/studios/this_studio/DeepDeepCube2/assets/envs/qtm_cube3.pickle")
qtm = open_pickle(f"{base_dir}/assets/envs/qtm_cube3.pickle")
qtm_actions = torch.from_numpy(np.array(qtm["actions"]))


all_moves = qtm_actions.cpu()
inverse_moves = reverse_actions(torch.arange(0, 12, dtype=dtype_int), n_gens=12).long()

## Neighborhood

In [8]:
def get_unique_elements_first_idx(tensor):
    # sort tensor
    sorted_tensor, indices = torch.sort(tensor) 
    # find position of jumps
    unique_mask = torch.cat((torch.tensor([True]), sorted_tensor[1:] != sorted_tensor[:-1]))
    return indices[unique_mask]

def get_next_ring_idx(ring_next, visited_hash):
    # индексы уникальных элементов без повторов
    new_hash = state2hash(ring_next)
    mask = ~torch.isin(new_hash, visited_hash)
    new_hash = new_hash[mask]
    idx = get_unique_elements_first_idx(new_hash)
    idx = torch.arange(ring_next.size(0))[mask][idx]
    return idx

In [9]:
# neighborhood_size = 5

# ring_0         = V0.unsqueeze(0)
# neighborhood_S = ring_0.clone()                      # состояния окрестности
# neighborhood_Y = torch.tensor((0,), dtype=dtype_int) # дистанция окрестности
# neighborhood_Z = torch.tensor((0,), dtype=dtype_int) # мосты окрестности
# visited_hash   = state2hash(neighborhood_S)

# ring_last = ring_0.clone()
# for j in tqdm(range(1, neighborhood_size+1)):
#     # do all steps
#     ring_next  = get_neighbors(ring_last).flatten(end_dim=1)
#     moves_next = torch.arange(12).repeat(ring_last.size(0))
    
#     # delete steps back and copies
#     idx        = get_next_ring_idx(ring_next, visited_hash)
#     ring_next  = ring_next[idx]
#     moves_next = moves_next[idx]
    
#     # save ring
#     neighborhood_S = torch.concat((neighborhood_S, ring_next))
#     neighborhood_Y = torch.concat((neighborhood_Y, torch.tensor([j] * ring_next.size(0))))
#     neighborhood_Z = torch.concat((neighborhood_Z, inverse_moves[moves_next]))
    
#     # prepare to next step
#     visited_hash   = torch.concat((visited_hash, state2hash(ring_next)))
#     ring_last      = ring_next.clone()
# clear_output()

# neighborhood_X = states2X(neighborhood_S)
# neighborhood_H = state2hash(neighborhood_S)
# print(f"# states in V0 neighborhood {neighborhood_size} = {neighborhood_X.size(0)}")

## Dataset

In [10]:
def do_random_step(states, last_moves):
    # Создаем массив возможных ходов для каждого состояния, исключая обратные ходы
    possible_moves = torch.ones((states.size(0), n_gens), dtype=torch.bool)
    possible_moves[torch.arange(states.size(0)), inverse_moves[last_moves]] = False
    
    # Генерация случайных индексов допустимых ходов
    next_moves = torch.multinomial(possible_moves.float(), 1).squeeze()
    
    # Применение ходов к состояниям и возврат новых состояний и последних ходов
    # print("do_random_step, states:", states.shape)
    # print("do_random_step, all_moves:", all_moves.shape)
    # print("do_random_step, next_moves:", next_moves.shape)
    # print("do_random_step, all_moves[next_moves]:", all_moves[next_moves].shape)
    return torch.gather(states, 1, all_moves[next_moves]), next_moves

def generate_random_walks(k=10000, K_min=1, K_max=28):
    dataset = torch.zeros(((K_max - K_min + 1) * k, state_size), dtype=dtype_int)
    Ks = torch.arange(K_min, K_max+1).repeat_interleave(k)
    Zs = torch.zeros(((K_max - K_min + 1) * k,), dtype=dtype_int)
    for (j, K) in enumerate(range(K_min, K_max+1)):
        states = V0.repeat(k, 1)
        last_moves = torch.full((k,), -1, dtype=dtype_int)  # Initialize with invalid move index
        for _ in range(K):
            states, last_moves = do_random_step(states, last_moves)
        dataset[j * k : (j+1) * k] = states
        Zs[j * k : (j+1) * k] = inverse_moves[last_moves]
    
    dataset_hash = state2hash(dataset)
    mask = ~torch.isin(dataset_hash, neighborhood_H)
    
    dataset = torch.concat((neighborhood_S, dataset[mask]))
    Ks      = torch.concat((neighborhood_Y, Ks[mask]))
    Zs      = torch.concat((neighborhood_Z, Zs[mask]))
    
    return dataset, Ks, Zs

In [11]:
def generate_xyz_dataset(k=100000, K_min=1, K_max=28):
    dataset, Ks, Zs = generate_random_walks(k=k, K_min=K_min, K_max=K_max)

    IDX = torch.arange(dataset.size(0))
    idx_train, idx_val = sklearn.model_selection.train_test_split(IDX, test_size=0.2)
    
    X = torch.nn.functional.one_hot(dataset, num_classes=6).view(-1, state_size*6)
    X_train = X[idx_train].to(torch.float)
    Y_train = Ks[idx_train].to(torch.float)
    Z_train = Zs[idx_train]
    X_val   = X[idx_val].to(torch.float)
    Y_val   = Ks[idx_val].to(torch.float)
    Z_val   = Zs[idx_val]
    
    return X_train, Y_train, Z_train, X_val, Y_val, Z_val

## Test

In [12]:
# XYZ = generate_xyz_dataset(k=60_000, K_min=1, K_max=30)

In [13]:
# # pred_v, pred_p = batch_processVP(model, XYZ[0], device, 4096)
# # pred_v, pred_p = pred_v.cpu(), pred_p.cpu()
# true_v, true_p = XYZ[1], XYZ[2]

## Beam Search

In [14]:
class TimeContext:
    def __init__(self, label: str):
        self.label = label

    def __enter__(self):
        self.start = time.time()

    def __exit__(self, exc_type, exc_value, traceback):
        self.end = time.time()
        if self.label not in log.keys():
            log[self.label] = []
        log[self.label].append(self.end - self.start)

In [15]:
torch.set_float32_matmul_precision('high') #high, highest

In [16]:
def batch_processVP(model, data, device, batch_size):
    """
    Process data through a model in batches.

    :param data: Tensor of input data
    :param model: A PyTorch model with a forward method that accepts data
    :param device: Device to perform computations (e.g., 'cuda', 'cpu')
    :param batch_size: Number of samples per batch
    :return: Concatenated tensor of model outputs
    """
    n_samples = data.size(0)
    outputs_v = torch.zeros((n_samples,), dtype=dtype_float, device=device)
    outputs_p = torch.zeros((n_samples,n_gens), dtype=dtype_float, device=device)

    # Process each batch
    for start in range(0, n_samples, batch_size):
        end = start + batch_size
        batch = data[start:end].to(device)

        with torch.no_grad():
            batch_output_v, batch_output_p = model(batch)
        
        # Store the output
        outputs_v[start:end] = batch_output_v.squeeze(dim=1)
        outputs_p[start:end] = batch_output_p

    return outputs_v, outputs_p

In [17]:
dummy_true = torch.tensor([True])
def get_unique_states(states):
    hashed = torch.sum(hash_vec * states, dim=1)
    hashed_sorted, idx = torch.sort(hashed)
    mask = torch.concat((dummy_true, hashed_sorted[1:] - hashed_sorted[:-1] > 0))
    return states[idx[mask]], idx[mask]

In [18]:
def check_B(B, tests, num_steps=100, device=torch.device("cpu")):
    global all_moves, hash_vec, dummy_true, V0
    
    all_moves = all_moves.to(device)
    hash_vec = hash_vec.to(device)
    dummy_true = dummy_true.to(device)
    V0 = V0.to(device)
    
    paths    = []
    ts = []
    num_tests = tests.size(0)
    for i in tqdm(range(num_tests), desc=f'B={B:6d}'):
        states = tests[i].unsqueeze(0).to(device)
        y_pred = torch.tensor([0], dtype=dtype_float, device=device)
        
        t1 = time.time()
        for j in range(num_steps):
            states, y_pred, idx = do_greedy_step(states, y_pred, B)
            if (states==V0).all(dim=1).any():
                break
        t2 = time.time()
        
        if j+1 < num_steps:
            ts.append(t2-t1)
            paths.append(j+1)
        else:
            print("Not found!")
            
    clear_output()
    
    win_prob = len(paths) / num_tests
    print(f"beam size    = {B}")
    print(f"# tests      = {num_tests}")
    print(f"win prob     = {win_prob:.3f}")
    print(f"time         = {np.mean(ts):.2f} s")
    print(f"avg path len = {np.mean(paths):.1f} ± {np.std(paths)/np.sqrt(len(paths)):.2f}")
    
    return paths, win_prob, ts

In [19]:
def get_tests(num_tests, R=1000):
    states = V0.unsqueeze(0).repeat(num_tests, 1)
    last_move = torch.full((num_tests,), -1, dtype=dtype_int)  # initialize with invalid move index
    R = 1000
    for _ in range(R):
        states, last_move = do_random_step(states, last_move)
    mask = torch.randint(0,1+1, (num_tests,)) > 0 # random parity
    states[mask], last_move[mask] = do_random_step(states[mask], last_move[mask])
    return states

## Model

In [20]:
# from pilgrim import PilgrimVP, count_parameters
# model = PilgrimVP(input_dim=6*state_size, hd1=5000, hd2=1000)

# model.to(device)
# model.eval()

# print("model is ready:", count_parameters(model))
# print(model.load_state_dict(torch.load("weights/VP_5000x1000_0.pth")))

In [28]:
import sys
sys.path.append("./DeepDeepCube2")

from models import Pilgrim, count_parameters

# model = Pilgrim(
#     input_dim = 54, 
#     hidden_dim1 = 5000, 
#     hidden_dim2 = 1000, 
#     num_residual_blocks = 4 
# ) # ~14M

# model.to(device)
# model.eval()

# print("model is ready:", count_parameters(model))
# print(model.load_state_dict(torch.load(
#     f"{base_dir}/assets/models/Cube3ResnetModel_value_policy_3_8B_14M.pt",
#     map_location=device
# )))
model = torch.load(f"{base_dir}/assets/models/pruning_finetune_Cube3ResnetModel_value_policy_3_8B_14M.pt", map_location=device)

In [29]:
def pred_d(states):
    pred_v, pred_p = batch_processVP(model, states, device, 4096)
    mask_finish = (states==V0).all(dim=1)
    pred_v[mask_finish] = 0
    return torch.clip(pred_v, 0, torch.inf), pred_p, mask_finish.any()

In [30]:
# value
def do_greedy_step(states, value_last, B=1000):
    # индексы элементов
    idx0 = torch.arange(states.size(0), device=device).repeat_interleave(n_gens)
    # найти соседей
    neighbors = get_neighbors(states).flatten(end_dim=1)
    # отфильтровать соседей
    neighbors, idx1 = get_unique_states(neighbors)
    # посчитать value на соседях
    value = pred_d(neighbors)[0]
    # отсортировать и обрубить
    idx2 = torch.argsort(value)[:B]
    return neighbors[idx2], value[idx2], idx0[idx1[idx2]]

In [31]:
# # policy
# def do_greedy_step(states, policy_last, B=1000):
#     # распространить policy_last
#     with TimeContext("policy_repeat"):
#         idx0 = torch.arange(states.size(0), device=device).repeat_interleave(n_gens)
#         policy_last = policy_last.repeat_interleave(n_gens)
#     # посчитать policy
#     with TimeContext("pred_d"):
#         policy = pred_d(states)[1]
#         policy = F.softmax(policy, dim=1)
#         policy = policy.flatten()
#     # найти соседей
#     with TimeContext("get_neighbors"):
#         neighbors = get_neighbors(states).flatten(end_dim=1)
#     # отфильтровать соседей
#     with TimeContext("get_unique_states"):
#         neighbors, idx1 = get_unique_states(neighbors)
#     # отсортировать и обрубить
#     with TimeContext("argsort"):
#         policy_next = policy_last[idx1] + torch.log(policy[idx1])
#         idx2 = torch.argsort(-policy_next)[:B]
#     return neighbors[idx2], policy_next[idx2], idx0[idx1[idx2]]

## Run

In [32]:
tests = get_tests(10)

In [33]:
# value
log = {}
paths, *_ = check_B(100, tests.to(device), device=device)

beam size    = 100
# tests      = 10
win prob     = 0.600
time         = 6.44 s
avg path len = 38.7 ± 3.16


In [None]:
pd.Series(paths).value_counts()

In [None]:
# time in ms
pd.DataFrame(log).mean() * 1e3