In [1]:
import os
from pathlib import Path
import pandas as pd
import numpy as np
import orjson
import zipfile
import zstandard as zstd
import shutil
from pymatgen.core import Structure

#### Load data

##### From QMOF

Expecting to have qmof_database.zip (downloaded from Figshare https://docs.materialsproject.org/apps/explorer-apps/mof-explorer/downloading-the-data) inside folder data/qmof_db/

In [2]:
# dataset saving paths
basedir = "data/"
qmof_db_dir = basedir + "QMOF/"
qmof_db_zip_path = basedir + "QMOF/qmof_database.zip"
qmof_db_path = qmof_db_dir + "qmof_database/"
if not Path(qmof_db_dir + "qmof_database").is_dir():
    with zipfile.ZipFile(qmof_db_zip_path, "r") as f:
        f.extractall(qmof_db_dir)

In [3]:
def save_to_json(data: dict, output_file):
    with open(output_file, "wb") as f:
        f.write(orjson.dumps(data, option=orjson.OPT_SERIALIZE_NUMPY))


def load_from_json(file):
    with open(file, "rb") as f:
        return orjson.loads(f.read())

In [4]:
def compress(from_file, to_file):
    with open(from_file, "rb") as f:
        res = zstd.ZstdCompressor(level=15, threads=-1).compress(f.read())
    with open(to_file, "wb") as f:
        f.write(res)


def decompress(from_file, to_file):
    with open(from_file, "rb") as f:
        res = zstd.decompress(f.read())
    with open(to_file, "wb") as f:
        f.write(res)


def compress_inplace(dataset_path):
    filenames = os.listdir(dataset_path)
    for filename in filenames:
        if filename == "props.json" or filename == "cifs.json":
            compress(dataset_path + filename, dataset_path + filename + ".zstd")


def decompress_inplace(dataset_path):
    filenames = os.listdir(dataset_path)
    for filename in filenames:
        if filename == "props.json.zstd" or filename == "cifs.json.zstd":
            decompress(
                dataset_path + filename, dataset_path + filename.removesuffix(".zstd")
            )

In [5]:
def test_cif_read():
    with open(qmof_db_path + "qmof_structure_data.json", "rb") as f:
        struct_data = orjson.loads(f.read())


def create_cifs():
    with open(qmof_db_path + "qmof_structure_data.json", "rb") as f:
        struct_data = orjson.loads(f.read())
    cifs = {}
    for entry in struct_data:
        cifs[entry["qmof_id"]] = entry["structure"]
    with open(basedir + "cifs.json", "wb") as f:
        f.write(orjson.dumps(cifs))

In [6]:
def create_props_from_band_gap():
    props = dict()
    with open(qmof_db_path + "qmof.json", "rb") as f:
        qmof = orjson.loads(f.read())
        # print(qmof[0]['outputs']['pbe']['bandgap'])
        for m in qmof:
            name = m["qmof_id"]
            value = {"bandgap": m["outputs"]["pbe"]["bandgap"]}
            props[name] = value
    with open(basedir + "props.json", "wb") as f:
        f.write(orjson.dumps(props))

In [7]:
create_props_from_band_gap()
create_cifs()

In [8]:
def set_property_to_ids(
    df: pd.DataFrame, property: str, csv: str = "./data/root/data/id_prop.csv"
):
    df[property].dropna().to_csv(csv, index=True, header=False)

### Training

In [9]:
from typing import Callable

is_clearml = True

In [10]:
# plot general train progress
import matplotlib.pyplot as plt

ds_prop_mae = dict()


def plot_ds_prop_mae():
    plt.figure(figsize=(10, 6))
    for ds, prop_mae in ds_prop_mae.items():
        props = list(prop_mae.keys())
        maes = list(prop_mae[k] for k in props)
        # line
        (c,) = plt.plot(props, maes, alpha=0.3, lw=2)
        color = c.get_color()
        # point
        plt.scatter(props, maes, alpha=0.6, color=color, marker="x", s=25)
        # label
        plt.plot([], [], alpha=1, color=color, label=ds)
    plt.ylabel("MAE")
    plt.xticks(rotation=45)
    plt.title("MAE for different datasets and properties")
    plt.legend()
    plt.tight_layout()

In [11]:
def train_default(ds_path: str, **kwargs) -> tuple[float, dict]:
    """train with default hyperparameters"""
    import model.main as main

    main.args.use_clearml = True
    main.args.max_cache_size = 160000
    main.args.data_options = [ds_path]
    mae = main.main()

    return float(mae), vars(main.args).copy()


def clearml_train_logger(
    ds_path: str,
    property: str,
    train_fn: Callable[[str], float],
    notes: str = "",
):
    ds_name = Path(ds_path).name
    if not is_clearml:
        mae = train_fn()
        print(ds_name, mae)
    else:
        # prepare task
        from clearml import Task

        Task.set_offline(True)

        task: Task = Task.init(
            project_name="rcgcnn",
            task_name=f"train {property} on {ds_name}",
            auto_connect_frameworks={
                # gpu info
                "tensorboard": True,
                "matplotlib": True,
                "tensorflow": False,
                "pytorch": True,
                "xgboost": False,
                "scikit": False,
                "fastai": False,
                "lightgbm": False,
                "hydra": False,
                "detect_repository": False,
                "tfdefines": False,
                "joblib": False,
                "megengine": False,
                "catboost": False,
            },
        )

        mae, args_dict = train_fn(ds_path, property=property)
        # finish task
        compress(ds_path + "id_prop.csv", ds_path + "id_prop.csv.zstd")
        task.upload_artifact(
            name="id_prop.csv.zstd", artifact_object=ds_path + "id_prop.csv.zstd"
        )
        # save model file
        task.upload_artifact(
            name="model_best.pth.tar", artifact_object="./model_best.pth.tar"
        )
        # save test results
        compress("./test_results.csv", "./test_results.csv.zstd")
        task.upload_artifact(
            name="test_results.csv.zstd", artifact_object="./test_results.csv.zstd"
        )

        ds_prop_mae.setdefault(ds_name, dict())[property] = mae

        task.upload_artifact(name="general_ds_prop_mae", artifact_object=ds_prop_mae)
        task.upload_artifact(name="args", artifact_object=args_dict)

        args_info = {
            "args": args_dict,
        }
        train_info = {
            "result_mae": float(mae),
            "ds_path": ds_path,
            "ds_name": ds_name,
            "property": property,
            "notes": notes,
        }
        task.connect(args_info)
        task.connect(train_info, name="train_info")

        if len(notes) > 0:
            # rename task taking notes into account
            task.set_name(f"train {property} on {ds_name} {notes}")

        # general progress plot
        logger = task.get_logger()
        plot_ds_prop_mae()
        logger.report_matplotlib_figure(
            title="General progress plot",
            series="MAE",
            figure=plt,
            report_interactive=False,
            report_image=True,
        )
        logger.report_matplotlib_figure(
            title="General progress plot",
            series="MAE",
            figure=plt,
        )
        plt.close()

        task.close()
        #
        # upload offline task according to https://clear.ml/docs/latest/docs/guides/set_offline/
        # Task.set_offline(False)
        # Task.import_offline_session(str(task.get_offline_mode_folder()))


def train_on_dataset(
    ds_path: str,
    ds_props: list | dict,
    clearml_train_logger: Callable[[str, Callable[[str], float]], float],
    train_fn: Callable[[], float],
):
    print(ds_path, ds_props)
    full_df = pd.DataFrame(load_from_json(ds_path + "props.json")).transpose()
    # clear sys argv for argparse
    import sys

    sys.argv = [ds_path, ds_path]
    del sys
    # initialize features from cifs
    # args influence CIFData only via data_options, max_cache_size and workers

    # shut.copy(basedir + "atom_init.json", ds_path)
    # choose ds prop format:
    for prop in ds_props:
        if prop in full_df.columns.values:
            set_property_to_ids(full_df, prop, ds_path + "id_prop.csv")
            clearml_train_logger(ds_path, prop, train_fn)
        else:
            print("no", prop, "in", ds_path)

In [12]:
train_on_dataset("./data/", ["bandgap"], clearml_train_logger, train_default)

./data/ ['bandgap']
ClearML Task: created new task id=offline-88076011088a4121b325419669c1e9aa
ClearML running in offline mode, session stored in /home/nodoteve/.clearml/cache/offline/offline-88076011088a4121b325419669c1e9aa
ClearML running in offline mode, session stored in /home/nodoteve/.clearml/cache/offline/offline-88076011088a4121b325419669c1e9aa
ClearML Monitor: Could not detect iteration reporting, falling back to iterations as seconds-from-start
20375
Epoch: [0][0/48]	Time 0.789 (0.789)	Data 0.130 (0.130)	Loss 1.1961 (1.1961)	MAE 1.011 (1.011)
Epoch: [0][10/48]	Time 0.338 (0.382)	Data 0.036 (0.058)	Loss 0.6587 (1.0231)	MAE 0.772 (0.938)
Epoch: [0][20/48]	Time 0.326 (0.358)	Data 0.035 (0.049)	Loss 0.6325 (0.8216)	MAE 0.705 (0.823)
Epoch: [0][30/48]	Time 0.330 (0.350)	Data 0.081 (0.047)	Loss 0.4787 (0.7177)	MAE 0.616 (0.762)
Epoch: [0][40/48]	Time 0.360 (0.347)	Data 0.040 (0.045)	Loss 0.3629 (0.6520)	MAE 0.541 (0.721)
Test: [0/16]	Time 0.255 (0.255)	Loss 0.7120 (0.7120)	MAE 0.75