In [10]:
# Wrap the old code into kooplearn
import schnetpack
import ml_confs
import torch
import pytorch_lightning as pl
from typing import Any
import sys
sys.path.append('../../dpnets_legacy/')
import koopnet
from collections import OrderedDict

class Namespace:
    def __init__(self, d):
        for k, v in d.items():
            if isinstance(v, dict):
                setattr(self, k, Namespace(v))
            else:
                setattr(self, k, v)

configs = {
    'data': {
        'cutoff': 5
    },
    'network': {
        'n_rbf': 20,
        'n_atom_basis': 64,
        'n_final_features': 16,
        'n_interactions': 3
    }
}

class SchNet(schnetpack.model.AtomisticModel):
    def __init__(self, configs: Namespace):
        super().__init__(
            input_dtype_str="float32",
            postprocessors=None,
            do_postprocessing=False,
        )
        self.cutoff = configs.data.cutoff
        self.pwise_dist = schnetpack.atomistic.PairwiseDistances()
        self.radial_basis = schnetpack.nn.GaussianRBF(n_rbf=configs.network.n_rbf, cutoff=self.cutoff)
        self.net = schnetpack.representation.SchNet(
            n_atom_basis=configs.network.n_atom_basis, 
            n_interactions=configs.network.n_interactions,
            radial_basis=self.radial_basis,
            cutoff_fn=schnetpack.nn.CosineCutoff(self.cutoff)
        )
        self.final_lin = torch.nn.Linear(configs.network.n_atom_basis, configs.network.n_final_features)
        self.batch_norm = torch.nn.BatchNorm1d(configs.network.n_final_features, affine=False)
    def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
        inputs = self.pwise_dist(inputs)
        inputs = self.net(inputs)
        inputs['scalar_representation'] = self.final_lin(inputs['scalar_representation'])
        return inputs

In [16]:
ckpt_path = 'ckpt/colorful-cosmos-97/'
ckpt = torch.load(ckpt_path + 'last.ckpt')
state_dict = ckpt['state_dict']
keys = [k[7:] for k in state_dict.keys()]
values = list(state_dict.values())
state_dict = OrderedDict(zip(keys, values))

In [17]:
model = SchNet(Namespace(configs))

In [19]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [22]:
torch.save(model.state_dict(), 'schnet_model_ckpt.pt')