In [2]:
def lstm_params(lstm, units, d_model):
    if lstm:
        # LSTMCellの学習可能パラメータ
        lstm_params = (4 * units * d_model) + (4 * units * units) + (8 * units)
        return lstm_params
    return 0

def ltc_params(units, d_model):
    # LTCCell単体の学習可能パラメータ
    return (units * units) * 4 + (units) * 3 + (d_model * units) * 4

def calc_params(units, d_input, d_output, d_model, n_layers, lstm):
    add_params = lstm_params(lstm, units, d_output) + ltc_params(units, d_output)
    trainable_params=lstm_params(lstm, units, d_model) + ltc_params(units, d_model) + add_params * (n_layers - 1)
    non_trainable_params = units * (units + d_model) + units * (units + d_output) * (n_layers - 1)

    encoder_params = d_model * (d_input + 1)
    decoder_params = d_output * (d_output + 1)

    print(f"Encoder Params: {encoder_params}")
    print(f"Decoder Params: {decoder_params}")
    print(f"Total Trainable Params: {trainable_params + encoder_params + decoder_params}")
    print(f"Total Non-Trainable Params: {non_trainable_params}")
    print(f"Total Params: {trainable_params + encoder_params + decoder_params + non_trainable_params}")
    return trainable_params + encoder_params + decoder_params + non_trainable_params

In [3]:
def calculate_rnn_params(cell_type, n_layers, d_input, d_model, d_output):
    """
    様々な種類のRNNの総パラメータ数を計算します。

    Args:
        cell_type (str): 'RNN', 'LSTM', 'GRU' のいずれか
        n_layers (int): レイヤー数
        d_input (int): 入力データの次元数
        d_model (int): 隠れ状態の次元数
        d_output (int): 出力データの次元数

    Returns:
        int: パラメータの総数
    """
    total_params = 0

    encoder_params = d_model * (d_input + 1)
    decoder_params = d_output * (d_model + 1)
    total_params += encoder_params + decoder_params

    # ゲートの数や重みの種類を決定する係数
    if cell_type.upper() == 'RNN':
        gate_factor = 1
    elif cell_type.upper() == 'LSTM':
        gate_factor = 4
    elif cell_type.upper() == 'GRU':
        gate_factor = 3
    else:
        raise ValueError("cell_typeは 'RNN', 'LSTM', 'GRU' のいずれかである必要があります")

    # 1層目のパラメータ数
    # 重み(入力->隠れ、隠れ->隠れ) + バイアス(2種類)
    params_first_layer = (gate_factor * d_model * d_model) + \
                         (gate_factor * d_model * d_model) + \
                         (2 * gate_factor * d_model)
    total_params += params_first_layer

    # 2層目以降のパラメータ数
    if n_layers > 1:
        # 2層目以降の入力は、前の層の隠れ状態 (次元はd_model)
        params_other_layers = (gate_factor * d_model * d_model) + \
                              (gate_factor * d_model * d_model) + \
                              (2 * gate_factor * d_model)
        total_params += (n_layers - 1) * params_other_layers
        
    return total_params

In [4]:
def calc_mem(params, batch_size, units, n_layers, ode_unfolds, length):
    # モデルパラメータの格納
    params_mem = params * 4  # float32で格納する場合、1パラメータあたり4バイト
    print(f"Params Memory Usage: {params_mem / (1024 ** 2):.2f} MB")
    # スケジューラやオプティマイザのメモリ使用量を仮定
    scheduler_mem = params_mem * 2
    print(f"Scheduler Memory Usage: {scheduler_mem / (1024 ** 2):.2f} MB")
    # 順伝播
    forward_mem = ((batch_size * units * units * 2) + (batch_size * units * 6)) * n_layers * length * ode_unfolds * 4
    print(f"Forward Memory Usage: {forward_mem / (1024 ** 2):.2f} MB")
    # 逆伝播. とりあえず順伝播と同じメモリ消費量を想定
    backward_mem = forward_mem
    print(f"Backward Memory Usage: {backward_mem / (1024 ** 2):.2f} MB")
    total_mem = params_mem + scheduler_mem + forward_mem + backward_mem
    print(f"Total Memory Usage: {total_mem / (1024 ** 2):.2f} MB")
    return total_mem

In [11]:
units=40
batch_size=128
d_input=40
d_output=12
d_model=40
n_layers=1
length=128
ode_unfolds=1
lstm=True
params = calc_params(units, d_input, d_output, d_model, n_layers, lstm)
calc_mem(params, batch_size, units, n_layers, ode_unfolds, length)

Encoder Params: 1640
Decoder Params: 156
Total Trainable Params: 27836
Total Non-Trainable Params: 3200
Total Params: 31036
Params Memory Usage: 0.12 MB
Scheduler Memory Usage: 0.24 MB
Forward Memory Usage: 215.00 MB
Backward Memory Usage: 215.00 MB
Total Memory Usage: 430.36 MB


451260112

In [6]:
calculate_rnn_params('RNN',1,9,128,6)

35078