In [14]:
import json
import pandas as pd
from pathlib import Path

import torch.nn as nn
import tensorflow as tf
from tensorflow import keras

from tensorflow.keras import layers


from typing import Dict, List, Callable

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, Concatenate, Add, Embedding, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2

from megnet.activations import softplus2
from megnet.config import DataType
from megnet.data.crystal import CrystalGraph
from megnet.data.graph import GaussianDistance, StructureGraph
from megnet.layers import MEGNetLayer, Set2Set, GaussianExpansion
from megnet.utils.preprocessing import DummyScaler, Scaler
from megnet.models.base import GraphModel

import torch
import torch.nn as nn


from pymatgen.core import Structure
from sklearn.model_selection import train_test_split


from pathlib import Path

# Utils

In [38]:
def read_pymatgen_struct(file):
    with open(file, "r") as f:
        d = json.load(f)
    return Structure.from_dict(d)

def prepare_dataset():
    dataset_path = 'dichalcogenides_public'
    dataset_path = Path(dataset_path)
    targets = pd.read_csv(dataset_path / "targets.csv", index_col=0)
    struct = {
        item.name.strip(".json"): read_pymatgen_struct(item)
        for item in (dataset_path / "structures").iterdir()
    }

    data = pd.DataFrame(columns=["structures"], index=struct.keys())
    data = data.assign(structures=struct.values(), targets=targets)

    return train_test_split(data, test_size=0.25, random_state=666)

def energy_within_threshold(prediction, target):
    # compute absolute error on energy per system.
    # then count the no. of systems where max energy error is < 0.02.
    e_thresh = 0.02
    error_energy = tf.math.abs(target - prediction)

    success = tf.math.count_nonzero(error_energy < e_thresh)
    total = tf.size(target)
    return success / tf.cast(total, tf.int64)

class PreConvLayer(tf.keras.layers.Layer):
    def __init__(self, name_prefix=None):
        super(PreConvLayer, self).__init__(name=name_prefix)
        self.conv1 = layers.Conv1D(64, kernel_size=3, activation='relu', padding='same')
        self.conv2 = layers.Conv1D(32, kernel_size=3, activation='relu', padding='same')

    def call(self, inputs):
        x = self.conv1(inputs)
        return self.conv2(x)

class PostDenseLayer(tf.keras.layers.Layer):
    def __init__(self, name_prefix=None):
        super(PostDenseLayer, self).__init__(name=name_prefix)
        self.fc1 = layers.Dense(32, activation='relu')
        self.fc2 = layers.Dense(16, activation='relu')

    def call(self, inputs):
        x = self.fc1(inputs)
        return self.fc2(x)


# Data

In [6]:
train, test = prepare_dataset()
train.iloc[0,0]

Structure Summary
Lattice
    abc : 25.5225256 25.5225256 14.879004
 angles : 90.0 90.0 119.99999999999999
 volume : 8393.668021812642
      A : 25.5225256 0.0 1.5628039641098191e-15
      B : -12.761262799999994 22.10315553833868 1.5628039641098191e-15
      C : 0.0 0.0 14.879004
    pbc : True True True
PeriodicSite: Mo (1.276e-07, 1.842, 3.72) [0.04167, 0.08333, 0.25]
PeriodicSite: Mo (-1.595, 4.605, 3.72) [0.04167, 0.2083, 0.25]
PeriodicSite: Mo (-3.19, 7.368, 3.72) [0.04167, 0.3333, 0.25]
PeriodicSite: Mo (-4.785, 10.13, 3.72) [0.04167, 0.4583, 0.25]
PeriodicSite: Mo (-6.381, 12.89, 3.72) [0.04167, 0.5833, 0.25]
PeriodicSite: Mo (-7.976, 15.66, 3.72) [0.04167, 0.7083, 0.25]
PeriodicSite: Mo (-9.571, 18.42, 3.72) [0.04167, 0.8333, 0.25]
PeriodicSite: Mo (-11.17, 21.18, 3.72) [0.04167, 0.9583, 0.25]
PeriodicSite: Mo (3.19, 1.842, 3.72) [0.1667, 0.08333, 0.25]
PeriodicSite: Mo (1.595, 4.605, 3.72) [0.1667, 0.2083, 0.25]
PeriodicSite: Mo (8.508e-08, 7.368, 3.72) [0.1667, 0.3333, 0.25]

In [7]:
train.iloc[0, 1]

1.1355

# Model

In [39]:
def prepare_model(cutoff, lr, pre_layer, post_layer):
    nfeat_bond = 10
    r_cutoff = cutoff
    gaussian_centers = np.linspace(0, r_cutoff + 1, nfeat_bond)
    gaussian_width = 0.8
    
    return GNNModel(
        pre_layer=pre_layer,
        post_layer=post_layer,
        graph_converter=CrystalGraph(cutoff=r_cutoff),
        centers=gaussian_centers,
        width=gaussian_width,
        loss=["MAE"],
        npass=2,
        learning_rate=lr,
        metrics=energy_within_threshold,
        n3=32
    )

In [40]:
class GNNModel(GraphModel):
    """
    Construct a graph network model with or without explicit atom features
    if n_feature is specified then a general graph model is assumed,
    otherwise a crystal graph model with z number as atom feature is assumed.
    """

    def __init__(
        self,
        pre_layer = None,
        post_layer = None,
        nblocks: int = 3,
        n1: int = 64,
        n2: int = 32,
        n3: int = 16,
        learning_rate: float = 1e-3,
        nvocal: int = 95,
        embedding_dim: int = 16,
        npass: int = 3,
        ntarget: int = 1,
        act: Callable = softplus2,
        loss: str = "mse",
        metrics: List[str] = None,
        graph_converter: StructureGraph = None,
        target_scaler: Scaler = DummyScaler(),
        optimizer_kwargs: Dict = {"clipnorm": 3},
        centers = None,
        width = None
    ):

        # Build the MEG Model
        model = make_gnn_model(
            pre_layer=pre_layer,
            post_layer=post_layer,
            nblocks=nblocks,
            n1=n1,
            n2=n2,
            n3=n3,
            nvocal=nvocal,
            embedding_dim=embedding_dim,
            npass=npass,
            ntarget=ntarget,
            act=act,
            centers=centers,
            width=width
    
        )

        opt_params = {"learning_rate": learning_rate}
        if optimizer_kwargs is not None:
            opt_params.update(optimizer_kwargs)
        model.compile(Adam(**opt_params), loss, metrics=metrics)

        if graph_converter is None:
            graph_converter = CrystalGraph(cutoff=4, bond_converter=GaussianDistance(np.linspace(0, 5, 100), 0.5))

        super().__init__(model=model, target_scaler=target_scaler, graph_converter=graph_converter)
        


In [41]:
def make_gnn_model(
    pre_layer = None, # model between input and first MEGNetLayer
    post_layer = None, # model between concatenate and output
    nblocks: int = 3, # number of MEGNetLayer blocks
    n1: int = 64, # number of hidden units in layer 1 in MEGNetLayer
    n2: int = 32, # number of hidden units in layer 2 in MEGNetLayer
    n3: int = 16, # number of hidden units in layer 3 in MEGNetLayer
    nvocal: int = 95, # number of total element
    embedding_dim: int = 16, # number of embedding dimension
    npass: int = 3, # number of recurrent steps in Set2Set layer
    ntarget: int = 1, # number of output targets
    act: Callable = softplus2, # activation function
    centers = None, # array for defining the Gaussian expansion centers
    width = None # width for the Gaussian basis
) -> Model:
    
    # ембеддинг для атомов (Это позволяет модели учитывать типы атомов)
    x1 = Input(shape=(None,), dtype=DataType.tf_int, name="atom_int_input") 
    x1_ = Embedding(nvocal, embedding_dim, name="atom_embedding")(x1)

    # преобразование Гауссовского расширения для атрибутов связей. Это позволяет учитывать расстояния между атомами
    x2 = Input(shape=(None,), dtype=DataType.tf_float, name="bond_float_input")
    x2_ = GaussianExpansion(centers=centers, width=width)(x2)  # type: ignore
 

    # state inputs (заряды атомов, их массы или другие физические параметры)
    # take default vector of two zeros
    x3 = Input(shape=(None, 2), dtype=DataType.tf_float, name="state_default_input")
    x3_ = x3

    #  представляют собой индексы, связанные с графом молекулы. Эти индексы указывают на то, какие атомы соединены связями и какие связи существуют между атомами
    x4 = Input(shape=(None,), dtype=DataType.tf_int, name="bond_index_1_input")
    x5 = Input(shape=(None,), dtype=DataType.tf_int, name="bond_index_2_input")
    x6 = Input(shape=(None,), dtype=DataType.tf_int, name="atom_graph_index_input")
    x7 = Input(shape=(None,), dtype=DataType.tf_int, name="bond_graph_index_input")


    # соответствует слою предобработки + одному слою MEGNetLayer.
    def one_block(a, b, c, has_ff=True, block_index=0):
        if has_ff:
            x1_ = pre_layer(name_prefix=f"block_{block_index}_atom_ff")(a)
            x2_ = pre_layer(name_prefix=f"block_{block_index}_bond_ff")(b)
            x3_ = pre_layer(name_prefix=f"block_{block_index}_state_ff")(c)
        else:
            x1_ = a
            x2_ = b
            x3_ = c
        out = MEGNetLayer(
            [n1, n2, n3],
            [n1, n2, n3],
            [n1, n2, n3],
            pool_method="mean",
            activation=act,
            kernel_regularizer=None,
            name=f"megnet_{block_index}",
        )([x1_, x2_, x3_, x4, x5, x6, x7])
    
        # возвращает три тензора - для атомов, связей и состояний
        x1_temp = out[0]
        x2_temp = out[1]
        x3_temp = out[2]
      
        return x1_temp, x2_temp, x3_temp

    x1_ = pre_layer(name_prefix="preblock_atom")(x1_)
    x2_ = pre_layer(name_prefix="preblock_bond")(x2_)
    x3_ = pre_layer(name_prefix="preblock_state")(x3_)
    for i in range(nblocks):
        if i == 0:
            has_ff = False
        else:
            has_ff = True
        x1_1 = x1_
        x2_1 = x2_
        x3_1 = x3_
        x1_1, x2_1, x3_1 = one_block(x1_1, x2_1, x3_1, has_ff, block_index=i)
        # добавляем skip connection
        x1_ = Add(name=f"block_{i}_add_atom")([x1_, x1_1])
        x2_ = Add(name=f"block_{i}_add_bond")([x2_, x2_1])
        x3_ = Add(name=f"block_{i}_add_state")([x3_, x3_1])


    # выполняем Set2Set для атомов и связей
    node_vec = Set2Set(T=npass, n_hidden=n3, name="set2set_atom")([x1_, x6])
    edge_vec = Set2Set(T=npass, n_hidden=n3, name="set2set_bond")([x2_, x7])
    # итоговые векторы для атомов, связей и глобальные вектора конкатенируются в один вектор
    final_vec = Concatenate(axis=-1)([node_vec, edge_vec, x3_])

    # применяем слой пост-обработки к final_vec
    final_vec = post_layer(name_prefix='post_process_layer')(final_vec)

    out = Dense(ntarget, name="readout_2")(final_vec)
    model = Model(inputs=[x1, x2, x3, x4, x5, x6, x7], outputs=out)
    return model

# Training

In [42]:
def pre_layer(name_prefix):
        return PreConvLayer(name_prefix=name_prefix)

def post_layer(name_prefix):
        return PostDenseLayer(name_prefix=name_prefix)

model = prepare_model(
        4,
        2e-4, 
        pre_layer,
        post_layer
    )
print(model.summary())
model.train(
        train.structures,
        train.targets,
        validation_structures=test.structures,
        validation_targets=test.targets,
        epochs=int(10),
        batch_size=int(128),
        save_checkpoint=False
    )



Model: "model_8"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 atom_int_input (InputLayer)    [(None, None)]       0           []                               
                                                                                                  
 bond_float_input (InputLayer)  [(None, None)]       0           []                               
                                                                                                  
 atom_embedding (Embedding)     (None, None, 16)     1520        ['atom_int_input[0][0]']         
                                                                                                  
 gaussian_expansion_9 (Gaussian  (None, None, 10)    0           ['bond_float_input[0][0]']       
 Expansion)                                                                                 

<__main__.GNNModel at 0x32f4f5100>

#Predict

In [22]:
model = prepare_model(
        4, 1e-4,  pre_layer,
        post_layer
)

dataset_path = Path('dichalcogenides_private')
struct = {item.name.strip('.json'): read_pymatgen_struct(item) for item in (dataset_path/'structures').iterdir()}
private_test = pd.DataFrame(columns=['id', 'structures'], index=struct.keys())
private_test = private_test.assign(structures=struct.values())
private_test = private_test.assign(predictions=model.predict_structures(private_test.structures))
private_test[['predictions']].to_csv('./submission.csv', index_label='id')

