In [70]:
from pathlib import Path
import pandas as pd
import torch
import numpy as np


In [72]:
CONFIG = {
    'hash_encoding': {
        'num_levels': 16,
        'level_dim': 2,
        'input_dim': 3,
        'log2_hashmap_size': 19,
        'base_resolution': 16
    },
    'mlp': {
        'num_layers': 3,  # Number of layers in geometric MLP
        'hidden_dim': 64,  # Hidden dimension size
    }
}

def load_torch_weights(file_path):
    """Load model weights from a checkpoint file."""
    try:
        weights = torch.load(file_path, map_location='cpu')
        return weights['model']
    except Exception as e:
        print(f"Error loading file {file_path}: {e}")
        return None
    
model_path = 'shared_data/CarrotKhanStatue/base_000_000_000/checkpoints/final.pth'
nerf = load_torch_weights(model_path)

  weights = torch.load(file_path, map_location='cpu')


In [73]:
for k in nerf.keys():
    print(f"{k}".ljust(40) + f" at nerf[k] you find a {type(nerf[k])}")

_orig_mod.aabb_train                     at nerf[k] you find a <class 'torch.Tensor'>
_orig_mod.aabb_infer                     at nerf[k] you find a <class 'torch.Tensor'>
_orig_mod.density_grid                   at nerf[k] you find a <class 'torch.Tensor'>
_orig_mod.density_bitfield               at nerf[k] you find a <class 'torch.Tensor'>
_orig_mod.grid_encoder.embeddings        at nerf[k] you find a <class 'torch.Tensor'>
_orig_mod.grid_encoder.offsets           at nerf[k] you find a <class 'torch.Tensor'>
_orig_mod.grid_mlp.net.0.weight          at nerf[k] you find a <class 'torch.Tensor'>
_orig_mod.grid_mlp.net.1.weight          at nerf[k] you find a <class 'torch.Tensor'>
_orig_mod.grid_mlp.net.2.weight          at nerf[k] you find a <class 'torch.Tensor'>
_orig_mod.view_mlp.net.0.weight          at nerf[k] you find a <class 'torch.Tensor'>
_orig_mod.view_mlp.net.1.weight          at nerf[k] you find a <class 'torch.Tensor'>
_orig_mod.view_mlp.net.2.weight          at nerf[k] yo

In [74]:
def extract_hash_encoding_structure(model_weights, num_levels=16, level_dim=2, input_dim=3, log2_hashmap_size=19, base_resolution=16):
    """
    Extract and organize hash encoding weights into hierarchical structure.
    
    Args:
        model_weights (dict): The loaded model weights dictionary
        num_levels (int): Number of levels in hash encoding
        level_dim (int): Dimension of encoding at each level
        input_dim (int): Input dimension (typically 3 for 3D)
        log2_hashmap_size (int): Log2 of maximum hash table size
        base_resolution (int): Base resolution of the grid
        
    Returns:
        dict: Hierarchical structure of hash encoding weights
    """
    # Extract hash encoding embeddings
    embeddings = model_weights['_orig_mod.grid_encoder.embeddings']
    
    # Calculate per-level parameters
    max_params = 2 ** log2_hashmap_size
    per_level_scale = np.exp2(np.log2(2048 / base_resolution) / (num_levels - 1))
    
    # Initialize structure to store weights
    hash_structure = {}
    offset = 0
    
    for level in range(num_levels):
        # Calculate resolution at this level
        resolution = int(np.ceil(base_resolution * (per_level_scale ** level)))
        
        # Calculate number of parameters for this level
        params_in_level = min(max_params, (resolution) ** input_dim)
        params_in_level = int(np.ceil(params_in_level / 8) * 8)  # make divisible by 8
        
        # Extract weights for this level
        level_weights = embeddings[offset:offset + params_in_level]
        
        # Store level information
        hash_structure[f'level_{level}'] = {
            'resolution': resolution,
            'num_params': params_in_level,
            'weights': level_weights,
            'weights_shape': level_weights.shape,
            'scale': per_level_scale ** level
        }
        
        offset += params_in_level
    
    # Add global information
    hash_structure['global_info'] = {
        'total_params': offset,
        'embedding_dim': level_dim,
        'base_resolution': base_resolution,
        'max_resolution': int(np.ceil(base_resolution * (per_level_scale ** (num_levels-1)))),
        'per_level_scale': per_level_scale
    }
    
    return hash_structure

mrhe_by_layer = extract_hash_encoding_structure(nerf)

In [75]:
tmo = 0
for k in mrhe_by_layer.keys():
    if 'level' in k:
        print(k, "Resolution: " + str( mrhe_by_layer[k]['resolution']), "\t\tShape of Hash Layer Params:" + str(mrhe_by_layer[k]['weights'].shape))
        tmo += mrhe_by_layer[k]['weights'].shape[0]*mrhe_by_layer[k]['weights'].shape[1]
print(f"Total MRHE Table Params: " + str(tmo))


level_0 Resolution: 16 		Shape of Hash Layer Params:torch.Size([4096, 2])
level_1 Resolution: 23 		Shape of Hash Layer Params:torch.Size([12168, 2])
level_2 Resolution: 31 		Shape of Hash Layer Params:torch.Size([29792, 2])
level_3 Resolution: 43 		Shape of Hash Layer Params:torch.Size([79512, 2])
level_4 Resolution: 59 		Shape of Hash Layer Params:torch.Size([205384, 2])
level_5 Resolution: 81 		Shape of Hash Layer Params:torch.Size([524288, 2])
level_6 Resolution: 112 		Shape of Hash Layer Params:torch.Size([524288, 2])
level_7 Resolution: 154 		Shape of Hash Layer Params:torch.Size([524288, 2])
level_8 Resolution: 213 		Shape of Hash Layer Params:torch.Size([524288, 2])
level_9 Resolution: 295 		Shape of Hash Layer Params:torch.Size([524288, 2])
level_10 Resolution: 407 		Shape of Hash Layer Params:torch.Size([524288, 2])
level_11 Resolution: 562 		Shape of Hash Layer Params:torch.Size([524288, 2])
level_12 Resolution: 777 		Shape of Hash Layer Params:torch.Size([524288, 2])
level_1

In [66]:
def extract_mlp_weights(model_weights):
    """Extract geometric and view-dependent MLP weights from the model."""
    geometry_layers = {}
    view_mlp_layers = {}
    
    # Extract geometry MLP weights
    for i in range(CONFIG['mlp']['num_layers']):
        weight_key = f'_orig_mod.grid_mlp.net.{i}.weight'
        bias_key = f'_orig_mod.grid_mlp.net.{i}.bias'
        
        if weight_key in model_weights:
            geometry_layers[f'layer_{i}'] = {
                'weights': model_weights[weight_key],
                'shape': model_weights[weight_key].shape
            }
            
            if bias_key in model_weights:
                geometry_layers[f'layer_{i}']['bias'] = model_weights[bias_key]
    
    # Extract view-dependent MLP weights
    for i in range(CONFIG['mlp']['num_layers']):
        weight_key = f'_orig_mod.view_mlp.net.{i}.weight'
        bias_key = f'_orig_mod.view_mlp.net.{i}.bias'
        
        if weight_key in model_weights:
            view_mlp_layers[f'layer_{i}'] = {
                'weights': model_weights[weight_key],
                'shape': model_weights[weight_key].shape
            }
            
            if bias_key in model_weights:
                view_mlp_layers[f'layer_{i}']['bias'] = model_weights[bias_key]
    
    return {
        'geometry_mlp': geometry_layers,
        'view_mlp': view_mlp_layers
    }

# Example usage
mlp_weights = extract_mlp_weights(nerf)
# Print number of layers and details
print(f"Geometry MLP layers: {len(mlp_weights['geometry_mlp'])}")
print(f"View MLP layers: {len(mlp_weights['view_mlp'])}")

# Print shape of each layer in the geometry MLP
print("\nGeometry MLP layer shapes:")
for layer_name, layer_data in mlp_weights['geometry_mlp'].items():
    print(f"  {layer_name}: {layer_data['shape']} - Input: {layer_data['shape'][1]}, Output: {layer_data['shape'][0]}")

# Print shape of each layer in the view MLP
print("\nView MLP layer shapes:")
for layer_name, layer_data in mlp_weights['view_mlp'].items():
    print(f"  {layer_name}: {layer_data['shape']} - Input: {layer_data['shape'][1]}, Output: {layer_data['shape'][0]}")

Geometry MLP layers: 3
View MLP layers: 3

Geometry MLP layer shapes:
  layer_0: torch.Size([64, 32]) - Input: 32, Output: 64
  layer_1: torch.Size([64, 64]) - Input: 64, Output: 64
  layer_2: torch.Size([16, 64]) - Input: 64, Output: 16

View MLP layer shapes:
  layer_0: torch.Size([32, 31]) - Input: 31, Output: 32
  layer_1: torch.Size([32, 32]) - Input: 32, Output: 32
  layer_2: torch.Size([3, 32]) - Input: 32, Output: 3
