In [11]:
using Pkg
Pkg.activate("/home/qingyanz/CrystalGraphConvNets.jl")

using CSV
using DataFrames
using SparseArrays
using Random, Statistics
using Flux
using Flux: @epochs
using GeometricFlux
using SimpleWeightedGraphs
using CrystalGraphConvNets
using DelimitedFiles
using ProgressBars
using PeriodicTable

using PyCall
s = pyimport("pymatgen.core.structure")

[32m[1m Activating[22m[39m environment at `~/CrystalGraphConvNets.jl/Project.toml`


PyObject <module 'pymatgen.core.structure' from '/home/qingyanz/miniconda3/envs/pymatgen/lib/python3.8/site-packages/pymatgen/core/structure.py'>

In [2]:
export inverse_square, exp_decay, build_graph, visualize_graph
include("../../CrystalGraphConvNets.jl/src/graph_functions.jl")

visualize_graph

In [3]:
# Hyperparameters
train_frac = 0.8
num_epochs = 10
cutoff_radius=8.0
max_num_nbr=12
decay_fcn=inverse_square

datasize = "_mid"
cif_dir = "../cif/"
el_list_dir = "../graphs/ellists$datasize"
graph_weights_dir = "../graphs/grwts$datasize"

num_conv = 3
atom_fea_len = 32
pool_type = "mean"
crys_fea_len = 128
num_hidden_layers = 1
lr = 0.001

features = ["group", "row", "X", "atomic_radius", "block"]
num_bins = [18, 8, 10, 10, 4]
logspaced = [false, false, false, true, false]
num_features = sum(num_bins)

50

In [4]:
cif_path = string("../cif/1/00/00/1000001.cif")

"../cif/1/00/00/1000001.cif"

In [5]:
gr, els = @time build_graph(cif_path)

 13.119537 seconds (17.95 M allocations: 471.981 MiB, 5.72% gc time)


({1156, 7340} undirected simple Int32 graph with Float32 weights, ["H", "H", "H", "H", "H", "H", "H", "H", "H", "H"  …  "O", "O", "O", "O", "O", "O", "O", "O", "O", "O"])

In [33]:
pop!(els)

"He"

In [34]:
els

1156-element Array{String,1}:
 "H"
 "H"
 "H"
 "H"
 "H"
 "H"
 "H"
 "H"
 "H"
 "H"
 "H"
 "H"
 "H"
 ⋮
 "O"
 "O"
 "O"
 "O"
 "O"
 "O"
 "O"
 "O"
 "O"
 "O"
 "O"
 "O"

In [25]:
# leave off noble gases too
max_atno = 83
nums_to_skip = union([2, 10, 18, 36, 54], Array(min(max_atno, 100):100))
all_elements = [e.symbol for e in elements[nums_to_skip]]

23-element Array{String,1}:
 "He"
 "Ne"
 "Ar"
 "Kr"
 "Xe"
 "Bi"
 "Po"
 "At"
 "Rn"
 "Fr"
 "Ra"
 "Ac"
 "Th"
 "Pa"
 "U"
 "Np"
 "Pu"
 "Am"
 "Cm"
 "Bk"
 "Cf"
 "Es"
 "Fm"

In [35]:
intersect(Set(els), all_elements)

Set{String} with 0 elements

In [36]:
if isempty(intersect(Set(els), skipped_elements))
    println("pass")
end

pass


In [None]:
labels = CSV.read("../labels/example$datasize.txt",
                  type="String", delim=", ")

# Discard cifs with missing labels for time being
labels = labels[labels[:crystalsystem] .!= "?", :]
labels = labels[labels[:crystalsystem] .!= "i", :]
labels = labels[labels[:crystalsystem] .!= "!", :]

atom_feature_vecs = make_feature_vectors(features) # , num_bins, logspaced)
inputs = FeaturedGraph{SimpleWeightedGraph{Int32, Float32}, Array{Float32,2}}[]
cifs_without_graphs = Matrix(undef, 0, 2)

cif_root = labels[:, 1]

# make all the combinations
# params = Iterators.product(cutoff_radii, nbr_num_cutoffs, decay_fcns)        
# for p in params
for (row, cif) in ProgressBar(enumerate(cif_root))
    # path to the cif
    cif_path = string(cif_dir, cif[1], "/", cif[2:3], "/", 
                      cif[4:5], "/", cif, ".cif")
    
    try
        params_suffix = ""
        el_list_subdir = string(el_list_dir, params_suffix,'/')
        gr_wt_subdir = string(graph_weights_dir, params_suffix,'/')
        
        if !isdir(el_list_subdir)
            mkdir(el_list_subdir)
        end
        if !isdir(gr_wt_subdir)
            mkdir(gr_wt_subdir)
        end
        
        gr, els = @time build_graph(cif_path) 
                  # radius=radius, max_num_nbr=max_num_nbr, dist_decay_func=decay_func
                
        writedlm(string(gr_wt_subdir, cif,".txt"), gr.weights)
        writedlm(string(el_list_subdir, cif,".txt"), els)
    catch e
        @info("unable to build graph at cif = $cif")
        
        # Record cifs where graph building failed
        global cifs_without_graphs = vcat(cifs_without_graphs, [row cif])
    end
end

# Record cifs where graph building failed
writedlm("../graphs/cifs_without_graphs$datasize.csv", cifs_without_graphs)