In [57]:
def lstm_params(lstm, units, d_model):
    if lstm:
        # LSTMCellの学習可能パラメータ
        lstm_params = (4 * units * d_model) + (4 * units * units) + (8 * units)
        return lstm_params
    return 0

def ltc_params(units, d_model):
    # LTCCell単体の学習可能パラメータ
    return (units * units) * 4 + (units) * 3 + (d_model * units) * 4
    
def calc_params(units, d_input, d_output, d_model, n_layers, lstm):
    add_params = lstm_params(lstm, units, d_output) + ltc_params(units, d_output)
    trainable_params=lstm_params(lstm, units, d_model) + ltc_params(units, d_model) + add_params * (n_layers - 1)
    non_trainable_params = units * (units + d_model) + units * (units + d_output) * (n_layers - 1)

    encoder_params = d_model * (d_input + 1)
    decoder_params = d_output * (d_output + 1)

    print(f"Encoder Params: {encoder_params}")
    print(f"Decoder Params: {decoder_params}")
    print(f"Total Trainable Params: {trainable_params + encoder_params + decoder_params}")
    print(f"Total Non-Trainable Params: {non_trainable_params}")
    print(f"Total Params: {trainable_params + encoder_params + decoder_params + non_trainable_params}")
    return trainable_params + encoder_params + decoder_params + non_trainable_params

In [58]:
def calc_mem(params, batch_size, units, n_layers, ode_unfolds, length):
    # モデルパラメータの格納
    params_mem = params * 4  # float32で格納する場合、1パラメータあたり4バイト
    print(f"Params Memory Usage: {params_mem / (1024 ** 2):.2f} MB")
    # スケジューラやオプティマイザのメモリ使用量を仮定
    scheduler_mem = params_mem * 2
    print(f"Scheduler Memory Usage: {scheduler_mem / (1024 ** 2):.2f} MB")
    # 順伝播
    forward_mem = ((batch_size * units * units * 2) + (batch_size * units * 6)) * n_layers * length * ode_unfolds * 4
    print(f"Forward Memory Usage: {forward_mem / (1024 ** 2):.2f} MB")
    # 逆伝播. とりあえず順伝播と同じメモリ消費量を想定
    backward_mem = forward_mem
    print(f"Backward Memory Usage: {backward_mem / (1024 ** 2):.2f} MB")
    total_mem = params_mem + scheduler_mem + forward_mem + backward_mem
    print(f"Total Memory Usage: {total_mem / (1024 ** 2):.2f} MB")
    return total_mem

In [60]:
units=256
batch_size=128
d_input=9
d_output=6
d_model=9
n_layers=2
length=128
ode_unfolds=3
lstm=True
params = calc_params(units, d_input, d_output, d_model, n_layers, lstm)
calc_mem(params, batch_size, units, n_layers, ode_unfolds, length)

Encoder Params: 90
Decoder Params: 42
Total Trainable Params: 1085060
Total Non-Trainable Params: 134912
Total Params: 1219972
Params Memory Usage: 4.65 MB
Scheduler Memory Usage: 9.31 MB
Forward Memory Usage: 49728.00 MB
Backward Memory Usage: 49728.00 MB
Total Memory Usage: 99469.96 MB


104301814320