In [1]:
# Set seed for everything
from pytorch_lightning import seed_everything
seed_everything(123321, workers=True)

# Ignore warnings for clarity
import warnings
warnings.filterwarnings("ignore")

import sys
sys.path.append("../")

import os
import argparse
import yaml
from sklearn.model_selection import train_test_split

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger, WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split

from utils.training import DielectricModule
from data.dataset import StructureDataset
from models.equivariant_model import GatedEquivariantModel
from confidential.models import CombinedModel
from confidential.utils import load_pfp, collate_fn_dict
from matbench.bench import MatbenchBenchmark

Global seed set to 123321


In [2]:
with open('../scripts/bench_config.yaml','r') as f:
    configs = yaml.safe_load(f)

In [3]:
mb = MatbenchBenchmark(autoload=False,subset=["matbench_dielectric"])

train_config = configs.pop("Train")
model_config = configs.pop("Model")
outdir = train_config.pop("save_path")

batch_size = train_config.pop("batch")
num_workers = train_config.pop("num_workers")
pfp_layer = model_config.pop("pfp_layer")
lr = train_config.pop("lr")
train_pfp = model_config.pop("train_pfp") # Bool

patience = train_config.pop("patience")
epoch = train_config.pop("epoch")
accelerator = train_config.pop("accelerator")
devices = train_config.pop("device")
clip_val = train_config.pop("gradient_clip")

2024-08-15 00:07:17 INFO     Initialized benchmark 'matbench_v0.1' with 1 tasks: 
['matbench_dielectric']


In [4]:
for task in mb.tasks:
    task.load()
    for fold in task.folds:
        print(f"Start training of fold-{fold}.")

        os.makedirs(outdir + str(fold), exist_ok=True)
        train_inputs, train_outputs = task.get_train_and_val_data(fold)        

        x_train, x_val, y_train, y_val, k_train, k_val = train_test_split(
            train_inputs.tolist(), train_outputs.tolist(), train_inputs.index.tolist(), test_size=0.1, random_state=123321)

        train_dataset = StructureDataset(x_train, y_train, k_train)
        val_dataset = StructureDataset(x_val, y_val, k_val)

        train_loader = DataLoader(
            train_dataset, batch_size=batch_size, collate_fn=collate_fn_dict, 
            shuffle=True, num_workers=num_workers, pin_memory=True
        )
        val_loader = DataLoader(
            val_dataset, batch_size=batch_size, collate_fn=collate_fn_dict,
            shuffle=False, num_workers=num_workers, pin_memory=True
        )

        # load the pre-trained PFP model
        pfp_wrapped = load_pfp(
            load_parameters=True, 
            return_layer=pfp_layer)

        # Freeze PFP parameters
        if train_pfp:
            pfp_wrapped.pfp.reset_parameters(4.0)
        else:
            for param in pfp_wrapped.parameters():
                param.requires_grad = False

        # Build Readout NN
        tensorial_model = GatedEquivariantModel(**model_config)
        # Build a combined model to connect readout NN to PFP
        model = CombinedModel(pfp_wrapped, tensorial_model)

        pl_module = DielectricModule(
            model, 
            learning_rate=lr)

        project_name = f'matbench_layer{pfp_layer}_fold{fold}'
        # wandb_logger = WandbLogger(project=project_name)
        csv_logger = CSVLogger("logs")

        checkpoint_callback = ModelCheckpoint(
            save_top_k=1,
            monitor="val_loss",
            mode="min",
            dirpath=f"{outdir}/pl_checkpoints_fold{fold}/",
            filename="eps-{epoch:02d}-{val_loss:.2f}",
        )

        earlystopping_callback = EarlyStopping("val_loss", mode="min", 
            patience=patience)

        trainer = pl.Trainer(
            max_epochs=epoch, 
            accelerator=accelerator, 
            logger=csv_logger,  
            devices=devices, 
            callbacks=[checkpoint_callback, earlystopping_callback],
            gradient_clip_val=clip_val,
            enable_model_summary=False)
        trainer.fit(
            model=pl_module, 
            train_dataloaders=train_loader, 
            val_dataloaders=val_loader)

        test_inputs, test_outputs = task.get_test_data(fold, include_target=True)
        test_dataset = StructureDataset(test_inputs.tolist(), test_outputs.tolist(), test_inputs.index.tolist())
        test_loader = DataLoader(
            test_dataset, batch_size=batch_size, collate_fn=collate_fn_dict,
            shuffle=False, num_workers=num_workers, pin_memory=True
        )

        # trainer.test(pl_module, test_loader, ckpt_path='best')
        preds = trainer.predict(pl_module, test_loader, ckpt_path='best')
        preds = torch.cat(preds, dim=0).detach().cpu().numpy()
        preds = np.mean(np.diagonal(preds, axis1=1, axis2=2), axis=1)

        # Record your data!
        task.record(fold, preds)

# Save your results
mb.to_file("dielectric_matbench1.json.gz")

2024-08-15 00:07:17 INFO     Loading dataset 'matbench_dielectric'...
2024-08-15 00:07:20 INFO     Dataset 'matbench_dielectric loaded.
Start training of fold-0.


processing structures to build the dataset..: 3429it [00:18, 187.29it/s]
processing structures to build the dataset..: 382it [00:01, 206.93it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6,5,7,2,4,3,1,0]


Epoch 454: 100%|██████████| 54/54 [00:14<00:00,  3.82it/s, v_num=32, val_loss=0.276, val_mae=0.276, val_rmse=1.450, train_loss=0.121, train_mae=0.121, train_rmse=0.842, lr=2.25e-6]


processing structures to build the dataset..: 953it [00:06, 158.81it/s]
Restoring states from the checkpoint path at ../confidential/matbench/checkpoints/pl_checkpoints_fold0/eps-epoch=254-val_loss=0.27.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6,5,7,2,4,3,1,0]
Loaded model weights from the checkpoint at ../confidential/matbench/checkpoints/pl_checkpoints_fold0/eps-epoch=254-val_loss=0.27.ckpt


Predicting DataLoader 0: 100%|██████████| 15/15 [00:02<00:00,  5.46it/s]
2024-08-15 01:51:30 INFO     Recorded fold matbench_dielectric-0 successfully.
Start training of fold-1.


processing structures to build the dataset..: 3429it [00:19, 174.84it/s]
processing structures to build the dataset..: 382it [00:02, 130.64it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6,5,7,2,4,3,1,0]


Epoch 575: 100%|██████████| 54/54 [00:07<00:00,  7.39it/s, v_num=36, val_loss=0.227, val_mae=0.227, val_rmse=0.972, train_loss=0.125, train_mae=0.125, train_rmse=0.831, lr=5.9e-7] 


processing structures to build the dataset..: 953it [00:05, 187.46it/s]
Restoring states from the checkpoint path at ../confidential/matbench/checkpoints/pl_checkpoints_fold1/eps-epoch=375-val_loss=0.23.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6,5,7,2,4,3,1,0]
Loaded model weights from the checkpoint at ../confidential/matbench/checkpoints/pl_checkpoints_fold1/eps-epoch=375-val_loss=0.23.ckpt


Predicting DataLoader 0: 100%|██████████| 15/15 [00:00<00:00, 17.83it/s]
2024-08-15 03:21:33 INFO     Recorded fold matbench_dielectric-1 successfully.
Start training of fold-2.


processing structures to build the dataset..: 3429it [00:19, 177.34it/s]
processing structures to build the dataset..: 382it [00:01, 220.09it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6,5,7,2,4,3,1,0]


Epoch 337: 100%|██████████| 54/54 [00:07<00:00,  7.31it/s, v_num=39, val_loss=0.245, val_mae=0.245, val_rmse=1.220, train_loss=0.0867, train_mae=0.0867, train_rmse=0.547, lr=3.52e-6]


processing structures to build the dataset..: 953it [00:05, 186.64it/s]
Restoring states from the checkpoint path at ../confidential/matbench/checkpoints/pl_checkpoints_fold2/eps-epoch=137-val_loss=0.24.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6,5,7,2,4,3,1,0]
Loaded model weights from the checkpoint at ../confidential/matbench/checkpoints/pl_checkpoints_fold2/eps-epoch=137-val_loss=0.24.ckpt


Predicting DataLoader 0: 100%|██████████| 15/15 [00:00<00:00, 17.56it/s]
2024-08-15 04:05:10 INFO     Recorded fold matbench_dielectric-2 successfully.
Start training of fold-3.


processing structures to build the dataset..: 3429it [00:18, 189.67it/s]
processing structures to build the dataset..: 382it [00:02, 169.58it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6,5,7,2,4,3,1,0]


Epoch 697: 100%|██████████| 54/54 [00:06<00:00,  7.79it/s, v_num=41, val_loss=0.170, val_mae=0.170, val_rmse=0.628, train_loss=0.111, train_mae=0.111, train_rmse=0.737, lr=3.02e-7]


processing structures to build the dataset..: 953it [00:05, 176.98it/s]
Restoring states from the checkpoint path at ../confidential/matbench/checkpoints/pl_checkpoints_fold3/eps-epoch=497-val_loss=0.17.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6,5,7,2,4,3,1,0]
Loaded model weights from the checkpoint at ../confidential/matbench/checkpoints/pl_checkpoints_fold3/eps-epoch=497-val_loss=0.17.ckpt


Predicting DataLoader 0: 100%|██████████| 15/15 [00:00<00:00, 19.46it/s]
2024-08-15 05:27:43 INFO     Recorded fold matbench_dielectric-3 successfully.
Start training of fold-4.


processing structures to build the dataset..: 3430it [00:18, 189.40it/s]
processing structures to build the dataset..: 382it [00:02, 180.47it/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6,5,7,2,4,3,1,0]


Epoch 379: 100%|██████████| 54/54 [00:06<00:00,  7.91it/s, v_num=42, val_loss=0.240, val_mae=0.240, val_rmse=0.919, train_loss=0.116, train_mae=0.116, train_rmse=0.753, lr=2.25e-6]


processing structures to build the dataset..: 952it [00:04, 209.40it/s]
Restoring states from the checkpoint path at ../confidential/matbench/checkpoints/pl_checkpoints_fold4/eps-epoch=179-val_loss=0.24.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [6,5,7,2,4,3,1,0]
Loaded model weights from the checkpoint at ../confidential/matbench/checkpoints/pl_checkpoints_fold4/eps-epoch=179-val_loss=0.24.ckpt


Predicting DataLoader 0: 100%|██████████| 15/15 [00:00<00:00, 19.02it/s]
2024-08-15 06:12:27 INFO     Recorded fold matbench_dielectric-4 successfully.
2024-08-15 06:12:27 INFO     Successfully wrote MatbenchBenchmark to file 'dielectric_matbench1.json.gz'.


In [5]:
scores = mb.matbench_dielectric.scores
scores

{'mae': {'mean': 0.2370850306791295,
  'max': 0.3516925340028822,
  'min': 0.11753604419222681,
  'std': 0.07649557856878016},
 'rmse': {'mean': 1.6829936226656266,
  'max': 2.901275849331958,
  'min': 0.5732083587460219,
  'std': 0.8485401319998302},
 'mape': {'mean': 0.05530619144780041,
  'max': 0.06721239289649394,
  'min': 0.03617504271599022,
  'std': 0.013170799705305278},
 'max_error': {'mean': 34.59824110916204,
  'max': 58.784419044481886,
  'min': 13.862970454340712,
  'std': 18.27463111793412},
 '_ipython_canary_method_should_not_exist_': {}}