In [7]:
import torch
from models.sleepppgnet import SleepPPGNet
from config import SleepPPGNetConfig

def count_tunable_params(model):
    """
    Computes the number of trainable parameters in a model.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Example usage with different configurations
configs = [
    {"num_res_blocks": 8, "tcn_layers": 2, "hidden_dim": 32},
    {"num_res_blocks": 8, "tcn_layers": 2, "hidden_dim": 64},
    {"num_res_blocks": 8, "tcn_layers": 2, "hidden_dim": 128},
    {"num_res_blocks": 8, "tcn_layers": 2, "hidden_dim": 256},
]

for config in configs:
    model = SleepPPGNet(
        input_channels=SleepPPGNetConfig.INPUT_CHANNELS,
        num_classes=SleepPPGNetConfig.NUM_CLASSES,
        num_res_blocks=config["num_res_blocks"],
        tcn_layers=config["tcn_layers"],
        hidden_dim=config["hidden_dim"],
        dropout_rate=SleepPPGNetConfig.DROPOUT_RATE
    )
    model.eval()  # Switch to eval mode before counting parameters
    num_params = count_tunable_params(model)
    
    print(f"Config: {config} -> Tunable Parameters: {num_params}")

Config: {'num_res_blocks': 8, 'tcn_layers': 2, 'hidden_dim': 32} -> Tunable Parameters: 3628547
Config: {'num_res_blocks': 8, 'tcn_layers': 2, 'hidden_dim': 64} -> Tunable Parameters: 14490627
Config: {'num_res_blocks': 8, 'tcn_layers': 2, 'hidden_dim': 128} -> Tunable Parameters: 57915395
Config: {'num_res_blocks': 8, 'tcn_layers': 2, 'hidden_dim': 256} -> Tunable Parameters: 231567363
