In [15]:
import torch
from torchprofile import profile_macs
from dac.model.dac import DAC
from torch import nn


In [3]:
def get_model_macs(model, inputs) -> int:
    return profile_macs(model, inputs)

In [12]:
model = DAC()

dummy_input = torch.randn(1, 1, 44100)
# pruned_model = channel_prune(model, prune_ratio=0.3)
dac_model_macs = get_model_macs(model, dummy_input)

if dac_model_macs >= 1e9:
    print(f"DAC model MACs: {dac_model_macs/1e9:.2f}B")
else:
    print(f"DAC model MACs: {dac_model_macs/1e6:.2f}M")


  WeightNorm.apply(module, name, dim)


DAC model MACs: 146.90B




In [16]:
def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
    """
    calculate the total number of parameters of model
    :param count_nonzero_only: only count nonzero weights
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements


def get_model_size(model: nn.Module, data_width=32, count_nonzero_only=False) -> int:
    """
    calculate the model size in bits
    :param data_width: #bits per element
    :param count_nonzero_only: only count nonzero weights
    """
    return get_num_parameters(model, count_nonzero_only) * data_width

In [18]:
dac_model_size = get_model_size(model)

if dac_model_size >= 1e9:
    print(f"DAC model size: {dac_model_size/1e9:.2f}B")
else:
    print(f"DAC model size: {dac_model_size/1e6:.2f}M")

DAC model size: 2.45B
