In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from env import env # environment variables
from pathlib import Path
from nn_nananana_ansatz.pyfig.pretty import print_tree
print_tree('../', ignore=['-dep-v2', '-dep', 'build', 'wandb', 'tmp*'])

Geometries of molecular oxygen in different spin states and oxidation states

O2, mult3	1.2075
O2, mult1	1.2255
O2+, mult2	1.1164


In [3]:
from rich.pretty import pprint
from rich import print

from nn_nananana_ansatz.pyfig.pyfig import DEVICES, DTYPES, RUNMODE
from nn_nananana_ansatz.pyfig.paths import Paths
from nn_nananana_ansatz.pyfig.plugins import Logger, Opt, OPTS, LOGMODE, RUNTYPE
from nn_nananana_ansatz.pyfig.utils import get_attrs

from nn_nananana_ansatz.systems import System
from appfig import Ansatzcfg, Walkers, Pyfig



In [4]:


class Pyfig(Pyfig):

	seed: int           	= 42
	debug: bool         	= False

	device: DEVICES     	= DEVICES.gpu
	dtype: DTYPES          	= DTYPES.float32
	
	paths: Paths = Paths(
		project         = 'hwat',
		exp_name        = 'test',
		run_name        = 'run.py',
		exp_data_dir    = 'data',
	)

	mode: RUNMODE       	= RUNMODE.train  # adam, sgd

	n_epoch: int      		= 100
	n_step: int      		= 100

	loss: str        		= 'vmc'  # orb_mse, vmc

	logger: Logger      = Logger(
		exp_name 		= paths.exp_name, 
		entity			= env.WANDB_ENTITY,
		project			= paths.project,
		n_log_metric	= 2,
		n_log_state		= -1,
		run_path		= paths.run_path,
        runtype         = RUNTYPE.runs,
		exp_data_dir	= paths.exp_data_dir,
	)

	opt: Opt         	= Opt(
		name = OPTS.AdamW,
	)

	walkers: Walkers = Walkers(
		n_b = 64,
        n_equil_step= 10 # DEBUGGING
	)

	system: System = System( # validation issue when nested, don't nest
	a =    [[0.0, 0.0, 0.0], [1.2075, 0.0, 0.0]],
	a_z =  [8, 8],
	charge = 0,
	spin = 2, # total spin, unparse
	)
	
	ansatzcfg: Ansatzcfg = Ansatzcfg(
		n_l 	= 3,
		ke_method 	= 'grad_grad', # inefficient, can do jit compiled version
	)

c = Pyfig( # edit here also

	seed= 1,
	
	ansatzcfg= Ansatzcfg(
		n_l 	= 2,
		ke_method 	= 'grad_grad',
	),
)

from nn_nananana_ansatz.pyfig.utils import print_maybe_inspect

for k,v in c.items():
	print(k, v)
# attrs = get_attrs(c)
# pprint(attrs)

# attrs = get_attrs(c.walkers)
# pprint(attrs)

In [5]:
# leave for now
# from pyfig.cli import TyperSource
# ts = TyperSource(c)
# ts.write()
# print(ts)

In [6]:
# import pandas as pd
# data_explore = ex.reshape(ex.shape[:3], -1) / 255.
# df = pd.DataFrame.from_dict({f"dim_{i}": data_explore[:100, i] for i in range(100)})
# df.head()
# import pygwalker as pyg
# gwalker = pyg.walk(df)

In [7]:
from nn_nananana_ansatz.hwat import Scf # https://github.com/pyscf wrapper

scf = Scf(system= c.system)
scf.init_app()
print(scf)

# electronic spin
# define the diffence between spin total, lots [:
# <S^2> = 4.4408921e-16  2S+1 = 1


pyfig:pyscf: 
converged SCF energy = -144.084750676126  <S^2> = 2.0008124  2S+1 = 3.0005416
app:init_app: mo_coef shape: (2, 10, 10)
**** MO energy ****
                             alpha | beta                alpha | beta
MO #1   energy= -21.0284314415787  | -20.9911041542446  occ= 1 | 1
MO #2   energy= -20.9471190199291  | -20.9065291461555  occ= 1 | 1
MO #3   energy= -2.40494141278155  | -2.3222595169621   occ= 1 | 1
MO #4   energy= -1.32132963072916  | -1.11969361990839  occ= 1 | 1
MO #5   energy= -1.32132963072916  | -1.11969361990839  occ= 1 | 1
MO #6   energy= -0.987304712466787 | -0.788701048233836 occ= 1 | 1
MO #7   energy= -0.85914411033142  | -0.757484628916062 occ= 1 | 1
MO #8   energy= -0.115754124642637 | 0.680443753354031  occ= 1 | 0
MO #9   energy= -0.115754124642635 | 0.680443753354034  occ= 1 | 0
MO #10  energy= 3.24064806202365   | 3.35909272248913   occ= 0 | 0

To work with the spin densities directly, `use mulliken_meta_spin()` only printing them here.

 ** Mullik

In [8]:
from copy import deepcopy

import torch
from torch.utils.data import DataLoader
from functorch import make_functional_with_buffers, vmap

from nn_nananana_ansatz.hwat import PyfigDataset, Ansatz



# CONVERT TO TORCH
dtype = dict(float32= torch.float32, float64= torch.float64, float16= torch.float16).get(c.dtype)
device = torch.device(c.device)
system_th = get_attrs(c.system)
system_th.update(dict(
    a= torch.tensor(system_th['a'], dtype= dtype, device= device), 
    a_z= torch.tensor(system_th['a_z'], dtype= dtype, device= device)
))
mo_coef = torch.tensor(scf.mo_coef, device=device, dtype=dtype)


print(f"""
device: {device}
dtype: {dtype}
""")
# INITIALISE THE MODEL
model = Ansatz(**system_th, **get_attrs(c.ansatzcfg), mol= None, mo_coef= mo_coef).to(device=device, dtype= dtype)
model_to_fn: torch.nn.Module = deepcopy(model)
model_fn, param, buffer = make_functional_with_buffers(model_to_fn)
model_fn_vmap = vmap(model_fn, in_dims=(None, None, 0))
del model
model = Ansatz(**system_th, **get_attrs(c.ansatzcfg), mol= scf.mol, mo_coef= mo_coef).to(device=device, dtype= dtype)

# INITIALISE THE DATASET
dataset = PyfigDataset(c, system= c.system, walkers= c.walkers)
def custom_collate(batch):
	return batch[0]
dataloader = DataLoader(dataset, batch_size= 1, collate_fn= custom_collate)  # c.data.n_b otherwise because of the internal sampler

# INITIALISE THE OPTIMISER
from torch.optim import Optimizer
from nn_nananana_ansatz.utils import get_opt
opt: Optimizer = get_opt(**get_attrs(c.opt))(model.parameters())

# TORCH SETTINGS https://jamesmccaffrey.wordpress.com/2019/01/23/pytorch-train-vs-eval-mode/
# not important 
model.train()
if 'eval' in c.mode:
    model.eval()

# wrap up
dataloader.dataset.init_dataset(c, device= device, dtype= dtype, model= model)

model fb layers: 
 [(32, 32), (128, 32), (128, 32), (128, 32)] [(4, 16), (16, 16), (16, 16), (16, 16)]
model fb layers: 
 [(32, 32), (128, 32), (128, 32), (128, 32)] [(4, 16), (16, 16), (16, 16), (16, 16)]
hwat:dataset: init
hwat:dataset:init: center_points
tensor([[0.0000, 0.0000, 0.0000],
        [0.0000, 1.2075, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 1.2075, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 1.2075, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 1.2075, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 1.2075, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 1.2075, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 1.2075, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 1.2075, 0.0000]], dtype=torch.float64)
dataset:len  100
dataset:init_dataset: data  torch.Size([64, 16, 3]) torch.float32 cuda:0


  warn_deprecated('make_functional_with_buffers', 'torch.func.functional_call')
  warn_deprecated('vmap', 'torch.vmap')


|    |     deltar |        acc |      diff |
|---:|-----------:|-----------:|----------:|
|  0 | 0.001      | 0.676406   | 0.0363906 |
|  1 | 0.00215443 | 0.674024   | 0.0367422 |
|  2 | 0.00464159 | 0.624766   | 0.0527188 |
|  3 | 0.01       | 0.66668    | 0.0251953 |
|  4 | 0.0215444  | 0.643125   | 0.0675    |
|  5 | 0.0464159  | 0.565859   | 0.0941406 |
|  6 | 0.1        | 0.48168    | 0.17832   |
|  7 | 0.215443   | 0.246133   | 0.413867  |
|  8 | 0.464159   | 0.0463281  | 0.613672  |
|  9 | 1          | 0.00628906 | 0.653711  |
dataset:init_dataset deltar  tensor([0.0100], device='cuda:0')
equil  0  acc  0.6519531011581421  deltar  -0.010118304751813412
dataset:init_dataset sampler is pretraining  False
dataset:init_dataset
data torch.Size([64, 16, 3]) cuda:0 torch.float32 tensor(0.2070, device='cuda:0') tensor(0.7467, device='cuda:0')
acc torch.Size([1]) cuda:0 torch.float32 tensor(0.6551, device='cuda:0') tensor(nan, device='cuda:0')
deltar torch.Size([1]) cuda:0 torch.float32 

In [18]:
import os
os.environ['WANDB_NOTEBOOK_NAME']='main'
import wandb
from nn_nananana_ansatz.pyfig.utils import gen_rnd

wandb.init( # https://wandb.ai/
	project = env.WANDB_PROJECT,
	entity= env.WANDB_ENTITY,
	name= gen_rnd(), # c.logger.exp_name,
	mode= c.logger.log_mode,
	tags= [c.logger.runtype, c.mode],
)

In [19]:
from nn_nananana_ansatz.hwat import loss_fn, loss_fn_pretrain, update_grads, update_params
from nn_nananana_ansatz.pyfig.utils import compute_metrix
from nn_nananana_ansatz.torch_utils import npify_tree

log_params = False
mode = c.mode = 'train'
model.requires_grad_(True) # torch stuff: sets parameters to compute gradients in backward pass

print(f"""
mode: {mode}
device: {device}
dtype: {dtype}
n_step {c.n_step}
n_log_metric: {c.logger.n_log_metric}
wandb: {wandb.run.name}
""")

v_cpu_d = dict()
for step, loader_d in enumerate(dataloader, start= 1):  # dataloader contains sampling loop

	data = loader_d['data'].to(device= device, dtype= dtype)  # move the data to the gpu
	model.zero_grad(set_to_none= True) # torch stuff: zero the gradient attached to the parameters of the model

	if c.mode == 'train': # compute the loss for different modes
		v_d = loss_fn(model, data, model_fn, system_th, c.ansatzcfg)  # general. v_d contains everything
	elif c.mode == 'pre':
		v_d = loss_fn_pretrain(model, data, c, step)
	loss = v_d.get('loss', torch.tensor(0.0, device= device, dtype= dtype)) # administrative

	create_graph = c.opt.name.lower() == 'AdaHessian'.lower() # torch stuff for autograd & second order optimisers
	loss.backward(create_graph= create_graph) # autograd: do backpropagation
	v_d['grads'] = {k: p.grad.detach() for k,p in model.named_parameters()}
	update_grads(model, v_d, step) # update the gradients on the model parameters
	
	update_params(model, opt)
	v_d['params'] = {k:p.detach() for k,p in model.named_parameters()} # collecting for logging

	if step % c.logger.n_log_metric == 0:
		
		if not log_params:
			v_d.pop('params')
			v_d.pop('grads')

		v_cpu_d: dict = npify_tree(v_d) # convert to numpy
		v_metrix: dict = compute_metrix(v_cpu_d) # rewrite the names and compute the metrics
		
		wandb.log(v_metrix) # send to wandb


In [20]:
wandb.finish()

0,1
energy/e.mean,▆▆▆▄▅▆▇▇█▇▇▅▄▃▃▃▄▄▄▅▄▄▃▃▂▂▂▃▃▄▄▂▃▂▂▂▃▁▃▁
energy/e.std,▆▄▃▃▄▅▃▂▂▃▂▄▂▂▂▂▂▂▂▃▂▂▂▃▃▂▃▃█▁▁▇▃▅▂▁▃▄▁▂
energy/ke.mean,▂▂▃▄▅▄▂▁▁▁▁▂▃▄▅▅▅▄▄▃▄▄▄▄▄▅▇▇▅▄▄▆▅▇▆▇▇█▆▇
energy/ke.std,▂▂▂▂▃▃▁▁▁▁▁▂▂▃▃▃▃▂▂▂▃▃▃▃▃▄▅▅▅▄▂▅▄█▄▅▅▅▃▅
energy/pe.mean,▇▇▆▅▄▅▇▇█▇█▆▅▅▄▃▄▅▄▅▅▅▄▄▄▃▂▂▄▄▄▃▄▂▃▂▂▁▃▂
energy/pe.std,▃▂▂▂▃▃▁▁▁▂▁▃▂▂▃▂▂▂▂▂▃▃▃▃▃▃▄▃▂▃▂▅▃█▄▄▄▅▂▄
grads/p_lay.0.bias.mean,▇█▄▅▆▄▇▅▃▇▅▃▆▄▆█▃▅▂▃▄▄▆▆▃▃▅▅▄▆▆▇▄▃▄▇▁▅▂▄
grads/p_lay.0.bias.std,▃▄▃▂█▃▄▂▄▃▅▃▂▃▃▄▂▂▄▂▃▁▂▄▁▄▃▂▇▃▂▃▃▃▅▂▄▄▂▁
grads/p_lay.0.weight.mean,▃▆▅▅▆▄▇▅▅▅▆▅█▇▁▇▄▆▄▆▆▄▆▄▅▅▅▃▁▅▄▇█▃▅▅▇▄▁▆
grads/p_lay.0.weight.std,█▇▆▅▆▄▆▂▄▇▅▅▄▇▆▅▂▂▃▃▄▁▄▄▄▅▃▃▇▂▆▃▄▅▇▃▃▅▅▃

0,1
energy/e.mean,-136.49573
energy/e.std,23.19153
energy/ke.mean,132.4101
energy/ke.std,80.59218
energy/pe.mean,-268.90582
energy/pe.std,99.78049
grads/p_lay.0.bias.mean,-0.01253
grads/p_lay.0.bias.std,0.03566
grads/p_lay.0.weight.mean,0.01212
grads/p_lay.0.weight.std,0.08993


# fin

## TODO
- [x] demo wandb
- [ ] test slurm
- [ ] test Accelerate
- [ ] distribution (Accelerate + slurm)
- [ ] docker
- [ ] kubernetes
- [ ] deploy as pip install package
- [ ] lightning integration

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=c3092bd2-44d7-455e-aedb-c6812270d279' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>