<a href="https://colab.research.google.com/github/sparks-baird/xtal2png/blob/main/notebooks/2_0_materials_project_ranges.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Selecting Parameter Ranges via Materials Project

In this notebook, we'll go over how we selected parameter ranges for some hyperparameters of `xtal2png`, namely the lower and upper bounds of lattice parameter lengths ($a$, $b$, and $c$), cell volume, and site pairwise distances.

After we've downloaded the data from Materials Project (or loaded it if running the notebook again), we'll extract the parameters from each of the compounds and do some exploratory data analysis. Based on the analysis, we choose to use a quantile as an upper bound on the parameter ranges in order to get rid of outliers. By removing the highest 1% in each parameter category, we retain 97% of the data with fewer than 52 sites. This gives us our final parameter ranges. Finally, we make publication-ready histogram figures and save these.

## Setup

Let's keep this notebook compatible both as a Google Colab notebook and running locally as a Jupyter notebook.

In [6]:
from os import path
try:
  import google.colab
  IN_COLAB = True
  base_dir = "/content/drive/MyDrive/"
except:
  IN_COLAB = False
  base_dir = path.join("data", "external")

In [2]:
if IN_COLAB:
  %pip install pymatgen kaleido

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pymatgen
  Downloading pymatgen-2022.0.17.tar.gz (40.6 MB)
[K     |████████████████████████████████| 40.6 MB 1.3 MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting kaleido
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
[K     |████████████████████████████████| 79.9 MB 1.2 MB/s 
Collecting spglib>=1.9.9.44
  Downloading spglib-1.16.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (325 kB)
[K     |████████████████████████████████| 325 kB 62.8 MB/s 
Collecting ruamel.yaml>=0.15.6
  Downloading ruamel.yaml-0.17.21-py3-none-any.whl (109 kB)
[K     |████████████████████████████████| 109 kB 54.1 MB/s 
Collecting scipy>=1.5.0
  Downloading scipy-1.7.3-cp37-cp3

## Data

### Materials Project API Key

Get your [Materials Project API key](https://next-gen.materialsproject.org/api) from a file that you store in your Google Drive (see below) or current directory (`.`), or specify it manually by setting the `api_key` variable in the form field or by running in a local miniconda command prompt with an environment activated that has `pymatgen` installed: `pmg config --add PMG_MAPI_KEY <USER_API_KEY>`, e.g. `pmg config --add PMG_MAPI_KEY 123abc456def`. For the latter option, see the [`pymatgen` docs](https://pymatgen.org/usage.html#setting-the-pmg-mapi-key-in-the-config-file).

The file named `mp-api-key.json` placed directly in your `MyDrive` folder or in your current directory would look like the following:
```json
{
    "API_KEY": "YOUR_API_KEY"
}
```
Note that this file is not necessary locally if you use the `pmg config` option above.

In [4]:
import json
if IN_COLAB:
  from google.colab import drive
  drive.mount('/content/drive')
  apikey_fpath = "/content/drive/MyDrive/mp-api-key.json"
  try:
    # https://stackoverflow.com/a/68442279/13697228
    with open(apikey_fpath, 'r') as f:
        json_data = json.load(f)
        api_key = json_data["API_KEY"]
  except Exception as e:
    print(e)
    api_key = "" #@param {type:"string"}
    if api_key == "":
      print(f"Couldn't load API key from {apikey_fpath}, and user-input API key is also empty.")
    print(f"defaulting to user-input API key {api_key}")
    pass
else:
  api_key = None
  print("make sure that you have run `pmg config --add PMG_MAPI_KEY <USER_API_KEY>`")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### Download

Let's either download the data directly from Materials Project using the `MPRester` API or load the data that's been saved previously to your device as `structures.pkl` in your `base_dir`.

In [5]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle
from pymatgen.ext.matproj import MPRester

In [7]:
pkl_path = path.join(base_dir, "structures.pkl")
try:
  with open(pkl_path, "rb") as f:
    results = pickle.load(f)
except Exception as e:
  print(e)
  with MPRester(api_key) as m:
      results = m.query(
          {"nelements": {"$gte": 2},
          "nsites": {"$lte": 52}},
          properties=["structure"],
      )
  with open(pkl_path, "wb") as f:
    pickle.dump(results, f)
  pass

### Extract Lattice and Distances

From here, we'll loop through each of the structures and grab the lattice parameter lengths (`a`, `b`, and `c`) as well as the cell volume (`volume`) and pairwise distance matrices between each of the sites for a given structure (`distance`).

In [8]:
a = []
b = []
c = []
volume = []
distance = []

for s in tqdm(results):
    s = s["structure"]
    lattice = s.lattice
    a.append(lattice.a)
    b.append(lattice.b)
    c.append(lattice.c)
    volume.append(lattice.volume)
    distance.append(s.distance_matrix)

print('range of a is: ', min(a), '-', max(a))
print('range of b is: ', min(b), '-', max(b))
print('range of c is: ', min(c), '-', max(c))
print('range of volume is: ', min(volume), '-', max(volume))

dis_min_tmp = []
dis_max_tmp = []
for d in tqdm(range(len(distance))):
  dis_min_tmp.append(min(distance[d][np.nonzero(distance[d])]))
  dis_max_tmp.append(max(distance[d][np.nonzero(distance[d])]))

print('range of pair-wise distance is: ', min(dis_min_tmp), '-', max(dis_max_tmp))

100%|██████████| 106127/106127 [01:19<00:00, 1341.27it/s]


range of a is:  2.296021 - 66.29136774227022
range of b is:  2.258778 - 61.125585795588215
range of c is:  2.131537 - 130.453537
range of volume is:  11.91856931582488 - 20090.90640762975


100%|██████████| 106127/106127 [00:12<00:00, 8797.44it/s]

range of pair-wise distance is:  0.7249349602879995 - 64.8913973530744





## Exploratory Data Analysis

### Setup

First, we store the data as a `DataFrame` to make it easier to visualize and apply operations to it.

In [9]:
import plotly.express as px
df = pd.DataFrame(dict(a=a, b=b, c=c, volume=volume, min_distance=dis_min_tmp, max_distance=dis_max_tmp))

### Min/Max
Next, we take a look at the minimum and maximum for each of the parameters.

In [None]:
low_df = df.apply(np.min).drop("max_distance")
low_df

a                2.296021
b                2.258778
c                2.131537
volume          11.918569
min_distance     0.724935
dtype: float64

In [None]:
df.apply(np.max)

a                  66.291368
b                  61.125586
c                 130.453537
volume          20090.906408
min_distance        8.650650
max_distance       64.891397
dtype: float64

The maxima here can be pretty large, for example ~`20000` cubic angstroms for the unit cell volume.

### Histogram

Let's take a quick look at one of the parameters involved, in this case the `a` lattice parameter length.

In [68]:
import plotly.express as px
px.histogram(df, x="a", marginal="rug")

Clearly, there are outliers.

### Quantile Maximum
Since these are some pretty large ranges that will inflate the round-off error of `xtal2png`, let's see if we can filter some of these further by considering only up to a certain percentile (`q` quantile) for the relevant parameters.

In [None]:
q = 0.99
df.apply(lambda a: np.quantile(a, 1 - q)).drop("max_distance")

a                2.921745
b                3.050439
c                3.347172
volume          39.162963
min_distance     0.981744
dtype: float64

In [None]:
upp_df = df.apply(lambda a: np.quantile(a, q))
upp_df = upp_df.drop("min_distance")
upp_df

a                 15.292415
b                 14.953414
c                 35.792380
volume          1467.529411
max_distance      17.550941
dtype: float64

### Data Retention

The ranges are a lot more reasonable now. Let's see how many compounds are retained after applying an upper bound filtering step using this upper quantile.

In [64]:
qstr = " and ".join([f"{lbl} < @upp_df.{lbl}" for lbl in upp_df.index]) # .drop(["volume", "max_distance"])
qstr

'a < @upp_df.a and b < @upp_df.b and c < @upp_df.c and volume < @upp_df.volume and max_distance < @upp_df.max_distance'

In [65]:
filt_df = df.query(qstr)
filt_df

Unnamed: 0,a,b,c,volume,min_distance,max_distance
0,5.189676,5.189676,5.189676,58.128751,2.906562,2.970818
1,5.388181,5.388181,5.388181,65.313995,3.018739,3.085060
2,3.300603,3.300603,3.300603,25.425237,2.333879,2.333879
3,3.498199,3.498199,3.498199,30.270418,2.142200,2.142200
4,3.510234,3.510234,3.510234,43.252200,3.039952,3.039952
...,...,...,...,...,...,...
106122,8.466314,8.603384,8.606069,469.512423,1.502355,6.098565
106123,8.960343,8.960343,8.960342,466.892631,1.533764,6.157343
106124,6.874616,7.317851,8.159621,363.700541,0.996728,5.331704
106125,5.211291,7.406056,10.707537,370.569429,0.983919,5.761587


In [66]:
frac_retained = filt_df.shape[0] / df.shape[0]
print(f"{100*frac_retained:.1f}% retained")

97.0% retained


The ranges are much more reasonable now. Also, we have retained ~97% of the original compounds. The other 3% will be much less likely to be represented during generation (i.e. it's been masked from the distribution), although as outliers to begin with it's unclear if most generative models would generate these kinds of compounds anyway. This may be interesting as a topic of future study.

## Selected Parameter Ranges

We'll leave the lower bound as the minimum of all Materials Project entries (with fewer than 52 sites, that is). Alternatively, the lower bound could be set to `0` for each of these.

In [None]:
low_df # i.e. minima

a                2.296021
b                2.258778
c                2.131537
volume          11.918569
min_distance     0.724935
dtype: float64

In [None]:
upp_df # based on `q` quantile

a                 15.292415
b                 14.953414
c                 35.792380
volume          1467.529411
max_distance      17.550941
dtype: float64

## Plotting Histogram Distributions

Let's plot and save the distributions for the parameters in `upp_df`. First, we define some helper functions to make the figures more compatible with academic publishing and to save them.

In [10]:
from typing import Union
import plotly.graph_objs as go
from plotly import offline

def matplotlibify(
    fig: go.Figure,
    size: int = 24,
    width_inches: Union[float, int] = 3.5,
    height_inches: Union[float, int] = 3.5,
    dpi: int = 142,
    return_scale: bool = False,
) -> go.Figure:
    """Make plotly figures look more like matplotlib for academic publishing.
    
    modified from: https://medium.com/swlh/formatting-a-plotly-figure-with-matplotlib-style-fa56ddd97539    
    
    Parameters
    ----------
    fig : go.Figure
        Plotly figure to be matplotlibified
    size : int, optional
        Font size for layout and axes, by default 24
    width_inches : Union[float, int], optional
        Width of matplotlib figure in inches, by default 3.5
    height_inches : Union[float, int], optional
        Height of matplotlib figure in Inches, by default 3.5
    dpi : int, optional
        Dots per inch (resolution) of matplotlib figure, by default 142. Leave as
        default unless you're willing to verify nothing strange happens with the output.
    return_scale : bool, optional
        If true, then return `scale` which is a quantity that helps with creating a
        high-resolution image at the specified absolute width and height in inches.
        More specifically:
        >>> width_default_px = fig.layout.width
        >>> targ_dpi = 300
        >>> scale = width_inches / (width_default_px / dpi) * (targ_dpi / dpi)
        Feel free to ignore this parameter.
    
    Returns
    -------
    fig : go.Figure
        The matplotlibified plotly figure.
    
    Examples
    --------
    >>> import plotly.express as px
    >>> df = px.data.tips()
    >>> fig = px.histogram(df, x="day")
    >>> fig.show()
    >>> fig = matplotlibify(fig, size=24, width_inches=3.5, height_inches=3.5, dpi=142)
    >>> fig.show()
    
    Note the difference between URL and URL.
    """
    font_dict = dict(family="Arial", size=size, color="black")

    # app = QApplication(sys.argv)
    # screen = app.screens()[0]
    # dpi = screen.physicalDotsPerInch()
    # app.quit()

    fig.update_layout(
        font=font_dict,
        plot_bgcolor="white",
        width=width_inches * dpi,
        height=height_inches * dpi,
        margin=dict(r=40, t=20, b=10),
    )

    fig.update_yaxes(
        showline=True,  # add line at x=0
        linecolor="black",  # line color
        linewidth=2.4,  # line size
        ticks="inside",  # ticks outside axis
        tickfont=font_dict,  # tick label font
        mirror="allticks",  # add ticks to top/right axes
        tickwidth=2.4,  # tick width
        tickcolor="black",  # tick color
    )

    fig.update_xaxes(
        showline=True,
        showticklabels=True,
        linecolor="black",
        linewidth=2.4,
        ticks="inside",
        tickfont=font_dict,
        mirror="allticks",
        tickwidth=2.4,
        tickcolor="black",
    )
    fig.update(layout_coloraxis_showscale=False)

    width_default_px = fig.layout.width
    targ_dpi = 300
    scale = width_inches / (width_default_px / dpi) * (targ_dpi / dpi)

    if return_scale:
        return fig, scale
    else:
        return fig

def plot_and_save(fig_path, fig, mpl_kwargs={}, show=False, update_legend=False):
    if show:
        try:
          fig.show()
        except Exception as e:
          print(e)
          offline.plot(fig)
          pass
    fig.write_html(fig_path + ".html")
    fig.to_json(fig_path + ".json")
    if update_legend:
        fig.update_layout(
            legend=dict(
                font=dict(size=16),
                yanchor="bottom",
                y=0.99,
                xanchor="right",
                x=0.99,
                bgcolor="rgba(0,0,0,0)",
                # orientation="h",
            )
        )
    fig = matplotlibify(fig, **mpl_kwargs)
    fig.write_image(fig_path + ".png")

From here, we just loop through the various parameters, plotting and saving histograms as we go. If running on Google Colab, these will be saved to the current directory which is temporary storage that will be purged after the session is closed.

In [69]:
figs = []
for lbl in df.columns.drop("min_distance"):
  fig = px.histogram(df, x=lbl, marginal="rug")
  fig = matplotlibify(fig)
  figs.append(fig)
  plot_and_save(lbl+"_hist", fig, show=False)

Here's an example of what the first figure looks like (compare with the histogram from an earlier section in terms of formatting).

In [70]:
figs[0]