# Notebook for downloading and converting datasets

In [None]:
import os
tmp = os.environ['TMPDIR']

## Helper functions

In [None]:
def step(cnt):
    if cnt%100==0:
        print(f'\r{cnt} data processed', end='')
    return cnt+1

def load_transition1x(fname, selection='all', slient=False):
    """ Adapted from the original transition1x dataset loader, with 
    different selection rules that allows for selecting reactants, 
    products, etc.
    
    Args:
        fname (str): path to the file
        selection (str): "all", "final", "reactant", "product", or "ts"
    """
    assert selection in ("all", "final", "reactant", "product", "ts"), "Unknown selection"
    
    import h5py
    from tips.io import Dataset
    
    meta = {
        'spec': {
            'elems': {'shape': [None], 'dtype': 'int32'},
            'coord': {'shape': [None, 3], 'dtype': 'float32'},
            'e_data': {'shape': [], 'dtype': 'float32'},
            'f_data': {'shape': [None, 3], 'dtype': 'float32'},
        },
        'fmt': 'Transition 1x dataset'
    }
    
    
    def grp2d(formula, rxn, grp):
        """ Iterates through a h5 group (transition 1x ver.)"""
        energies = grp["wB97x_6-31G(d).energy"]
        forces = grp["wB97x_6-31G(d).forces"]
        numbers = list(grp["atomic_numbers"])
        positions = grp["positions"]
        for energy, force, positions in zip(energies, forces, positions):
            d = {
                "e_data": energy.__float__(),
                "f_data": force.tolist(),
                "coord": positions,
                "elem": numbers,
            }
        yield d
    
    def generator():
        h5ds = h5py.File(fname)['data']
        cnt = 0
        def step(cnt):
            if cnt%100==0:
                print(f'\r{cnt} data processed', end='')
            return cnt+1
        
        for formula, grp in h5ds.items():
            for rxn, subgrp in grp.items():
                if selection in ("all", "final", "reactant"):
                    mol = next(grp2d(formula, rxn, subgrp["reactant"]))
                    cnt = step(cnt); yield mol
                    
                if selection in ("all", "final", "product"):
                    mol = next(grp2d(formula, rxn, subgrp["product"]))
                    cnt = step(cnt); yield mol
                    
                if selection in ("final", "ts"):
                    mol = next(grp2d(formula, rxn, subgrp["transition_state"]))
                    cnt = step(cnt); yield mol

                if selection in ("all"):
                    for mol in grp2d(formula, rxn, subgrp):
                        cnt = step(cnt); yield mol
                        
    return Dataset(generator=generator, meta=meta)

def load_qm9x(fname, slient=False):
    import h5py
    from tips.io import Dataset
    
    def grp2d(formula,  grp):
        """ Iterates through a h5 group (QM9x ver.)"""
        energies = grp["energy"]
        forces = grp["forces"]
        numbers = list(grp["atomic_numbers"])
        positions = grp["positions"]
        for energy, force, positions in zip(energies, forces, positions):
            d = {
                "e_data": energy.__float__(),
                "f_data": force.tolist(),
                "coord": positions,
                "elem": numbers,
            }
            yield d
    
    meta = {
        'spec': {
            'elems': {'shape': [None], 'dtype': 'int32'},
            'coord': {'shape': [None, 3], 'dtype': 'float32'},
            'e_data': {'shape': [], 'dtype': 'float32'},
            'f_data': {'shape': [None, 3], 'dtype': 'float32'},
        },
        'fmt': 'QM9x dataset'
    }
    
    def generator():
        h5ds = h5py.File(fname)
        cnt = 0
        for formula, grp in h5ds.items():
            for mol in grp2d(formula, grp):
                cnt = step(cnt)
                yield mol
                        
    return Dataset(generator=generator, meta=meta)

## Downdload QM9x

In [None]:
!wget -q -O download_qm9x.py https://gitlab.com/matschreiner/qm9x/-/raw/master/scripts/download_qm9x.py
!python download_qm9x.py {tmp}

## Download Transition 1x

In [None]:
!wget -q -O download_t1x.py https://gitlab.com/matschreiner/Transition1x/-/raw/main/download_t1x.py
!python download_t1x.py {tmp}

## Load and convert T1X

**OPTIONS FOR T1X**
- "reactant", "product", or "ts"
- "final": reactants, products, transition states
- "all": all data from t1x

In [None]:
fname = f'{tmp}/transition1x.h5'
ds = load_transition1x(fname, 'reactant') 
ds.convert('product', fmt='pinn')

In [None]:
fname = f'{tmp}/qm9x.h5'
ds = load_qm9x(fname)
ds.convert('qm9x', fmt='pinn')

## Download and convert QM9

In [None]:
!mkdir -p {tmp}/dsgdb9nsd && curl -sSL https://ndownloader.figshare.com/files/3195389 | tar xj -C {tmp}/dsgdb9nsd

In [None]:
# QM9 support is through the pinn package, note the syntax difference
from pinn.io import load_qm9, write_tfrecord
from glob import glob

def hartree2ev(datum):
    datum['e_data'] = 27.211407953*datum['e_data']
    return datum

filelist = sorted(glob(f'{tmp}/dsgdb9nsd/*.xyz'))
ds = load_qm9(filelist).map(hartree2ev)
write_tfrecord('qm9.yml', ds)