In [1]:
import os
import shutil
import zipfile
import warnings
from functools import partial

import pandas as pd
from tqdm import tqdm
import lightning as pl
from ase.io import read
import matplotlib.pyplot as plt
from pymatgen.core import Structure
from dgl.data.utils import split_dataset
from pytorch_lightning.loggers import CSVLogger

from matgl.models import M3GNet
from matgl.utils.io import RemoteFile
from matgl.utils.training import ModelLightningModule
from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.graph.data import MGLDataset, MGLDataLoader, collate_fn_graph

warnings.simplefilter("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
torch.__version__

'2.2.0+cu121'

In [None]:
# https://matgl.ai/tutorials/Training%20a%20M3GNet%20Potential%20with%20PyTorch%20Lightning.html

### Data Load

In [3]:
files_structures = os.listdir('Dataset_ZrO2')
if '.ipynb_checkpoints' in files_structures: files_structures.remove('.ipynb_checkpoints')

strucs,energies,forces = [],[],[]

for file in files_structures:

    structures = read(f'Dataset_ZrO2/{file}',index=':')

    for structure in structures:

        struc_forces = structure.get_forces()
        energy = structure.get_potential_energy()

        append = True
        
        for force in struc_forces:
            if append:
                for atom_force in force:
                    if atom_force > 35:
                        append = False
                        break

        if append:

            structure = Structure.from_ase_atoms(structure)
            
            strucs.append(structure)
            forces.append(struc_forces)
            energies.append(energy)

In [4]:
print('strucs\tforces\tenergies')
print(len(strucs),'\t',len(forces),'\t',len(energies))

strucs	forces	energies
14427 	 14427 	 14427


In [5]:
# get element types in the dataset
elem_list = get_element_list(strucs)

# setup a graph converter
converter = Structure2Graph(element_types=elem_list, cutoff=4.0)

# convert the raw dataset into M3GNetDataset
mp_dataset = MGLDataset(
    threebody_cutoff=4.0,
    structures=strucs,
    converter=converter,
    labels={'forces':forces,'energies':energies},
    include_line_graph=True,
)

In [6]:
train_data, val_data, test_data = split_dataset(
    mp_dataset,
    frac_list=[0.8, 0.1, 0.1],
    shuffle=True,
    random_state=42,
)

my_collate_fn = partial(collate_fn_graph, include_line_graph=True)
train_loader, val_loader, test_loader = MGLDataLoader(
    train_data=train_data,
    val_data=val_data,
    test_data=test_data,
    collate_fn=my_collate_fn,
    batch_size=2,
    num_workers=1,
)

In [7]:
# setup the architecture of M3GNet model
model = M3GNet(
    element_types=elem_list,
    is_intensive=True,
    readout_type="set2set",
)
# setup the M3GNetTrainer
lit_module = ModelLightningModule(model=model, include_line_graph=True)

In [8]:
logger = CSVLogger("logs", name="M3GNet_training")
trainer = pl.Trainer(max_epochs=20, accelerator="gpu", logger=logger)
trainer.fit(model=lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader)

Trainer will use only 1 of 3 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=3)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA L40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]

  | Name  | Type              | Params | Mode 
----------------------------------------------------
0 | model | M3GNet            | 386 K  | train
1 | mae   | MeanAbsoluteError | 0      | train
2 | rmse  | 

Epoch 0: 100%|██████████| 5771/5771 [03:53<00:00, 24.72it/s, v_num=0]      
[Aidation: |          | 0/? [00:00<?, ?it/s]
[Aidation:   0%|          | 0/721 [00:00<?, ?it/s]
[Aidation DataLoader 0:   0%|          | 0/721 [00:00<?, ?it/s]
[Aidation DataLoader 0:   0%|          | 1/721 [00:00<00:22, 32.43it/s]
[Aidation DataLoader 0:   0%|          | 2/721 [00:00<00:20, 35.37it/s]
[Aidation DataLoader 0:   0%|          | 3/721 [00:00<00:19, 37.41it/s]
[Aidation DataLoader 0:   1%|          | 4/721 [00:00<00:18, 38.56it/s]
[Aidation DataLoader 0:   1%|          | 5/721 [00:00<00:18, 38.53it/s]
[Aidation DataLoader 0:   1%|          | 6/721 [00:00<00:18, 38.46it/s]
[Aidation DataLoader 0:   1%|          | 7/721 [00:00<00:18, 38.12it/s]
[Aidation DataLoader 0:   1%|          | 8/721 [00:00<00:18, 38.25it/s]
[Aidation DataLoader 0:   1%|          | 9/721 [00:00<00:18, 38.39it/s]
[Aidation DataLoader 0:   1%|▏         | 10/721 [00:00<00:18, 38.45it/s]
[Aidation DataLoader 0:   2%|


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [51]:
trainer.test(dataloaders=test_loader)

Restoring states from the checkpoint path at logs/M3GNet_training/version_0/checkpoints/epoch=0-step=5771.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]
Loaded model weights from the checkpoint at logs/M3GNet_training/version_0/checkpoints/epoch=0-step=5771.ckpt


Testing: |          | 722/? [00:17<00:00, 40.43it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_MAE             77.02153778076172
        test_RMSE            94.90870666503906
     test_Total_Loss           38196.765625
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_Total_Loss': 38196.765625,
  'test_MAE': 77.02153778076172,
  'test_RMSE': 94.90870666503906}]