In [1]:
%load_ext autoreload
%autoreload 2

In [2]:

import torch
from franken.backbones.utils import CacheDir
from franken.rf.model import FrankenPotential, GaussianRFParams

In [3]:
# It is a bit weird that to initialize a FrankenPotential we need to initialize a rather obscure CacheDir first. Make it easier.
device = torch.device('cuda:0')
CacheDir.initialize()
model = FrankenPotential("SchNet-S2EF-OC20-200k", "gaussian", GaussianRFParams()).to(device)

num_solver_params = 121
all_weights = torch.zeros(
    num_solver_params,
    model.rf.total_random_features,
    dtype=torch.float32,
    device=device,
)

# Dataset init.
from franken.data import BaseAtomsDataset, Target
from franken.data.utils import get_dataloader

dset = BaseAtomsDataset.from_path(
    "/home/novelli/franken/datasets/test/32/train.xyz",
    split="train",
    gnn_backbone_id="SchNet-S2EF-OC20-200k",
)
dloader = get_dataloader(dset, distributed=False)

converting ASE atoms collection to graphs:   0%|          | 0/32 [00:00<?, ? systems/s]

In [8]:
import franken.metrics

metrics = {k: franken.metrics.spawn(k, device=device) for k in franken.metrics.available_metrics()}
for data, targets in dloader:
    data = data.to(device=device)
    targets = targets.to(device=device)
    predictions = Target(*model.energy_and_forces(
        data, weights=all_weights, forces_mode="torch.func"
    ))
    for m in metrics.values():
        m.update(predictions, targets)

In [12]:
metric_values = {
    k: metrics[k].compute() for k in metrics
}

In [14]:
metric_values

{'energy_MAE': tensor([7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233,
         7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233,
         7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233,
         7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233,
         7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233,
         7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233,
         7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233,
         7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233,
         7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233,
         7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233,
         7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233,
         7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233, 7.5233,
         7.523

In [13]:
metrics['energy_MAE'].buffer.shape

AttributeError: 'NoneType' object has no attribute 'shape'

In [11]:
metrics['forces_MAE'].buffer.shape

torch.Size([121])

In [10]:
metrics['forces_cosim'].buffer.shape

torch.Size([121])