In [1]:
import os
import torch
import json

from bcnf.utils import get_dir, load_config, sub_root_path
from bcnf.train import Trainer
from bcnf import CondRealNVP_v2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
MODEL_NAME = 'videos_CNN_LSTM_large'

In [3]:
config_path = os.path.join(get_dir("configs", "runs"), f"{MODEL_NAME}.yaml")
config = load_config(config_path, verify=False)

In [4]:
model = CondRealNVP_v2.from_config(config).to(device)

print(f'{model.n_params:,}')

47,835,415


In [5]:
trainer = Trainer(
    config={k.lower(): v for k, v in config.to_dict().items()},
    project_name="bcnf-test",
    parameter_index_mapping=model.parameter_index_mapping,
    verbose=True,
)

Using dtype: torch.float32
Loading data from /home/psaegert/Projects/bcnf/data/bcnf-data/fixed_data_render_2s_15FPS/train...


Loading data from directory: 100%|██████████| 1/1 [00:01<00:00,  1.84s/it, file=fixed_data_render_2s_15FPS_1.pkl]


Using videos data for training. Shapes:
Conditions: [torch.Size([1000, 2, 30, 90, 160]), torch.Size([1000, 7])]
Parameters: torch.Size([1000, 19])


In [None]:
model = trainer.train(model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpsaegert[0m ([33mbalisticcnf[0m). Use [1m`wandb login --relogin`[0m to force relogin


Train: -35.1344 - Val: -24.4621 (avg: -24.5370, min: -24.6922) | lr: 1.25e-05 - Patience: 48/500 - z: (0.0417 ± 0.5339) ± (1.3649 ± 0.6076):   6%|▋         | 3154/50000 [2:46:40<41:15:29,  3.17s/it]  


VBox(children=(Label(value='0.004 MB of 0.019 MB uploaded\r'), FloatProgress(value=0.2277873243707672, max=1.0…

0,1
distance_to_last_best_val_loss_fold_-1,▁▁▁▁▁▁▁▂▁▁▂▁▃▅▂▄▆▁▃▃▅▇▁▃▄▇▁▁▁▄▁▄▅█▂▅▂▃▄▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
lr_fold_-1,█████████████████▄▄▄▄▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
time_fold_-1,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
train_loss_fold_-1,█▇▆▅▅▅▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss_fold_-1,█▇▆▅▄▃▃▃▃▃▂▂▂▂▂▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
z_mean_mean_fold_-1,▅▅▆▄▅▅▆▅▄▇▄▂▁▆█▄▅▆▃▄▅▄▃▆▃▅▄▄▃▅▄▄▄▃▆▃▄▅▄▅
z_mean_std_fold_-1,▃▁▂▃▂▃▂▆▄▆▅▅▄▄▁▆▅▆▄▄▅▇▆▄▆▃▄▄▄▄█▃▅▄▄▄▆▃▃▅
z_std_mean_fold_-1,▁▁▁▂▂▃▂▂▂▂▃▅▄▅▄▄▅▅▄▄▅▆▄▆▅▄▄▅█▄▆▆▅▄▆▅▅▆█▅
z_std_std_fold_-1,▁▁▁▁▂▂▁▁▂▁▃▃▂▃▂▃▄▄▄▃▄▄▅▃▃▄▄▄▅▄▃▅▄▃▄▆▆▄█▄

0,1
distance_to_last_best_val_loss_fold_-1,48.0
epoch,3155.0
lr_fold_-1,1e-05
time_fold_-1,1710783686.77525
train_loss_fold_-1,-35.13435
val_loss_fold_-1,-24.46211
z_mean_mean_fold_-1,0.04173
z_mean_std_fold_-1,0.53386
z_std_mean_fold_-1,1.36487
z_std_std_fold_-1,0.60761


In [None]:
torch.save(model.state_dict(), os.path.join(get_dir('models', 'bcnf-models', MODEL_NAME, create=True), f"state_dict.pt"))

with open(os.path.join(get_dir('models', 'bcnf-models', MODEL_NAME, create=True), 'config.json'), 'w') as f:
    json.dump({'config_path': "{{BCNF_ROOT}}" + config_path}, f)

print(f"Model saved to {get_dir('models', 'bcnf-models', MODEL_NAME)}")