In [1]:
import torch

# Load the VAE checkpoint
ckpt = torch.load('./Pretrained_Components/VAE', map_location='cpu', weights_only=False)
sd = ckpt['state_dict']

print("=" * 70)
print("VAE ARCHITECTURE - Layer Names and Dimensions")
print("=" * 70)

# Group by encoder/decoder
encoder_layers = {}
decoder_layers = {}
other_layers = {}

for name, param in sd.items():
    shape_str = str(list(param.shape))
    numel = param.numel()
    
    if name.startswith('encoder'):
        encoder_layers[name] = (shape_str, numel)
    elif name.startswith('decoder'):
        decoder_layers[name] = (shape_str, numel)
    else:
        other_layers[name] = (shape_str, numel)

# Print Encoder
print("\nðŸ“¦ ENCODER")
print("-" * 70)
total_encoder = 0
for name, (shape, numel) in encoder_layers.items():
    print(f"{name:50s} {shape:25s} ({numel:,} params)")
    total_encoder += numel
print(f"{'Total Encoder Parameters':50s} {total_encoder:,}")

# Print Decoder
print("\nðŸ“¦ DECODER")
print("-" * 70)
total_decoder = 0
for name, (shape, numel) in decoder_layers.items():
    print(f"{name:50s} {shape:25s} ({numel:,} params)")
    total_decoder += numel
print(f"{'Total Decoder Parameters':50s} {total_decoder:,}")

# Print Other (quant_conv, post_quant_conv)
if other_layers:
    print("\nðŸ“¦ OTHER (Quantization Layers)")
    print("-" * 70)
    total_other = 0
    for name, (shape, numel) in other_layers.items():
        print(f"{name:50s} {shape:25s} ({numel:,} params)")
        total_other += numel
    print(f"{'Total Other Parameters':50s} {total_other:,}")

# Summary
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
total_params = sum(p.numel() for p in sd.values())
print(f"Total Parameters: {total_params:,}")
print(f"\nInput Dimension:  {sd['encoder.fc_in.weight'].shape[1]}")
print(f"Output Dimension: {sd['decoder.fc_out.weight'].shape[0]}")

VAE ARCHITECTURE - Layer Names and Dimensions

ðŸ“¦ ENCODER
----------------------------------------------------------------------
encoder.fc_in.weight                               [2048, 513]               (1,050,624 params)
encoder.fc_in.bias                                 [2048]                    (2,048 params)
encoder.conv_in.weight                             [64, 10, 3, 3]            (5,760 params)
encoder.conv_in.bias                               [64]                      (64 params)
encoder.down.0.block.0.norm1.weight                [64]                      (64 params)
encoder.down.0.block.0.norm1.bias                  [64]                      (64 params)
encoder.down.0.block.0.conv1.weight                [64, 64, 3, 3]            (36,864 params)
encoder.down.0.block.0.conv1.bias                  [64]                      (64 params)
encoder.down.0.block.0.norm2.weight                [64]                      (64 params)
encoder.down.0.block.0.norm2.bias                  

In [6]:
import torch
import sys
sys.path.insert(0, './External/DNNWG')

from stage1.models.autoencoder import VAENoDiscModel

ddconfig = {
    'double_z': True,
    'z_channels': 4,
    'resolution': 64,
    'in_channels': 10,
    'my_channels': 10,
    'out_ch': 10,
    'ch': 64,
    'ch_mult': [1, 2, 4],
    'num_res_blocks': 2,
    'attn_resolutions': [],
    'dropout': 0.0,
    'in_dim': 513,
    'fdim': 4096
}

lossconfig = {
    'target': 'stage1.modules.losses.CustomLosses.Myloss',
    'params': {'logvar_init': 0.0, 'kl_weight': 1e-6}
}

print('=== Creating model (this should succeed) ===')
model = VAENoDiscModel(
    ddconfig=ddconfig,
    lossconfig=lossconfig,
    embed_dim=4,
    learning_rate=1e-4,
    input_key='weight',
    device='cpu'
)
print('Model created successfully!')

print('=== Trying forward pass (this will fail) ===')
x = torch.randn(1, 5130)
print(f'Input: {x.shape}')

try:
    output = model({'weight': x})
    print('Forward pass succeeded!')
except Exception as e:
    print(f'Forward pass FAILED: {e}')
    print()
    print('This means:')
    print('  - Model can be CREATED (no error)')
    print('  - Model CANNOT be TRAINED (forward pass fails)')
    print('  - Your friend could NOT have trained this successfully!')


=== Creating model (this should succeed) ===
making attention of type 'vanilla' with 256 in_channels
Working with z of shape (1, 4, 16, 16) = 1024 dimensions.
making attention of type 'vanilla' with 256 in_channels
Model created successfully!
=== Trying forward pass (this will fail) ===
Input: torch.Size([1, 5130])
Forward pass succeeded!
