In [1]:
import pandas as pd
import torch
import numpy as np

In [2]:
CONFIG = {
    'hash_encoding': {
        'num_levels': 16,
        'level_dim': 2,
        'input_dim': 3,
        'log2_hashmap_size': 19,
        'base_resolution': 16
    },
    'mlp': {
        'num_layers': 3,  # Number of layers in geometric MLP
        'hidden_dim': 64,  # Hidden dimension size
    }
}

def load_torch_weights(file_path):
    """Load model weights from a checkpoint file."""
    try:
        weights = torch.load(file_path, map_location='cpu')
        return weights['model']
    except Exception as e:
        print(f"Error loading file {file_path}: {e}")
        return None
    
model_path_base = 'shared_data/CarrotKhanStatue/base_000_000_000/checkpoints/final.pth'
model_path_x_180 = 'shared_data/CarrotKhanStatue/x_180_000_000/checkpoints/final.pth'
model_path_test = 'shared_data/GoldBag/base_000_000_000/checkpoints/final.pth'


nerf_base = load_torch_weights(model_path_base)
nerf_x_180 = load_torch_weights(model_path_x_180)
nerf_test = load_torch_weights(model_path_test)

In [3]:
def extract_hash_encoding_structure(model_weights, num_levels=16, level_dim=2, input_dim=3, log2_hashmap_size=19, base_resolution=16):
    """
    Extract and organize hash encoding weights into hierarchical structure.
    
    Args:
        model_weights (dict): The loaded model weights dictionary
        num_levels (int): Number of levels in hash encoding
        level_dim (int): Dimension of encoding at each level
        input_dim (int): Input dimension (typically 3 for 3D)
        log2_hashmap_size (int): Log2 of maximum hash table size
        base_resolution (int): Base resolution of the grid
        
    Returns:
        dict: Hierarchical structure of hash encoding weights
    """
    # Extract hash encoding embeddings
    embeddings = model_weights['_orig_mod.grid_encoder.embeddings']
    
    # Calculate per-level parameters
    max_params = 2 ** log2_hashmap_size
    per_level_scale = np.exp2(np.log2(2048 / base_resolution) / (num_levels - 1))
    
    # Initialize structure to store weights
    hash_structure = {}
    offset = 0
    
    for level in range(num_levels):
        # Calculate resolution at this level
        resolution = int(np.ceil(base_resolution * (per_level_scale ** level)))
        
        # Calculate number of parameters for this level
        params_in_level = min(max_params, (resolution) ** input_dim)
        params_in_level = int(np.ceil(params_in_level / 8) * 8)  # make divisible by 8
        
        # Extract weights for this level
        level_weights = embeddings[offset:offset + params_in_level]
        
        # Store level information
        hash_structure[f'level_{level}'] = {
            'resolution': resolution,
            'num_params': params_in_level,
            'weights': level_weights,
            'weights_shape': level_weights.shape,
            'scale': per_level_scale ** level
        }
        
        offset += params_in_level
    
    # Add global information
    hash_structure['global_info'] = {
        'total_params': offset,
        'embedding_dim': level_dim,
        'base_resolution': base_resolution,
        'max_resolution': int(np.ceil(base_resolution * (per_level_scale ** (num_levels-1)))),
        'per_level_scale': per_level_scale
    }
    
    return hash_structure

In [4]:
mrhe_by_layer_base = extract_hash_encoding_structure(nerf_base)

mrhe_by_layer_x_180 = extract_hash_encoding_structure(nerf_x_180)

mrhe_by_layer_nerf_test = extract_hash_encoding_structure(nerf_test)

In [5]:
base_dict = {layer: info['weights'] for layer, info in mrhe_by_layer_base.items() if layer.startswith('level_')}

In [6]:
base_dict

{'level_0': tensor([[-0.0902, -0.0635],
         [ 0.1001,  0.1045],
         [ 0.1441,  0.0420],
         ...,
         [ 0.2313,  0.1918],
         [ 0.1921,  0.0209],
         [-0.0819, -0.0145]]),
 'level_1': tensor([[-0.0154,  0.0137],
         [-0.0872,  0.0826],
         [-0.0383,  0.1021],
         ...,
         [-0.8903,  0.2423],
         [ 0.0768, -0.0400],
         [-0.1069,  0.0990]]),
 'level_2': tensor([[-0.0608,  0.5266],
         [-0.4926,  0.6712],
         [-0.3840,  0.9586],
         ...,
         [-0.2585, -0.2453],
         [-0.0279, -0.0033],
         [ 0.3932,  0.5994]]),
 'level_3': tensor([[-0.3374, -0.3020],
         [-0.4407, -0.8653],
         [ 0.1957, -0.4545],
         ...,
         [-0.5310, -0.8185],
         [-0.6661, -0.0695],
         [ 0.1364,  0.2074]]),
 'level_4': tensor([[ 2.4145e-01, -1.7465e-01],
         [ 8.2795e-01,  1.2799e+00],
         [ 1.7449e+00,  1.3036e+00],
         ...,
         [-2.8503e-05, -3.0760e-05],
         [-1.3017e-05, 

In [7]:
# turned 180 CarrotKhanStatue
x_180_dict = {layer: info['weights'] for layer, info in mrhe_by_layer_x_180.items() if layer.startswith('level_')}

In [8]:
# different nerf to check back (currently Goldback)
test_dict = {layer: info['weights'] for layer, info in mrhe_by_layer_nerf_test.items() if layer.startswith('level_')}

In [9]:
tokens = []
positions = []
global_index = 0  # zählt alle Tokens über alle Layer hinweg

for key in sorted(base_dict.keys(), key=lambda x: int(x.split("_")[1])):
    layer_index = int(key.split("_")[1])              # z.B. 'level_3' → 3
    layer_tensor = base_dict[key]                     # Tensor der Form [N, 2]
    num_tokens_in_layer = layer_tensor.shape[0]

    for position_in_layer in range(num_tokens_in_layer):
        token = layer_tensor[position_in_layer]        # Ein Vektor [2]
        tokens.append(token)
        positions.append(torch.tensor([global_index, layer_index, position_in_layer]))
        global_index += 1

In [10]:
tokens = torch.stack(tokens)        # → Shape: [N, 2]
positions = torch.stack(positions)  # → Shape: [N, 3]

In [11]:
tokens

torch.Size([6098120, 2])

In [12]:
positions

tensor([[      0,       0,       0],
        [      1,       0,       1],
        [      2,       0,       2],
        ...,
        [6098117,      15,  524285],
        [6098118,      15,  524286],
        [6098119,      15,  524287]])