In [22]:
from xenonpy.descriptor.graph import CrystalGraphFeaturizer
from xenonpy.model.nn.graph import CrystalGraphConvNet, ConvLayer, CrystalGraphDataset, collate_pool
from xenonpy.datatools import preset
from torch.utils.data import DataLoader
from torch import optim
from torch import nn

import pandas as pd
import torch

In [2]:
samples = preset.mp_samples
structures = samples['structure']
tmp = structures[0]
tmp

Structure Summary
Lattice
    abc : 4.3265793673570885 4.3265793673570885 4.3265793673570885
 angles : 59.99999999999999 59.99999999999999 59.99999999999999
 volume : 57.26892444621164
      A : -3.05935361 -3.05935361 0.0
      B : -3.05935361 -0.0 -3.05935361
      C : 0.0 -3.05935361 -3.05935361
PeriodicSite: Rb (-3.0594, -3.0594, -3.0594) [0.5000, 0.5000, 0.5000]
PeriodicSite: Cu (0.0000, 0.0000, 0.0000) [-0.0000, 0.0000, 0.0000]
PeriodicSite: O (-4.5890, -4.5890, -4.5890) [0.7500, 0.7500, 0.7500]

In [3]:
cgfea = CrystalGraphFeaturizer(atom_feature='elements')

In [11]:
a = cgfea.transform(structures, return_type='array')

In [9]:
dataset = CrystalGraphDataset(a, samples['formation_energy_per_atom'].values)

In [26]:
dataloader = DataLoader(dataset, batch_size=10, collate_fn=collate_pool)

In [28]:
def run(atom_fea_len, h_fea_len, n_conv, n_h):
    print('training model with paras:', atom_fea_len, h_fea_len, n_conv, n_h)
    
#     model = CrystalGraphConvNet(orig_atom_fea_len=401, nbr_fea_len=41, atom_fea_len=atom_fea_len, h_fea_len=h_fea_len, n_conv=n_conv, n_h=n_h)
    model = CrystalGraphConvNet(orig_atom_fea_len=58, nbr_fea_len=41, atom_fea_len=atom_fea_len, h_fea_len=h_fea_len, n_conv=n_conv, n_h=n_h)

    ### parameter
    EPOCH = 500
    LR = 0.0001
    WD = 0.99 #weight decay


    optimizer = optim.Adam(model.parameters(), lr = LR)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

#     model.to(device)

    criterion = nn.MSELoss()

    Loss = []
    best_loss = 100

    for i in range(EPOCH):
        model.train(mode=True)
        print('-------')
        print('Epoch %d'%(i+1))
        scheduler.step()
        for inp_var, label in dataloader:
            output = model(*inp_var)
            loss = criterion(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print('Train: %f' % loss.item())

        _ = model.eval()

#         tmp = []
#         for td in TEST_DATA:
#             pred = model(*td).cpu().detach().numpy()
#             tmp.append(pred)
#         pred = np.vstack(tmp)

#         tmp = []
#         for td in TRAIN_DATA:
#             pred_fit = model(*td).cpu().detach().numpy()
#             tmp.append(pred_fit)
#         pred_fit = np.vstack(tmp)

#         test_loss = mean_squared_error(pred, true)
# #         print('Test: %f' % test_loss)

#         if test_loss < best_loss:
#             best_loss = test_loss
#             save_name = '%s_%s_%s_%s_%s' % (name, atom_fea_len, h_fea_len, n_conv, n_h)
#             torch.save({
#                 'train_loss': loss.cpu().detach().numpy().item(),
#                 'test_loss': test_loss,
#                 'best_model': model,
#                 'pred': pred,
#                 'true': true,
#                 'pred_fit': pred_fit,
# #                 'true_fit': true_fit,
#             }, 'Li3P_sample/train_0.3/%s_model.pkl' % save_name)
#             draw(pred, true, pred_fit, true_fit)
#             plt.savefig('Li3P_sample/train_0.3/%s.png' % save_name)

#     print('Best test: %f\n' % best_loss)

In [29]:
from xenonpy.math import Product

# # dft
# for paras in Product([200, 250, 300], [100, 150, 200], [3, 4, 5], [1, 2, 3]):
#     run(*paras)
    
# org
for paras in Product([128, 64], [64, 32], [3, 4, 5], [1, 2, 3]):
    run(*paras)

training model with paras: 128 64 3 1
-------
Epoch 1
Train: 22.756613
-------
Epoch 2
Train: 7.342506
-------
Epoch 3
Train: 13.946541
-------
Epoch 4
Train: 18.655800
-------
Epoch 5
Train: 20.084021
-------
Epoch 6
Train: 19.509020
-------
Epoch 7
Train: 17.829136
-------
Epoch 8
Train: 15.060686
-------
Epoch 9
Train: 12.080341
-------
Epoch 10
Train: 10.218332
-------
Epoch 11
Train: 9.005825
-------
Epoch 12
Train: 7.748944
-------
Epoch 13
Train: 6.907959
-------
Epoch 14
Train: 6.296173
-------
Epoch 15
Train: 5.770061
-------
Epoch 16


KeyboardInterrupt: 

In [11]:
tmp = structures[0]
tmp

Structure Summary
Lattice
    abc : 4.3265793673570885 4.3265793673570885 4.3265793673570885
 angles : 59.99999999999999 59.99999999999999 59.99999999999999
 volume : 57.26892444621164
      A : -3.05935361 -3.05935361 0.0
      B : -3.05935361 -0.0 -3.05935361
      C : 0.0 -3.05935361 -3.05935361
PeriodicSite: Rb (-3.0594, -3.0594, -3.0594) [0.5000, 0.5000, 0.5000]
PeriodicSite: Cu (0.0000, 0.0000, 0.0000) [-0.0000, 0.0000, 0.0000]
PeriodicSite: O (-4.5890, -4.5890, -4.5890) [0.7500, 0.7500, 0.7500]

In [17]:
a = tmp.get_all_neighbors(8, include_index=True)
a = [sorted(n, key=lambda x: x[1]) for n in a]
a[1]

[(PeriodicSite: O (1.5297, -1.5297, -1.5297) [-0.2500, -0.2500, 0.7500],
  2.6494779454196298,
  2),
 (PeriodicSite: O (-1.5297, 1.5297, -1.5297) [-0.2500, 0.7500, -0.2500],
  2.6494779454196298,
  2),
 (PeriodicSite: O (-1.5297, -1.5297, 1.5297) [0.7500, -0.2500, -0.2500],
  2.6494779454196298,
  2),
 (PeriodicSite: O (1.5297, 1.5297, 1.5297) [-0.2500, -0.2500, -0.2500],
  2.64947794541963,
  2),
 (PeriodicSite: Rb (3.0594, 0.0000, 0.0000) [-0.5000, -0.5000, 0.5000],
  3.05935361,
  0),
 (PeriodicSite: Rb (0.0000, 3.0594, 0.0000) [-0.5000, 0.5000, -0.5000],
  3.05935361,
  0),
 (PeriodicSite: Rb (0.0000, 0.0000, -3.0594) [-0.5000, 0.5000, 0.5000],
  3.05935361,
  0),
 (PeriodicSite: Rb (0.0000, 0.0000, 3.0594) [0.5000, -0.5000, -0.5000],
  3.05935361,
  0),
 (PeriodicSite: Rb (0.0000, -3.0594, 0.0000) [0.5000, -0.5000, 0.5000],
  3.05935361,
  0),
 (PeriodicSite: Rb (-3.0594, 0.0000, 0.0000) [0.5000, 0.5000, -0.5000],
  3.05935361,
  0),
 (PeriodicSite: Cu (3.0594, 3.0594, 0.0000) [-1

In [18]:
for s in tmp.species:
    print(s.name.__class__)

<class 'str'>
<class 'str'>
<class 'str'>


In [11]:
tmp = preset.dataset_elements_completed