In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

# matplotlib.use("Agg")

from ase import Atoms
from ase.io import read
from agox.databases import Database
from agox.environments import Environment
from agox.utils.graph_sorting import Analysis

import glob
import numpy as np
from sklearn.decomposition import PCA

In [None]:
## Set up the plotting environment
# matplotlib.rcParams.update(matplotlib.rcParamsDefault)
plt.rc('text', usetex=True)
plt.rc('font', family='cmr10', size=12)
plt.rcParams["axes.formatter.use_mathtext"] = True

In [None]:
## Set the plotting parameters
seed = 0
identifier = ""
# min_energy = -9.064090728759766

In [None]:
## Set the descriptors
from agox.models.descriptors import SOAP

local_descriptor = local_descriptor = SOAP.from_species(["C", "Mg", "O"], r_cut=6.0)

In [None]:
## Set the calculators
from chgnet.model import CHGNetCalculator
calc = CHGNetCalculator()

In [None]:
## Function to compare the cell and carbon atoms of a structure and a host
def compare_cells_and_carbon(structure, host):
    """Check if the structure and host have the same cell and same number of carbon atoms."""
    # Compare cell parameters
    cell_match = np.allclose(structure.cell, host.cell, atol=1e-5)

    # Count number of carbon atoms
    carbon_count_structure = sum(atom.symbol == 'C' for atom in structure)
    carbon_count_host = sum(atom.symbol == 'C' for atom in host)

    return cell_match and (carbon_count_structure == carbon_count_host)

In [None]:
## Load the structures
poscar_files = glob.glob("DC-MgO_hosts/POSCAR_*")
hosts = []
for poscar_file in poscar_files:
  host = read(poscar_file)
  host.calc =calc
  hosts.append(host)

In [None]:
## Create bulk MgO structure
MgO = Atoms("Mg4O4", positions=[
        [0.0, 0.0, 0.0],
        [0.0, 2.097, 2.097],
        [2.097, 0.0, 2.097],
        [2.097, 2.097, 0.0],
        [0.0, 0.0, 2.097],
        [0.0, 2.097, 0.0],
        [2.097, 0.0, 0.0],
        [2.097, 2.097, 2.097],
    ], cell=[4.1940, 4.1940, 4.1940], pbc=True)
MgO.calc = calc

In [None]:
## Load the unrelaxed structures
unrlxd_structures = read("DOutput"+identifier+"/unrlxd_structures_seed"+str(seed)+".traj", index=":")
for structure in unrlxd_structures:
  structure.calc = calc

In [None]:
## Load the relaxed structures
rlxd_structures = read("DOutput"+identifier+"/rlxd_structures_seed"+str(seed)+".traj", index=":")
for structure in rlxd_structures:
  structure.calc = calc
# min_energy = min([structure.get_potential_energy()/len(structure) for structure in rlxd_structures])

In [None]:
## Get the unrelaxed formation energies
unrlxd_delta_en_per_atom = []
for structure in unrlxd_structures:
  host_energy = None
  area = None
  energy = structure.get_potential_energy()
  # compare cell of structure to cell of host and find matching host cell, i.e. same cell and same number of carbon atoms
  for host in hosts:
    if compare_cells_and_carbon(structure, host):
      host.calc = calc
      host_energy = host.get_potential_energy()
      area = np.linalg.norm(np.cross(host.cell[0], host.cell[1]))
      break
  if host_energy is None:
    print("No matching host for structure")
    continue
  energy -= host_energy
  energy -= MgO.get_potential_energy() * sum(atom.symbol == 'Mg' for atom in structure)
  energy /= 2 * area
  unrlxd_delta_en_per_atom.append(energy)
print("Unrelaxed min energy: ", min(unrlxd_delta_en_per_atom))
    

In [None]:
## Get the relaxed formation energies
rlxd_delta_en_per_atom = []
for structure in rlxd_structures:
  host_energy = None
  area = None
  energy = structure.get_potential_energy()
  # compare cell of structure to cell of host and find matching host cell, i.e. same cell and same number of carbon atoms
  for host in hosts:
    if compare_cells_and_carbon(structure, host):
      host.calc = calc
      host_energy = host.get_potential_energy()
      area = np.linalg.norm(np.cross(host.cell[0], host.cell[1]))
      break
  if host_energy is None:
    print("No matching host for structure")
    continue
  energy -= host_energy
  energy -= MgO.get_potential_energy() * sum(atom.symbol == 'Mg' for atom in structure)
  energy /= 2 * area
  rlxd_delta_en_per_atom.append(energy)
print("Relaxed min energy: ", min(rlxd_delta_en_per_atom))
    

In [None]:
## Set up the PCA
pca = PCA(n_components=2)

In [None]:
## Get the 'super atom' descriptors for the unrelaxed structures
unrlxd_super_atoms = []
for structure in unrlxd_structures:
  unrlxd_super_atoms.append( np.mean(local_descriptor.get_features(structure), axis=0) )

In [None]:
## Get the 'super atom' descriptors for the relaxed structures
rlxd_super_atoms = []
for structure in rlxd_structures:
  rlxd_super_atoms.append( np.mean(local_descriptor.get_features(structure), axis=0) )

In [None]:
## Fit the PCA model to the unrelaxed or relaxed structures
rlxd_string = "rlxd"

In [None]:
## Save pca model
import pickle
if True:
  pca.fit(np.squeeze([arr for arr in rlxd_super_atoms]))
  with open("pca_model"+identifier+"_all_rlxd_"+str(seed)+".pkl", "wb") as f:
    pickle.dump(pca, f)

## Load pca model
with open("pca_model"+identifier+"_all_"+rlxd_string+"_0.pkl", "rb") as f:
  pca = pickle.load(f)

In [None]:
## Transform the unrelaxed and relaxed structures to the reduced space
unrlxd_X_reduced = pca.transform(np.squeeze([arr for arr in unrlxd_super_atoms]))
rlxd_X_reduced = pca.transform(np.squeeze([arr for arr in rlxd_super_atoms]))

In [None]:
## Get the index of the structure with the minimum energy
min_energy_index = np.argmin(rlxd_delta_en_per_atom)
print(min_energy_index)

In [None]:
## Plot the PCA
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 6))

plt.subplots_adjust(wspace=0.05, hspace=0)

## Get the maximum energy for the colourbar
max_en = min(3.5, max(np.max(unrlxd_delta_en_per_atom), np.max(rlxd_delta_en_per_atom)))

## Plot the PCA
axes[0].scatter(unrlxd_X_reduced[:, 0], unrlxd_X_reduced[:, 1], c=unrlxd_delta_en_per_atom, cmap="viridis", vmin = 0, vmax = max_en)
axes[1].scatter(rlxd_X_reduced[:, 0], rlxd_X_reduced[:, 1], c=rlxd_delta_en_per_atom, cmap="viridis", vmin = 0, vmax = max_en)

## Add the minimum energy structures to the plot
for ax in axes:
  ax.scatter(rlxd_X_reduced[min_energy_index, 0], rlxd_X_reduced[min_energy_index, 1], s=200, edgecolor='red', facecolor='none', linewidth=2)

## Add labels
fig.text(0.5, 0.04, 'Principal component 1', ha='center', fontsize=15)
axes[0].set_ylabel('Principal component 2', fontsize=15)
axes[0].set_title('Unrelaxed')
axes[1].set_title('Relaxed')
if identifier == "_VASP":
  if rlxd_string == "rlxd":
    xlims = [-11, 8]
    ylims = [-5, 6]
  else:
    xlims = [-9, 13]
    ylims = [-7, 12]
else:
  if rlxd_string == "rlxd":
    xlims = [-250, 150]
    ylims = [-10, 70]
  else:
    xlims = [-600, 600]
    ylims = [-100, 100]
  # if rlxd_string == "rlxd":
  #   xlims = [-25, 10]
  #   ylims = [-12, 6]
  # else:
  #   xlims = [-5, 13]
  #   ylims = [-6.5, 13]

for ax in axes:
  ax.tick_params(axis='both', direction='in')
  ax.set_xlim(xlims)
  ax.set_ylim(ylims)

## Unify tick labels
xticks = axes[0].get_xticks()
xticks = xticks[(xticks >= xlims[0]) & (xticks <= xlims[1])]

axes[0].set_xticks(xticks)
axes[1].set_xticks(xticks)
axes[1].set_yticklabels([])
axes[0].tick_params(axis='x', labelbottom=True, top=True)
axes[1].tick_params(axis='x', labelbottom=True, top=True)
axes[0].tick_params(axis='y', labelbottom=True, right=True)
axes[1].tick_params(axis='y', labelbottom=True, right=True)

## Make axes[0] and axes[1] the same width
axes[0].set_box_aspect(1.7)
axes[1].set_box_aspect(1.7)

## Add colorbar next to the axes
cbar = fig.colorbar(axes[1].collections[0], ax=axes, orientation='vertical', fraction=0.085, pad=0.02)
cbar.set_label('Formation energy (eV/atom)', fontsize=15)

## Save the figure
plt.savefig('C-MgO_RAFFLE'+identifier+'_pca_'+rlxd_string+'_fit_seed'+str(seed)+'.pdf', bbox_inches='tight', pad_inches=0, facecolor=fig.get_facecolor(), edgecolor='none')