In [1]:
import torch
import lightning.pytorch as pl
from datapile import FastPMPile, HuggingfaceLoader, HuggingfaceLoader_iter
from model import Lpt2NbodyNetLightning
import yaml
import argparse
import os

def load_config(config_file):
    with open(config_file, 'r') as file:
        config = yaml.safe_load(file)
    return config

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def verify_model(model, input_shape, output_shape=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """
    Verify that a model can successfully complete forward and backward passes.
    
    Args:
        model: The PyTorch model to verify
        input_shape: Shape of input tensor (batch size included)
        output_shape: Expected output shape (for loss calculation)
        device: Device to run verification on
    """
    model = model.to(device)
    model.train()  # Set to training mode for proper gradient computation
    
    try:
        # Create random input tensor
        dummy_input = torch.randn(*input_shape).to(device)
        dummy_input.requires_grad = True  # For models that need input gradients
        
        # Forward pass
        output = model(dummy_input)
        print("Forward pass completed successfully.")
        print(f"Output shape: {output.shape}")
        
        # Create dummy target if output_shape is provided
        if output_shape is not None:
            if isinstance(output, (list, tuple, dict)):
                # Handle multiple outputs
                dummy_target = [torch.randn(*shape).to(device) for shape in output_shape]
                loss = sum(torch.nn.functional.mse_loss(o, t) for o, t in zip(output, dummy_target))
            else:
                dummy_target = torch.randn(*output_shape).to(device)
                loss = torch.nn.functional.mse_loss(output, dummy_target)
        else:
            # Simple sum of outputs if no target shape provided
            if isinstance(output, (list, tuple, dict)):
                loss = sum(o.sum() for o in output)
            else:
                loss = output.sum()
        
        # Backward pass
        loss.backward()
        print("Backward pass completed successfully.")
        
        # Check parameter gradients
        for name, param in model.named_parameters():
            if param.grad is None:
                print(f"Warning: Parameter {name} has no gradient")
            else:
                print(f"Parameter {name} gradient norm: {param.grad.norm().item():.4f}")
        
        return True
        
    except Exception as e:
        print(f"Error during verification: {str(e)}")
        return False

In [7]:
path="/home/user/ckwan1/ml/configs/new_config/naf_denoise_coslr.yaml"
config = load_config(path)
config['model']['batch_size'] = config['data']['batch_size']
config['model']['max_epochs'] = config['trainer']['max_epochs']
config['model']['naf_middle_blk_num'] = 1
config['model']['naf_enc_blk_nums']  = [1,1]
config['model']['naf_dec_blk_nums']  = [1,1]
model = Lpt2NbodyNetLightning(**config['model'])
input_shape = (1, 3, 32, 32, 32)  
output_shape = (1, 3, 32, 32, 32)  
verify_model(model, input_shape, output_shape)

Forward pass completed successfully.
Output shape: torch.Size([1, 3, 32, 32, 32])
Backward pass completed successfully.
Parameter model.intro.weight gradient norm: 0.6609
Parameter model.intro.bias gradient norm: 0.0307
Parameter model.ending.weight gradient norm: 3.0512
Parameter model.ending.bias gradient norm: 0.0508
Parameter model.encoders.0.0.beta gradient norm: 0.0022
Parameter model.encoders.0.0.gamma gradient norm: 0.0039
Parameter model.encoders.0.0.conv1.weight gradient norm: 0.0000
Parameter model.encoders.0.0.conv1.bias gradient norm: 0.0000
Parameter model.encoders.0.0.conv2.weight gradient norm: 0.0000
Parameter model.encoders.0.0.conv2.bias gradient norm: 0.0000
Parameter model.encoders.0.0.conv3.weight gradient norm: 0.0000
Parameter model.encoders.0.0.conv3.bias gradient norm: 0.0000
Parameter model.encoders.0.0.sca.1.weight gradient norm: 0.0000
Parameter model.encoders.0.0.sca.1.bias gradient norm: 0.0000
Parameter model.encoders.0.0.conv4.weight gradient norm: 0.00

True