# Training and inference of ALIGNN models for IDAO-22
Team: NESCafe Gold 3in1

## Imports and data reading

In [None]:
from pathlib import Path
from megnet.models import MEGNetModel
from megnet.data.crystal import CrystalGraph
import numpy as np
import tensorflow as tf

from scripts.utils import structures_to_df, read_json_structures

In [7]:
df_public, df_private = structures_to_df()
df_public = df_public.merge(read_json_structures(Path('../data/train/defects/pymatgen')).rename({'structure': 'diff'}, axis=1), on=['_id'])
df_private = df_private.merge(read_json_structures(Path('../data/eval/defects/pymatgen')).rename({'structure': 'diff'}, axis=0), on=['_id'])

2967it [00:18, 157.65it/s]
2966it [00:19, 151.92it/s]
2966it [00:01, 1723.35it/s]
2967it [00:01, 1824.46it/s]


## MEGNet training on custom metric

In [115]:
def energy_within_threshold(prediction, target):
    e_thresh = 0.02
    error_energy = tf.math.abs(target - prediction)

    success = tf.math.count_nonzero(error_energy < e_thresh)
    total = tf.size(target)
    return success / tf.cast(total, tf.int64)

In [119]:
nfeat_bond = 100
r_cutoff = 20  # ключевое
gaussian_centers = np.linspace(0, r_cutoff + 1, nfeat_bond)
gaussian_width = 0.5
graph_converter = CrystalGraph(cutoff=r_cutoff)

model = MEGNetModel(graph_converter=graph_converter,
                    centers=gaussian_centers,
                    width=gaussian_width,
                    metrics=energy_within_threshold)

model.train(df_public['diff'], df_public['band_gap'], epochs=100)

  super(Adam, self).__init__(name, **kwargs)


Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

<megnet.models.megnet.MEGNetModel at 0x1238c6fd0>

In [139]:
model.train(df_public['diff'], df_public['band_gap'], epochs=20)
model.save_model('../models/megnet/defects/09598tr_089lb.hdf5')

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<megnet.models.megnet.MEGNetModel at 0x1238c6fd0>

In [145]:
model.load_weights('../models/megnet/defects/09598tr_089lb.hdf5')
model.predict_structures(df_private['structure_y'])
df_private['predictions']

0       0.361837
1       1.144807
2       1.799694
3       1.136510
4       0.405438
          ...   
2962    1.143082
2963    0.353465
2964    1.143884
2965    0.387114
2966    0.388613
Name: predictions, Length: 2967, dtype: float32

In [140]:
df_private['predictions'] = model.predict_structures(df_private['structure_y'])
df_private[['_id', 'predictions']].rename({'_id': 'id'}, axis=1).set_index('id').to_csv('megnet.csv')
# df_private