In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import re
import random
import requests
import json
from pathlib import Path
import dotenv
import warnings
import datetime

from tqdm import tqdm
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import bibtexparser
from mp_api.client import MPRester
from pymatgen.core import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from ase.visualize import view

plt.rcParams["font.size"] = 20
warnings.filterwarnings("ignore")

In [3]:
base_path = Path("./data/mp-40/")

# 1. Materials Project

### 1.0. Set up API key

Save your materials project API key in a .env file in the same directory as this notebook. The file should look like this:

```
MP_API_KEY=your_api_key
```

In [4]:
dotenv.load_dotenv()
MP_API_KEY = os.getenv("MP_API_KEY")
SCOPUS_API_KEY = os.getenv("SCOPUS_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

### 1.1. Retrieving created_at from the Materials Project API

`created_at` is only available in the `mpr.materials.search`, so we will use this function to retrieve the registration date of the materials.

In [None]:
with MPRester(MP_API_KEY) as mpr:
    total_docs = mpr.materials.search(
        num_sites=[0, 40],
        fields=[
            "material_id",
            "created_at",
        ],
    )

In [None]:
data = []
for doc in total_docs:
    data.append(
        {
            "material_id": doc.material_id,
            "created_at": doc.created_at,
        }
    )
df_mp_created_at = pd.DataFrame(data)
# remove duplicates
df_mp_created_at = df_mp_created_at.drop_duplicates(subset=["material_id"])
# save to csv
df_mp_created_at.to_csv(base_path / "mp-created-at.csv", index=False)

### 1.2. Donwload snapshot with constraints of num_sites <= 40 and energy convex hull <= 0.25 eV and experimental = True

In [None]:
with MPRester(MP_API_KEY) as mpr:
    docs = mpr.summary.search(
        num_sites=[0, 40],
        # energy_above_hull=[0, 0.25],
        theoretical=False,
        fields=[
            "material_id",
            "structure",
            "energy_above_hull",
            "band_gap",
            "theoretical",
        ],
    )

In [10]:
excluded_gas_list = [
    "H",
    "He",
    "N",
    "O",
    "F",
    "Ne",
    "Cl",
    "Ar",
    "Kr",
    "Xe",
    "Rn",
    "Fr",
    "Og",
]

In [None]:
data = []
for doc in tqdm(docs):
    st = doc.structure
    elements = [elmt.symbol for elmt in st.composition.elements]

    if len(elements) == 1 and elements[0] in excluded_gas_list:
        print(elements)
        continue

    if max(st.lattice.abc) > 20:
        print(st.formula, st.lattice.abc)
        continue

    row = {
        "material_id": doc.material_id,
        "energy_above_hull": doc.energy_above_hull,
        "band_gap": doc.band_gap,
        "cif": st.to(fmt="cif"),
    }
    data.append(row)

df_mp_api = pd.DataFrame(data)
# remove duplicates
df_mp_api = df_mp_api.drop_duplicates(subset="material_id")
# shuffle
df_mp_api = df_mp_api.sample(frac=1, random_state=42).reset_index()
# save to csv
df_mp_api.to_csv(base_path / "mp-api.csv", index=False)

In [None]:
plt.rcParams["font.size"] = 20
df_mp_api["energy_above_hull"].hist(bins=100)
plt.xlabel("Energy above hull (eV/atom)")
plt.ylabel("Number of materials")
plt.ylim(0, 1000)
# fill the region of interest with red (0 to 0.25 eV/atom)
plt.axvspan(0, 0.25, color="red", alpha=0.2)
plt.text(0.25, 800, "Meta-stable", color="red")

### 1.3. Download abstract using Crossref API

In [15]:
# refer to "./data/mp-40/get_abstract_from_materials_id.py"

In [None]:
df_mp_api = pd.read_csv(base_path / "mp-api.csv")
material_ids = df_mp_api["material_id"].values
print(len(material_ids))

# Analyze the abstract data
print(f"Number of total materials: {len(material_ids)}")
path_abstract_data = base_path / "abstract_data"
abstract_data_files = list(path_abstract_data.glob("*.json"))
print(f"Number of abstract data: {len(abstract_data_files)}")

# make pandas dataframe
data = []
for file in abstract_data_files:
    d = json.load(open(file))
    material_id = file.stem.split("_")[0]
    d["material_id"] = material_id
    data.append(d)

df_abstract_data = pd.DataFrame(data)
df_abstract_data

In [None]:
# The number of reference papers per material
df_abstract_data["material_id"].value_counts().value_counts().sort_index().plot(
    kind="bar"
)
plt.xlabel("Number of references per material")
plt.ylabel("Number of materials")
plt.xlim(0, 10)

In [None]:
# Histogram of publication year
df_year = df_abstract_data["year"].dropna()
df_year = df_year[df_year.str.isnumeric()]
df_year.astype(int).hist(bins=50)
plt.xlabel("Publication year")

In [13]:
# find NaN values
df = df_abstract_data[~df_abstract_data["abstract"].isna()]

In [None]:
for i in range(1000):
    print(df.iloc[i]["abstract"], df.iloc[i]["url"])
    print("-" * 40)

### 1.4. Calculate parameters using pymatgen

In [None]:
# calculate properties
from pandarallel import pandarallel

pandarallel.initialize(progress_bar=True)


def calculate_property(data):
    st = Structure.from_str(data.cif, fmt="cif")
    sg = SpacegroupAnalyzer(st, symprec=0.1)
    data["composition"] = st.composition.reduced_composition.alphabetical_formula
    data["volume"] = st.volume
    data["density"] = st.density
    data["atomic_density"] = st.density
    data["crystal_system"] = sg.get_crystal_system()
    data["space_group_symbol"] = sg.get_space_group_symbol()
    data["space_group_number"] = sg.get_space_group_number()
    return data


df_mp_api = pd.read_csv(base_path / "mp-api.csv")
df_mp_total = df_mp_api.parallel_apply(calculate_property, axis=1)
df_mp_total.to_csv(base_path / "mp-total.csv", index=False)

### 1.3. Make test set registered after 

In [111]:
df_mp_total = pd.read_csv(base_path / "mp-total.csv")
df_mp_total = df_mp_total.drop(columns=["index"])

In [112]:
# stable label
df_mp_total["stability"] = df_mp_total["energy_above_hull"].apply(
    lambda x: "stable" if x == 0 else "metastable" if x < 0.25 else "unstable"
)
# metallic label
df_mp_total["metallic"] = df_mp_total["band_gap"].apply(
    lambda x: "metallic" if x < 0.1 else "insulator"
)

In [None]:
df_total = df_mp_total.sample(frac=1, random_state=42)
df_test = df_total.iloc[:1000]
df_train_val = df_total.iloc[1000:]
num_train = int(len(df_train_val) * 0.90)
df_train = df_train_val.iloc[:num_train]
df_val = df_train_val.iloc[num_train:]
print(len(df_train), len(df_val), len(df_test))

In [114]:
# save
df_train.to_csv(base_path / "train.csv", index=False)
df_val.to_csv(base_path / "val.csv", index=False)
df_test.to_csv(base_path / "test.csv", index=False)

# 2. Text Prompts

In [None]:
# ! generate_text_prompt.py

In [None]:
path_prompts = Path("../mp-50/prompts/")  # TODO: change the path
text_files = list(path_prompts.glob("*.txt"))

# read and make df
prompts = {}
for text_file in text_files:
    material_id = text_file.stem
    with open(text_file, "r") as f:
        text = f.read()
        revised_text = re.sub(r"\d+\.\s", "", text)
        text_prompts = revised_text.split("\n")
        prompt = random.choice(text_prompts)  # select one prompt randomly
        prompts[material_id] = prompt

df_prompts = pd.DataFrame(prompts.items(), columns=["material_id", "prompt"])

In [None]:
# update trian, test, val
for split in ["train", "val", "test"]:
    df = pd.read_csv(base_path / f"{split}.csv")
    df = pd.merge(df, df_prompts, on="material_id")
    df.to_csv(base_path / f"{split}.csv", index=False)

## Info lattice matrix

In [97]:
df_train = pd.read_csv(base_path / "train.csv")
st_list = [Structure.from_str(cif, fmt="cif") for cif in df_train["cif"]]

In [101]:
lattice_params = np.array([st.lattice.parameters for st in st_list])
lattice_params_mean = lattice_params.mean(axis=0)
lattice_params_std = lattice_params.std(axis=0)
print(lattice_params_mean, lattice_params_std)

In [105]:
# write
lattice_params_mean = lattice_params_mean.tolist()
lattice_params_std = lattice_params_std.tolist()
with open(base_path / "lattice_params.txt", "w") as f:
    f.write(f"mean: {lattice_params_mean}\n")
    f.write(f"std: {lattice_params_std}\n")