In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import os
import copy
import re
import random
import json
from pathlib import Path
import dotenv
import warnings

from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt

from mp_api.client import MPRester
from pymatgen.core import Structure, Composition
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from robocrys.condense.mineral import MineralMatcher

from openai import OpenAI

plt.rcParams["font.size"] = 20
warnings.filterwarnings("ignore")

In [5]:
base_path = Path("./data/mp-40/")

### Set up API key

Save your materials project API key in a .env file in the same directory as this notebook. The file should look like this:

```
MP_API_KEY=your_api_key
```

In [7]:
dotenv.load_dotenv()
MP_API_KEY = os.getenv("MP_API_KEY")
SCOPUS_API_KEY = os.getenv("SCOPUS_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

# 1. Materials Project

### 1.1. Retrieving created_at from the Materials Project API

`created_at` is only available in the `mpr.materials.search`, so we will use this function to retrieve the registration date of the materials.

In [None]:
with MPRester(MP_API_KEY) as mpr:
    total_docs = mpr.materials.search(
        num_sites=[0, 40],
        fields=[
            "material_id",
            "created_at",
        ],
    )

In [None]:
data = []
for doc in total_docs:
    data.append(
        {
            "material_id": doc.material_id,
            "created_at": doc.created_at,
        }
    )
df_mp_created_at = pd.DataFrame(data)
# remove duplicates
df_mp_created_at = df_mp_created_at.drop_duplicates(subset=["material_id"])
# save to csv
df_mp_created_at.to_csv(base_path / "mp-created-at.csv", index=False)

### 1.2. Donwload snapshot with constraints of num_sites <= 40 and energy convex hull <= 0.25 eV and experimental = True

In [None]:
with MPRester(MP_API_KEY) as mpr:
    docs = mpr.summary.search(
        num_sites=[0, 40],
        # energy_above_hull=[0, 0.25],
        theoretical=False,
        fields=[
            "material_id",
            "structure",
            "energy_above_hull",
            "band_gap",
            "theoretical",
        ],
    )

In [10]:
excluded_gas_list = [
    "H",
    "He",
    "N",
    "O",
    "F",
    "Ne",
    "Cl",
    "Ar",
    "Kr",
    "Xe",
    "Rn",
    "Fr",
    "Og",
]

In [None]:
data = []
for doc in tqdm(docs):
    st = doc.structure
    elements = [elmt.symbol for elmt in st.composition.elements]

    if len(elements) == 1 and elements[0] in excluded_gas_list:
        print(elements)
        continue

    if max(st.lattice.abc) > 20:
        print(st.formula, st.lattice.abc)
        continue

    row = {
        "material_id": doc.material_id,
        "energy_above_hull": doc.energy_above_hull,
        "band_gap": doc.band_gap,
        "cif": st.to(fmt="cif"),
    }
    data.append(row)

df_mp_api = pd.DataFrame(data)
# remove duplicates
df_mp_api = df_mp_api.drop_duplicates(subset="material_id")
# shuffle
df_mp_api = df_mp_api.sample(frac=1, random_state=42).reset_index()
# save to csv
df_mp_api.to_csv(base_path / "mp-api.csv", index=False)

In [None]:
plt.rcParams["font.size"] = 20
df_mp_api["energy_above_hull"].hist(bins=100)
plt.xlabel("Energy above hull (eV/atom)")
plt.ylabel("Number of materials")
plt.ylim(0, 1000)
# fill the region of interest with red (0 to 0.25 eV/atom)
plt.axvspan(0, 0.25, color="red", alpha=0.2)
plt.text(0.25, 800, "Meta-stable", color="red")

### 1.3. Download abstract using Crossref API

In [15]:
# refer to "./data/mp-40/get_abstract_from_materials_id.py"

In [None]:
df_mp_api = pd.read_csv(base_path / "mp-api.csv")
material_ids = df_mp_api["material_id"].values
print(len(material_ids))

# Analyze the abstract data
print(f"Number of total materials: {len(material_ids)}")
path_abstract_data = base_path / "abstract_data"
abstract_data_files = list(path_abstract_data.glob("*.json"))
print(f"Number of abstract data: {len(abstract_data_files)}")

# make pandas dataframe
data = []
for file in abstract_data_files:
    d = json.load(open(file))
    material_id = file.stem.split("_")[0]
    d["material_id"] = material_id
    data.append(d)

df_abstract_data = pd.DataFrame(data)
# df_abstract_data.to_csv(base_path / "abstract-data.csv", index=False)

### 1.4. Calculate parameters using pymatgen

In [None]:
# calculate properties
from pandarallel import pandarallel

pandarallel.initialize(progress_bar=True)

mineral_matcher = MineralMatcher()


def calculate_property(data):
    st = Structure.from_str(data.cif, fmt="cif")
    sg = SpacegroupAnalyzer(st, symprec=0.1)
    data["composition"] = st.composition.reduced_composition.alphabetical_formula
    data["volume"] = st.volume
    data["density"] = st.density
    data["atomic_density"] = st.density
    data["crystal_system"] = sg.get_crystal_system()
    data["space_group_symbol"] = sg.get_space_group_symbol()
    data["space_group_number"] = sg.get_space_group_number()
    try:
        data["mineral"] = mineral_matcher.get_best_mineral_name()["type"]
    except:
        data["mineral"] = None
    return data


df_mp_api = pd.read_csv(base_path / "mp-api.csv")
df_mp_total = df_mp_api.parallel_apply(calculate_property, axis=1)
df_mp_total.to_csv(base_path / "mp-total.csv", index=False)

### 1.5. Make test set registered after 

In [19]:
df_mp_total = pd.read_csv(base_path / "mp-total.csv")

In [21]:
# stable label
df_mp_total["stability"] = df_mp_total["energy_above_hull"].apply(
    lambda x: "stable" if x == 0 else "metastable" if x < 0.25 else "unstable"
)
# metallic label
df_mp_total["metallic"] = df_mp_total["band_gap"].apply(
    lambda x: "metallic" if x < 0.1 else "insulator"
)

In [None]:
df_mp_total = df_mp_total.sample(frac=1, random_state=42)
df_test = df_mp_total.iloc[:1000]
df_train_val = df_mp_total.iloc[1000:]
num_train = int(len(df_train_val) * 0.90)
df_train = df_train_val.iloc[:num_train]
df_val = df_train_val.iloc[num_train:]
print(len(df_train), len(df_val), len(df_test))

In [34]:
# save
df_train.to_csv(base_path / "train.csv", index=False)
df_val.to_csv(base_path / "val.csv", index=False)
df_test.to_csv(base_path / "test.csv", index=False)

# 2. Text Prompts (MatCap)

### 2.1. Generate prompts using OPENAI API

In [6]:
# !generate_text_prompt.py

In [7]:
df_mp_total = pd.read_csv(base_path / "mp-total.csv")
df_abstract_data = pd.read_csv(base_path / "abstract-data.csv")
df_abstract_data = df_abstract_data[~df_abstract_data["abstract"].isna()]

In [54]:
def composition_augmentation(composition: Composition):
    comp_list = list(
        set(
            [
                composition.reduced_formula,  # reduced formula
                composition.reduced_composition.alphabetical_formula,  # reduced alphabetical formula
                composition.reduced_composition.iupac_formula,  # reduced IUPAC formula
                composition.reduced_composition.hill_formula,  # reduced Hill formula
            ]
        )
    )
    return comp_list

In [None]:
for idx, row in tqdm(df_mp_total.iterrows(), total=len(df_mp_total)):
    st = Structure.from_str(row["cif"], fmt="cif")
    # composition
    composition = st.composition
    reduced_formula = composition.reduced_formula
    comp_list = composition_augmentation(composition)

    # crystal system
    crystal_system = row["crystal_system"]

    # stability: stable < 0 eV, metastable < 0.25 eV, unstable > 0.25 eV
    stability = (
        "stable"
        if row["energy_above_hull"] == 0
        else "metastable" if row["energy_above_hull"] < 0.25 else "unstable"
    )

    # metallic: metallic < 0.1 eV, insulator > 0.1 eV
    metallic = "metallic" if row["band_gap"] < 0.1 else "insulator"

    # mineral
    mineral = "" if pd.isna(row["mineral"]) else row["mineral"]

    # paper data
    paper_data = ""
    if row["material_id"] in df_abstract_data["material_id"].values:
        df = df_abstract_data[df_abstract_data["material_id"] == row["material_id"]]
        for i, (_, row) in enumerate(df.iterrows()):
            paper_data += f"paper{i+1} - title: {row['title'][2:-2]} | abstract: {row['abstract']}\n"
    paper_data = paper_data[:3000]  # limit the length of paper_data

    template = f"""provide concise captions for "{reduced_formula}" with the following properties:


{crystal_system}
{stability}
{metallic}
{mineral}
{paper_data}

Here are some examples for other crystal systems:
1. Orthorhombic crystal structure of ZnMnO4
2. metastable crystal structure of LiO2
3. Si1 C1 crystal structure with metallic properties

Please provide "five concise captions" to describe the crystal structure with the various compound name: {", ".join(comp_list)}
"""
    print(template)
    print("-" * 40)
    if idx == 0:
        break

In [20]:
client = OpenAI(
    api_key=OPENAI_API_KEY,
)

In [None]:
print(template)
chat_completion = client.chat.completions.create(
    messages=[
        {
            "role": "user",
            "content": template,
        }
    ],
    model="gpt-4o-mini",
)

In [None]:
output = chat_completion.choices[0].message.content
print(output)

### 2.2. Post-process prompts

In [8]:
prompts_dir = base_path / "prompts"
prompts_files = list(prompts_dir.glob("*.txt"))
print(len(prompts_files))

35318


In [58]:
prompts = {}
for text_file in prompts_files:
    material_id = text_file.stem
    with open(text_file, "r") as f:
        raw_text = f.read()
        f.close()
    text = copy.deepcopy(raw_text)
    # remove '\n' that are not followed by a number and dot
    text = re.sub(r"\n(?!\d+\.)", "", text)
    # remove sentence before the first prompt
    start_idx = text.find("1.")
    text = text[start_idx:]
    # remove "**"
    text = re.sub(r"\*\*", "", text)
    # remove the number and dot (e.g., "1. ")
    prompt_list = text.split("\n")
    # remove empty strings including spaces and newlines
    prompt_list = [t for t in prompt_list if t.strip()]
    len_prompts = len(prompt_list)
    if len_prompts < 5:
        text = re.sub(r"^\d+\.\s", "", text, flags=re.MULTILINE)
        prompt_list = text.split("\n")
    # remove the number and dot (e.g., "1. ")
    prompt_list = [re.sub(r"^\d+\.\s", "", t) for t in prompt_list]

    if len(prompt_list) != 5:
        raise ValueError("The number of prompts is not 5.")

    prompt = random.choice(prompt_list)  # select one prompt randomly
    prompts[material_id] = prompt_list
df_prompts = pd.DataFrame(prompts.items(), columns=["material_id", "prompt"])

In [62]:
# add prompts to the dataset
for split in ["train", "val", "test"]:
    df = pd.read_csv(base_path / f"{split}.csv")
    df = df.merge(df_prompts, on="material_id", how="left")
    df.to_csv(base_path / f"{split}.csv", index=False)