In [1]:
import torch

from utils import DEVICE, init_random_seeds
from models._utils import (
    print_parameters,
    NaivePredictor,
    LinearRegression,
    FeatureFFNN,
    PureAttention,
    NeuralTransformer,
    NetworkLSTM,
    NetworkCTRNN,
    LiquidCfC,
)
from fvcore.nn import (
    FlopCountAnalysis,
    ActivationCountAnalysis,
    flop_count_table,
    flop_count_str,
)


# Initialize the random seeds
init_random_seeds(0)

CUDA device found.


In [105]:
# @title Prepare model and input
# @markdown Make sure the model and input are on the same device.

# Set shapes for model and input
seq_len = 100
input_size = 302
hidden_size = 516

# Use a standard PyTorch model
model = torch.nn.Linear(in_features=input_size, out_features=hidden_size)
model = model.to(DEVICE)
model.eval()  # switch to eval mode
print(f"Simple PyTorch model: {model}\n")

# Create input of the correct shape for the model
input = torch.randn(1, seq_len, input_size).to(DEVICE)  # batch_size=1
mask = None
print(f"Input: {input.shape} \t Output: {model(input).shape}", end="\n\n")
print(f"\n{'~'*100}\n")

# Load one of our custom models instead
model_args = dict(input_size=input_size, hidden_size=hidden_size, loss="MSE")
# model = NaivePredictor(**model_args)
# model = LinearRegression(**model_args)
model = FeatureFFNN(**model_args)  # hidden_size = 516 -> num_params = 580286
# model = PureAttention(**model_args)  # hidden_size = 312 -> num_params = 580310
# model = NeuralTransformer(**model_args)  # hidden_size = 196 -> num_params = 584186
# model = NetworkLSTM(**model_args)  # hidden_size = 234 -> num_params = 582260
# model = NetworkCTRNN(**model_args)  # hidden_size = 408 -> num_params = 582110
# model = LiquidCfC(**model_args)  # hidden_size = 422 -> num_params = 582368
model = model.to(DEVICE)
model.eval()  # switch to eval mode
print(f"Custom model: {model}\n")

# Create input of the correct shape for the model
input = torch.randn(1, seq_len, input_size).to(DEVICE)  # batch_size=1
mask = torch.ones(1, input_size).to(bool).to(DEVICE)
print(
    f"Input: {input.shape} \t Mask: {mask.shape} \t Output: {model(input, mask).shape}",
    end="\n\n",
)

Simple PyTorch model: Linear(in_features=302, out_features=516, bias=True)

Input: torch.Size([1, 100, 302]) 	 Output: torch.Size([1, 100, 516])


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Custom model: FeatureFFNN(
  (identity): Identity()
  (input_hidden): Sequential(
    (0): Linear(in_features=302, out_features=516, bias=True)
    (1): ReLU()
  )
  (hidden_hidden): FeedForward(
    (ffwd): Sequential(
      (0): Linear(in_features=516, out_features=516, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.1, inplace=False)
    )
  )
  (inner_hidden_model): InnerHiddenModel(
    (hidden_hidden): FeedForward(
      (ffwd): Sequential(
        (0): Linear(in_features=516, out_features=516, bias=True)
        (1): ReLU()
        (2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (latent_embedding): Linear(in_features=302, out_features=516, bias=True)
  (linear): Linear(in_features=516, out_features=302, bias=True)
  (layer_

In [106]:
# @title Using fvcore

# Adjust input based on if we use standard PyTorch model or custom model
input = (input, mask) if mask is not None else input

# Count the total and number of trainable parameters
all_params_ct, train_params_ct = print_parameters(model)

print(f"\nAll params: {all_params_ct}\nTrainable params: {train_params_ct}", end="\n\n")

# Perform FLOP Counting: Use the FlopCountAnalysis class to analyze your model:
flops = FlopCountAnalysis(model, input)

# Print Results: You can now print out the FLOP and parameter information:
print(
    f"FLOP: {flops.total(), flops.by_operator(), flops.by_module(), flops.by_module_and_operator()}",
    end="\n\n",
)
print(flop_count_str(flops), end="\n\n")
print(flop_count_table(flops), end="\n\n")
print(
    f"\tParams: {sum(p.numel() for p in model.parameters() if p.requires_grad)}",
    end="\n\n",
)

# Perform Activations Counting: Use the ActivationCountAnalysis class to analyze your model:
acts = ActivationCountAnalysis(model, input)

# Print Results: You can now print out the FLOP and parameter information:
print(
    f"Activations: {acts.total(), acts.by_operator(), acts.by_module(), acts.by_module_and_operator()}",
    end="\n\n",
)
print(
    f"\tParams: {sum(p.numel() for p in model.parameters() if p.requires_grad)}",
    end="\n\n",
)

Unsupported operator aten::expand_as encountered 1 time(s)
Unsupported operator aten::mul encountered 1 time(s)
Unsupported operator aten::add encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
layer_norm
Unsupported operator aten::expand_as encountered 1 time(s)
Unsupported operator aten::mul encountered 1 time(s)
Unsupported operator aten::add encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
layer_norm



All params: 580286
Trainable params: 580286

FLOP: (57792000, Counter({'linear': 57792000}), Counter({'': 57792000, 'hidden_hidden': 26625600, 'hidden_hidden.ffwd': 26625600, 'hidden_hidden.ffwd.0': 26625600, 'input_hidden': 15583200, 'input_hidden.0': 15583200, 'linear': 15583200, 'identity': 0, 'input_hidden.1': 0, 'hidden_hidden.ffwd.1': 0, 'hidden_hidden.ffwd.2': 0, 'inner_hidden_model': 0, 'layer_norm': 0}), {'': Counter({'linear': 57792000}), 'identity': Counter(), 'input_hidden': Counter({'linear': 15583200}), 'input_hidden.0': Counter({'linear': 15583200}), 'input_hidden.1': Counter(), 'hidden_hidden': Counter({'linear': 26625600}), 'hidden_hidden.ffwd': Counter({'linear': 26625600}), 'hidden_hidden.ffwd.0': Counter({'linear': 26625600}), 'hidden_hidden.ffwd.1': Counter(), 'hidden_hidden.ffwd.2': Counter(), 'inner_hidden_model': Counter(), 'linear': Counter({'linear': 15583200}), 'layer_norm': Counter()})

N/A indicates a possibly missing statistic due to how the module was ca