In [7]:
import torch

from utils import DEVICE, init_random_seeds
from models._utils import print_parameters, FeatureFFNN, NeuralTransformer, NetworkLSTM
from fvcore.nn import (
    FlopCountAnalysis,
    ActivationCountAnalysis,
    flop_count_table,
    flop_count_str,
)

# Display the DEVICE
print(f"DEVICE: {DEVICE}")

# Initialize the random seeds
init_random_seeds(0)

DEVICE: cuda


In [8]:
# @title Prepare model and input
# @markdown Make sure the model and input are on the same device.

# Set shapes for model and input
seq_len = 100
input_size = 302
hidden_size = 512

# Use a standard PyTorch model
model = torch.nn.Linear(in_features=input_size, out_features=hidden_size)
model = model.to(DEVICE)
model.eval()  # switch to eval mode
print(f"Simple PyTorch model: {model}\n")

# Create input of the correct shape for the model
input = torch.randn(1, seq_len, input_size).to(DEVICE)  # batch_size=1
mask = None
print(f"Input: {input.shape} \t Output: {model(input).shape}", end="\n\n")
print(f"\n{'~'*100}\n")

# Load one of our custom models instead
model_args = dict(input_size=input_size, hidden_size=hidden_size, loss="MSE")
model = FeatureFFNN(**model_args)
# model = NeuralTransformer(**model_args)
# model = NetworkLSTM(**model_args)
model = model.to(DEVICE)
model.eval()  # switch to eval mode
print(f"Custom model: {model}\n")

# Create input of the correct shape for the model
input = torch.randn(1, seq_len, input_size).to(DEVICE)  # batch_size=1
mask = torch.ones(1, input_size).to(bool).to(DEVICE)
print(
    f"Input: {input.shape} \t Mask: {mask.shape} \t Output: {model(input, mask).shape}",
    end="\n\n",
)

Simple PyTorch model: Linear(in_features=302, out_features=512, bias=True)

Input: torch.Size([1, 100, 302]) 	 Output: torch.Size([1, 100, 512])


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Custom model: NetworkLSTM(
  (identity): Identity()
  (input_hidden): Sequential(
    (0): Linear(in_features=302, out_features=512, bias=True)
    (1): ReLU()
    (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (hidden_hidden): LSTM(512, 512, batch_first=True)
  (inner_hidden_model): InnerHiddenModel(
    (hidden_hidden): LSTM(512, 512, batch_first=True)
  )
  (linear): Linear(in_features=512, out_features=302, bias=True)
)

Input: torch.Size([1, 100, 302]) 	 Mask: torch.Size([1, 302]) 	 Output: torch.Size([1, 100, 302])



In [9]:
# @title Using fvcore

# Adjust input based on if we use standard PyTorch model or custom model
input = (input, mask) if mask is not None else input

# Count the total and number of trainable parameters
all_params_ct, train_params_ct = print_parameters(model)

print(f"\nAll params: {all_params_ct}\nTrainable params: {train_params_ct}", end="\n\n")

# Perform FLOP Counting: Use the FlopCountAnalysis class to analyze your model:
flops = FlopCountAnalysis(model, input)

# Print Results: You can now print out the FLOPs and parameter information:
print(
    f"FLOPs: {flops.total(), flops.by_operator(), flops.by_module(), flops.by_module_and_operator()}",
    end="\n\n",
)
print(flop_count_str(flops), end="\n\n")
print(flop_count_table(flops), end="\n\n")
print(
    f"\tParams: {sum(p.numel() for p in model.parameters() if p.requires_grad)}",
    end="\n\n",
)

# Perform Activations Counting: Use the ActivationCountAnalysis class to analyze your model:
acts = ActivationCountAnalysis(model, input)

# Print Results: You can now print out the FLOPs and parameter information:
print(
    f"Activations: {acts.total(), acts.by_operator(), acts.by_module(), acts.by_module_and_operator()}",
    end="\n\n",
)
print(
    f"\tParams: {sum(p.numel() for p in model.parameters() if p.requires_grad)}",
    end="\n\n",
)

Unsupported operator aten::expand_as encountered 1 time(s)
Unsupported operator aten::mul encountered 2 time(s)
Unsupported operator aten::lstm encountered 1 time(s)
Unsupported operator aten::expand_as encountered 1 time(s)
Unsupported operator aten::mul encountered 2 time(s)
Unsupported operator aten::layer_norm encountered 1 time(s)
Unsupported operator aten::lstm encountered 1 time(s)



All params: 2412334
Trainable params: 2412334

FLOPs: (31180800, Counter({'linear': 30924800, 'layer_norm': 256000}), Counter({'': 31180800, 'input_hidden': 15718400, 'input_hidden.0': 15462400, 'linear': 15462400, 'input_hidden.2': 256000, 'identity': 0, 'input_hidden.1': 0, 'hidden_hidden': 0, 'inner_hidden_model': 0}), {'': Counter({'linear': 30924800, 'layer_norm': 256000}), 'identity': Counter(), 'input_hidden': Counter({'linear': 15462400, 'layer_norm': 256000}), 'input_hidden.0': Counter({'linear': 15462400}), 'input_hidden.1': Counter(), 'input_hidden.2': Counter({'layer_norm': 256000}), 'hidden_hidden': Counter(), 'inner_hidden_model': Counter(), 'linear': Counter({'linear': 15462400})})

N/A indicates a possibly missing statistic due to how the module was called. Missing values are still included in the parent's total.
NetworkLSTM(
  #params: 2.41M, #flops: 31.18M
  (identity): Identity(#params: 0, #flops: N/A)
  (input_hidden): Sequential(
    #params: 0.16M, #flops: 15.72M