In [1]:
import json
import os

path_dir = "./data"

saved_files = os.listdir(path_dir)
saved_files_wo_ext = []
for i in saved_files:
    saved_files_wo_ext.append(i.split(".")[0])

In [2]:
import json
import os
import time

import numpy as np
import torch
# from e3nn import o3
import torch.nn.functional as F
from torch.utils.data import Dataset
# from torch_geometric.data import Data

In [3]:
def pbc_expand(atom_type, atom_coord):
    """
    Expand the atoms by periodic boundary condition to eight directions in the neighboring cells.
    :param atom_type: atom types, tensor of shape (n_atom,)
    :param atom_coord: atom coordinates, tensor of shape (n_atom, 3)
    :return: expanded atom types and coordinates
    """
    exp_type, exp_coord = [], []
    exp_direction = torch.FloatTensor(
        [
            [0, 0, 0],
            [1, 0, 0],
            [0, 1, 0],
            [0, 0, 1],
            [0, 1, 1],
            [1, 0, 1],
            [1, 1, 0],
            [1, 1, 1],
        ]
    )
    for a_type, a_coord in zip(atom_type, atom_coord):
        for direction in exp_direction:
            new_coord = a_coord + direction
            if (new_coord <= 1).all():
                exp_type.append(a_type)
                exp_coord.append(new_coord)
    return torch.LongTensor(exp_type), torch.stack(exp_coord, dim=0)


In [4]:
item = 0
file_pattern = ".chgcar"
file_list = saved_files_wo_ext
print(file_list[item])

file_name = f'{(file_list[item])}{file_pattern}'
data_path = "./data"

fileobj = open(os.path.join(data_path, file_name), 'r')

atom_file = "crystal.json"
with open(atom_file) as f:
    atom_info = json.load(f)
atom_list = [info["name"] for info in atom_info]
atom_name2idx = {name: idx for idx, name in enumerate(atom_list)}


mp-720294


In [5]:
"""Read atoms and data from CHGCAR file."""
readline = fileobj.readline
readline()  # the first comment line
scale = float(readline())  # the scaling factor (lattice constant)

In [6]:
print("scale: ", scale)

scale:  1.0


In [7]:
# the upcoming three lines contain the cell information
cell = torch.empty(3, 3, dtype=torch.float)
for i in range(3):
    cell[i] = torch.FloatTensor([float(s) for s in readline().split()])
cell = cell * scale

In [8]:
print("cell: ", cell)

cell:  tensor([[ 6.3662,  0.0000,  0.0000],
        [ 0.0000,  6.7907,  0.0000],
        [ 0.0000,  1.6387, 23.0747]])


In [9]:
# the sixth line specifies the constituting elements
elements = readline().split()
# the seventh line supplies the number of atoms per atomic species
n_atoms = [int(s) for s in readline().split()]
# the eighth line is always "Direct" in our application
readline()

'Direct\n'

In [10]:
print("elements: ", elements)
print("n_atoms: ", n_atoms)

elements:  ['As', 'H', 'N', 'O']
n_atoms:  [4, 72, 12, 28]


In [11]:
tot_atoms = sum(n_atoms)
atom_type = torch.empty(tot_atoms, dtype=torch.long)
atom_coord = torch.empty(tot_atoms, 3, dtype=torch.float)
# the upcoming lines contains the atomic positions in fractional coordinates
idx = 0
for elem, n in zip(elements, n_atoms):
    atom_type[idx:idx + n] = atom_name2idx[elem]
    for _ in range(n):
        atom_coord[idx] = torch.FloatTensor([float(s) for s in readline().split()])
        idx += 1

In [12]:
print("atom_type: ", atom_type)
print("atom_coord: ", atom_coord)

atom_type:  tensor([52, 52, 52, 52, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67,
        67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67,
        67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67,
        67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67,
        67, 67, 67, 67, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0])
atom_coord:  tensor([[0.8600, 0.7905, 0.8430],
        [0.3600, 0.2095, 0.6570],
        [0.1400, 0.2095, 0.1570],
        [0.6400, 0.7905, 0.3430],
        [0.8914, 0.7447, 0.6401],
        [0.3914, 0.2553, 0.8599],
        [0.1086, 0.2553, 0.3599],
        [0.6086, 0.7447, 0.1401],
        [0.9700, 0.9803, 0.6578],
        [0.4700, 0.0197, 0.8422],
        [0.0300, 0.0197, 0.3422],
        [0.5300, 0.9803, 0.1578],
        [0.8889, 0.1984, 0.7847],
        [0.3

In [13]:
pbc = True
if pbc:
    atom_type, atom_coord = pbc_expand(atom_type, atom_coord)

In [14]:
print("atom_type: ", atom_type)

atom_type:  tensor([52, 52, 52, 52, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67,
        67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67,
        67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67,
        67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67, 67,
        67, 67, 67, 67, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0])


In [15]:
print("atom_coord: ", atom_coord)

atom_coord:  tensor([[0.8600, 0.7905, 0.8430],
        [0.3600, 0.2095, 0.6570],
        [0.1400, 0.2095, 0.1570],
        [0.6400, 0.7905, 0.3430],
        [0.8914, 0.7447, 0.6401],
        [0.3914, 0.2553, 0.8599],
        [0.1086, 0.2553, 0.3599],
        [0.6086, 0.7447, 0.1401],
        [0.9700, 0.9803, 0.6578],
        [0.4700, 0.0197, 0.8422],
        [0.0300, 0.0197, 0.3422],
        [0.5300, 0.9803, 0.1578],
        [0.8889, 0.1984, 0.7847],
        [0.3889, 0.8016, 0.7153],
        [0.1111, 0.8016, 0.2153],
        [0.6111, 0.1984, 0.2847],
        [0.9609, 0.2809, 0.7155],
        [0.4609, 0.7191, 0.7845],
        [0.0391, 0.7191, 0.2845],
        [0.5391, 0.2809, 0.2155],
        [0.7244, 0.8685, 0.0101],
        [0.2244, 0.1315, 0.4899],
        [0.2756, 0.1315, 0.9899],
        [0.7756, 0.8685, 0.5101],
        [0.6585, 0.6325, 0.0284],
        [0.1585, 0.3675, 0.4716],
        [0.3415, 0.3675, 0.9716],
        [0.8415, 0.6325, 0.5284],
        [0.7991, 0.1725, 0.9940],
 

In [16]:
atom_coord = atom_coord @ cell

In [17]:
print("atom_coord: ", atom_coord)

atom_coord:  tensor([[ 5.4752,  6.7493, 19.4510],
        [ 2.2921,  2.4995, 15.1610],
        [ 0.8910,  1.6801,  3.6237],
        [ 4.0741,  5.9299,  7.9136],
        [ 5.6748,  6.1063, 14.7696],
        [ 2.4917,  3.1425, 19.8424],
        [ 0.6914,  2.3231,  8.3051],
        [ 3.8745,  5.2869,  3.2322],
        [ 6.1751,  7.7350, 15.1788],
        [ 2.9920,  1.5138, 19.4332],
        [ 0.1911,  0.6944,  7.8959],
        [ 3.3742,  6.9156,  3.6415],
        [ 5.6590,  2.6334, 18.1071],
        [ 2.4759,  6.6154, 16.5049],
        [ 0.7072,  5.7961,  4.9676],
        [ 3.8903,  1.8140,  6.5697],
        [ 6.1170,  3.0797, 16.5094],
        [ 2.9339,  6.1691, 18.1026],
        [ 0.2492,  5.3497,  6.5652],
        [ 3.4323,  2.2603,  4.9721],
        [ 4.6119,  5.9140,  0.2325],
        [ 1.4289,  1.6961, 11.3049],
        [ 1.7542,  2.5155, 22.8422],
        [ 4.9373,  6.7333, 11.7698],
        [ 4.1922,  4.3418,  0.6549],
        [ 1.0091,  3.2682, 10.8825],
        [ 2.1740,  4.0876

In [18]:
readline()  # an empty line
shape = [int(s) for s in readline().split()]  # grid size
n_grid = shape[0] * shape[1] * shape[2]
# the grids are corner-aligned
x_coord = torch.linspace(0, shape[0] - 1, shape[0]).unsqueeze(-1) / shape[0] * cell[0]
y_coord = torch.linspace(0, shape[1] - 1, shape[1]).unsqueeze(-1) / shape[1] * cell[1]
z_coord = torch.linspace(0, shape[2] - 1, shape[2]).unsqueeze(-1) / shape[2] * cell[2]
grid_coord = x_coord.view(-1, 1, 1, 3) + y_coord.view(1, -1, 1, 3) + z_coord.view(1, 1, -1, 3)
grid_coord = grid_coord.view(-1, 3)

In [19]:
print("grid_coord: ", grid_coord)
print("grid_coord.shape: ", grid_coord.shape)

grid_coord:  tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 4.5521e-03, 6.4096e-02],
        [0.0000e+00, 9.1042e-03, 1.2819e-01],
        ...,
        [6.2999e+00, 8.3529e+00, 2.2882e+01],
        [6.2999e+00, 8.3574e+00, 2.2946e+01],
        [6.2999e+00, 8.3620e+00, 2.3011e+01]])
grid_coord.shape:  torch.Size([3732480, 3])


In [20]:
# the augmented occupancies are ignored
density = torch.FloatTensor([float(s) for s in fileobj.read().split()[:n_grid]])
# the value stored is the charge within a grid instead of the charge density
# divide the charge by the grid volume to get the density
volume = torch.linalg.det(cell).abs()
density = density / volume
# CHGCAR file stores the density as Z-Y-X, convert them to X-Y-Z
density = density.view(shape[2], shape[1], shape[0]).transpose(0, 2).contiguous().view(-1)

In [22]:
print("density: ", density)
print("density.shape: ", density.shape)

density:  tensor([0.0127, 0.0128, 0.0129,  ..., 0.0125, 0.0125, 0.0126])
density.shape:  torch.Size([3732480])
