In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'

from mol_td.utils import load_config
from mol_td.data_fns import load_data, prep_data, get_split, prep_dataloaders
from mol_td.model import SimpleVAE

import jax
from jax import jit
from typing import Any, Callable, Sequence, Optional
from jax import lax, random as rnd, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
import optax
import wandb
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset



In [3]:
cfg = load_config('/home/amawi/projects/mol-td/configs/default_config.yaml')
cfg

{'WANDB': {'user': 'xmax1', 'WANDB_API_KEY': 1},
 'MODEL': {'enc_hidden': [40, 20], 'dec_hidden': [40, 84], 'seed': 1},
 'TRAIN': {'n_epochs': 3, 'lr': 0.001},
 'PATHS': {'root': '/home/amawi/projects/mol-td',
  'data': './data',
  'results': './results/test',
  'default_config': './configs/default_config.yaml',
  'uracil_xyz': './data/uracil.xyz'}}

In [4]:
# load and prep the data
data, raw_data = load_data('/home/amawi/projects/mol-td/data/uracil_dft.npz')
print(list(raw_data.keys()))
train_loader, val_loader, test_loader = prep_dataloaders(data)

(133770, 12, 3) (133770, 12, 3) (133770, 12, 1)
['E', 'name', 'F', 'theory', 'R', 'z', 'type', 'md5']


In [5]:
# initialise the model
model = SimpleVAE(cfg['MODEL'])
rng, video_rng, params_rng, sample_rng = rnd.split(rnd.PRNGKey(cfg['MODEL']['seed']), 4)
ex_batch = next(train_loader)
params = model.init(dict(params=params_rng, sample=sample_rng), ex_batch)

Dimensions:  {'mu': (32, 20), 'sigma': (32, 20), 'predicted': (32, 84)}


In [6]:
cfg = load_config('/home/amawi/projects/mol-td/configs/default_config.yaml')

run = wandb.init(project='test', id='test3', entity='xmax1', config=cfg['TRAIN'])

loss_grad_fn = jit(jax.value_and_grad(model.apply, has_aux=True))
fwd = jit(model.apply)

tx = optax.sgd(learning_rate=cfg['TRAIN']['lr'])
opt_state = tx.init(params)

for epoch in range(cfg['TRAIN']['n_epochs']):
    for batch_idx, batch in enumerate(train_loader):
        
        (loss, signal), grads = loss_grad_fn(params, batch)
        updates, opt_state = tx.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        wandb.log({'loss': loss, 
                   'kl_div': signal['kl_div'], 
                   'nll': signal['nll']})

        # if batch_idx % 100 == 0:
        #     print(f'Step {batch_idx}, loss {loss}')

        # indicators TODO

    train_loader.shuffle()

    if val_loader is not None:
        for batch_idx, batch in enumerate(val_loader):
            val_loss, _ = fwd(params, batch)
            
        wandb.log({'val_loss': loss, 'epoch': epoch})

run.finish()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mxmax1[0m (use `wandb login --relogin` to force relogin)


Dimensions:  {'mu': (32, 20), 'sigma': (32, 20), 'predicted': (32, 84)}
Dimensions:  {'mu': (32, 20), 'sigma': (32, 20), 'predicted': (32, 84)}



0,1
epoch,▁▅█
kl_div,█▇▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▅▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
nll,█▅▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▃▁

0,1
epoch,2.0
kl_div,0.2171
loss,214.48035
nll,214.26324
val_loss,214.48035


In [14]:
for x in train_loader:
    print(x)