In [3]:
# Wrap the old code into kooplearn
import torch
import schnetpack
from kooplearn.abc import FeatureMap
from kooplearn.models import DeepEDMD

class Namespace:
    def __init__(self, d):
        for k, v in d.items():
            if isinstance(v, dict):
                setattr(self, k, Namespace(v))
            else:
                setattr(self, k, v)
configs = Namespace({
    'data': {
        'cutoff': 5
    },
    'network': {
        'n_rbf': 20,
        'n_atom_basis': 64,
        'n_final_features': 16,
        'n_interactions': 3
    }
})

class SchNet(schnetpack.model.AtomisticModel):
    def __init__(self, configs: Namespace):
        super().__init__(
            input_dtype_str="float32",
            postprocessors=None,
            do_postprocessing=False,
        )
        self.cutoff = configs.data.cutoff
        self.pwise_dist = schnetpack.atomistic.PairwiseDistances()
        self.radial_basis = schnetpack.nn.GaussianRBF(n_rbf=configs.network.n_rbf, cutoff=self.cutoff)
        self.net = schnetpack.representation.SchNet(
            n_atom_basis=configs.network.n_atom_basis, 
            n_interactions=configs.network.n_interactions,
            radial_basis=self.radial_basis,
            cutoff_fn=schnetpack.nn.CosineCutoff(self.cutoff)
        )
        self.final_lin = torch.nn.Linear(configs.network.n_atom_basis, configs.network.n_final_features)
        self.batch_norm = torch.nn.BatchNorm1d(configs.network.n_final_features, affine=False)
    def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        inputs = self.pwise_dist(inputs)
        inputs = self.net(inputs)
        inputs['scalar_representation'] = self.final_lin(inputs['scalar_representation'])
        return inputs

### Data loading

In [5]:
import os
from pathlib import Path
data_path = Path('data/')
db_path = data_path / 'CLN025-0-protein-ALL.db'
cache_path = data_path / 'cache'
nb_list_transform = schnetpack.transform.CachedNeighborList(cache_path, schnetpack.transform.MatScipyNeighborList(cutoff=configs.data.cutoff), keep_cache=True)
in_transforms = [schnetpack.transform.CastTo32(), nb_list_transform]
dataset = schnetpack.data.ASEAtomsData(str(db_path), transforms=in_transforms)
dataloader = schnetpack.data.AtomsLoader(dataset, num_workers=20, persistent_workers=True)

In [6]:
for batch in dataloader:
    print (batch)
    break

{'_idx': tensor([0]), '_n_atoms': tensor([93]), '_atomic_numbers': tensor([7, 6, 6, 6, 6, 6, 6, 8, 6, 6, 6, 8, 7, 6, 6, 6, 6, 6, 6, 8, 6, 6, 6, 8,
        7, 6, 6, 6, 8, 8, 6, 8, 7, 6, 6, 6, 6, 6, 8, 7, 6, 6, 6, 6, 8, 8, 6, 8,
        7, 6, 6, 8, 6, 6, 8, 7, 6, 6, 8, 7, 6, 6, 8, 6, 6, 8, 7, 6, 6, 6, 6, 7,
        6, 6, 6, 6, 6, 6, 6, 8, 6, 8, 8, 7, 6, 6, 6, 6, 6, 6, 8, 6, 6]), '_positions': tensor([[ 13.2955, -14.7450, -11.2065],
        [ 13.5762, -13.8061, -12.2575],
        [ 14.5278, -14.4698, -13.2362],
        [ 14.7350, -13.4173, -14.3212],
        [ 14.1697, -13.5783, -15.6179],
        [ 14.1953, -12.5928, -16.6435],
        [ 14.6779, -11.3134, -16.2287],
        [ 14.8198, -10.3106, -17.2147],
        [ 15.2583, -12.1480, -14.0601],
        [ 15.2582, -11.1114, -15.0057],
        [ 12.2009, -13.3506, -12.6907],
        [ 11.6628, -12.4301, -12.0687],
        [ 11.6444, -13.7536, -13.7862],
        [ 10.3856, -13.1981, -14.2790],
        [ 10.7510, -11.9113, -15.1152],
      

In [2]:
state_dict = torch.load('schnet_model_ckpt.pt')
model = SchNet(configs).load_state_dict(state_dict)