In [None]:
import os
from pymatgen.core import Structure
from pathlib import Path
import pandas as pd
import numpy as np
import sys
from mp_api.client import MPRester
import matplotlib.pyplot as plt
import scipy.stats as stats
import matplotlib.colors as mcolors
import traceback
import orjson
import zipfile
import zstandard as zstd
import shutil
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from PIL import Image
from ase.visualize import view
from ase.io import write
from ase.io.utils import PlottingVariables # write kwargs here
import itertools
from pymatgen.analysis.structure_matcher import StructureMatcher
from typing import Callable
from tqdm import tqdm


colors = list(mcolors.TABLEAU_COLORS)
sys.path.append("../../../")
from utils.save_and_load import save_to_json, load_from_json

#### Tensor to scalar

In [None]:
def dict2tensor(data: dict):
    d = {k: v["value"] for k, v in data["dinf"].items()}
    return np.array(
        [
            [
                [d["111"], d["112"], d["113"]],
                [d["121"], d["122"], d["123"]],
                [d["131"], d["132"], d["133"]],
            ],
            [
                [d["211"], d["212"], d["213"]],
                [d["221"], d["222"], d["223"]],
                [d["231"], d["232"], d["233"]],
            ],
            [
                [d["311"], d["312"], d["313"]],
                [d["321"], d["322"], d["323"]],
                [d["331"], d["332"], d["333"]],
            ],
        ]
    )

In [None]:
def tensor2scalar(d: np.ndarray):
    return (
        19 / 105 * sum(d[i, i, i] ** 2 for i in range(3))
        + 13
        / 105
        * sum(d[i, i, i] * d[i, j, j] for i in range(3) for j in range(3) if i != j)
        + 44 / 105 * sum(d[i, i, j] ** 2 for i in range(3) for j in range(3) if i != j)
        + 13
        / 105
        * sum(d[a, a, b] * d[b, c, c] for a, b, c in ((0, 1, 2), (1, 2, 0), (2, 0, 1)))
        + 5
        / 7
        * np.mean(
            [
                d[i, j, k] ** 2
                for i in range(3)
                for j in range(3)
                for k in range(3)
                if i != j and j != k and k != i
            ]
        )
    ) ** 0.5

In [None]:
# Levine, Z. H.; Allan, D. C. Large Local-Field Effects in the Second-Harmonic Susceptibility of Crystalline Urea. Phys. Rev. B 1993, 48, 7783.

urea_dict = {
    "dinf": {
        "111": {"value": 0, "unit": "pm/V"},
        "112": {"value": 0, "unit": "pm/V"},
        "113": {"value": 0, "unit": "pm/V"},
        "121": {"value": 0, "unit": "pm/V"},
        "122": {"value": 0, "unit": "pm/V"},
        "123": {"value": 1.2, "unit": "pm/V"},
        "131": {"value": 0, "unit": "pm/V"},
        "132": {"value": 1.2, "unit": "pm/V"},
        "133": {"value": 0, "unit": "pm/V"},
        "211": {"value": 0, "unit": "pm/V"},
        "212": {"value": 0, "unit": "pm/V"},
        "213": {"value": 1.2, "unit": "pm/V"},
        "221": {"value": 0, "unit": "pm/V"},
        "222": {"value": 0, "unit": "pm/V"},
        "223": {"value": 0, "unit": "pm/V"},
        "231": {"value": 1.2, "unit": "pm/V"},
        "232": {"value": 0, "unit": "pm/V"},
        "233": {"value": 0, "unit": "pm/V"},
        "311": {"value": 0, "unit": "pm/V"},
        "312": {"value": 1.2, "unit": "pm/V"},
        "313": {"value": 0, "unit": "pm/V"},
        "321": {"value": 1.2, "unit": "pm/V"},
        "322": {"value": 0, "unit": "pm/V"},
        "323": {"value": 0, "unit": "pm/V"},
        "331": {"value": 0, "unit": "pm/V"},
        "332": {"value": 0, "unit": "pm/V"},
        "333": {"value": 0, "unit": "pm/V"},
    },
}

In [None]:
# d0ra03136d.pdf, 252.pdf
KDP_dict = {
    "dinf": {
        "111": {"value": 0, "unit": "pm/V"},
        "112": {"value": 0, "unit": "pm/V"},
        "113": {"value": 0, "unit": "pm/V"},
        "121": {"value": 0, "unit": "pm/V"},
        "122": {"value": 0, "unit": "pm/V"},
        "123": {"value": 0.41, "unit": "pm/V"},
        "131": {"value": 0, "unit": "pm/V"},
        "132": {"value": 0.41, "unit": "pm/V"},
        "133": {"value": 0, "unit": "pm/V"},
        "211": {"value": 0, "unit": "pm/V"},
        "212": {"value": 0, "unit": "pm/V"},
        "213": {"value": 0.41, "unit": "pm/V"},
        "221": {"value": 0, "unit": "pm/V"},
        "222": {"value": 0, "unit": "pm/V"},
        "223": {"value": 0, "unit": "pm/V"},
        "231": {"value": 0.41, "unit": "pm/V"},
        "232": {"value": 0, "unit": "pm/V"},
        "233": {"value": 0, "unit": "pm/V"},
        "311": {"value": 0, "unit": "pm/V"},
        "312": {"value": 0.41, "unit": "pm/V"},
        "313": {"value": 0, "unit": "pm/V"},
        "321": {"value": 0.41, "unit": "pm/V"},
        "322": {"value": 0, "unit": "pm/V"},
        "323": {"value": 0, "unit": "pm/V"},
        "331": {"value": 0, "unit": "pm/V"},
        "332": {"value": 0, "unit": "pm/V"},
        "333": {"value": 0, "unit": "pm/V"},
    },
}

In [None]:
KDP_eff = tensor2scalar(dict2tensor(KDP_dict))
KDP_eff

In [None]:
urea = tensor2scalar(dict2tensor(urea_dict))
urea_div_KDP = urea / KDP_eff
urea, urea_div_KDP

#### Articles data

In [None]:
# articles
path1 = Path(
    "../../raw_data/private/Metal–organic frameworks as competitive (1)/Metal–organic frameworks as competitive/"
)
path2 = Path(
    "../../raw_data/private/Rational_Synthesis_of_Noncentrosymmetric_Metal–Organic_Frameworks/Rational Synthesis of Noncentrosymmetric Metal–Organic Frameworks for Second-Order Nonlinear Optics/"
)

In [None]:
data1: list[tuple[str, float]] = [
    ("KDP", 2.5),
    ("urea", 80),
    ("KDP", 0.1),
    ("KDP", 0.13),
    ("SiO2", 80),
    ("KDP", 5),
    ("urea", 0.8),
    ("urea", 1),
    ("SiO2", 15),
    ("KDP", 2.1),
    ("urea", 0.5),
    ("urea", 0.5),
    ("urea", 0.3),
    ("urea", 0.3),
    ("urea", 0.3),
    ("urea", 0.7),
    ("KDP", 4.24),
    ("KDP", 0.9),
    ("KDP", 0.1),
    ("KDP", 0.5),
    ("urea", 0.8),
    ("KDP", 4),
    ("KDP", 7),
    ("urea", 0.5),
    ("urea", 1.5),
    ("KDP", 1.5),
    ("urea", 0.9),
    ("KDP", 0.35),
    ("KDP", 0.40),
    ("KDP", 0.17),
    ("KDP", 0.08),
    ("KDP", 0.10),
    ("urea", 0.3),
    ("urea", 0.4),
    ("urea", 0.4),
    ("urea", 0.8),
    ("KDP", 3.6),
    ("urea", 0.7),
    ("urea", 0.7),
    ("urea", 0.8),
    ("KDP", 1),
    ("KDP", 1.1),
    ("KDP", 15),
    ("KDP", 0.27),
    ("urea", 0.3),
    ("KDP", 0.7),
    ("KDP", 0.466),
    ("KDP", 0.122),
    ("KDP", 5.6),
]

In [None]:
data2: list[tuple[int | str, str | None, float | None]] = [
    (1, "SiO2", 1.5),
    (4, "SiO2", 126),
    (6, "SiO2", 18),
    (8, "SiO2", 310),
    (9, "SiO2", 400),
    (10, "SiO2", 345),
    (11, "KDP", 3),
    (12, "KDP", 5),
    (13, "urea", 0.5),
    (14, None, None),  # active
    (15, "urea", 1.2),
    (16, "urea", 1),
    (17, "urea", 4),
    (18, "urea", 0.4),
    (19, "KDP", 0.5),
    (20, "KDP", 1.1),
    (21, "KDP", 2.5),
    (22, "SiO2", 80),
    (23, "SiO2", 10),
    (24, "SiO2", 70),
    (25, "KDP", 5),
    (26, "KDP", 1),
    (27, "KDP", 1),  # > 1
    (28, "KDP", 3.5),
    (29, "KDP", 5),
    (30, "KDP", 6.5),
    (31, "urea", 0.5),
    (32, "KDP", 0.4),
    (33, "KDP", 0.2),
    (34, "urea", 0.7),
    (35, "urea", 0.7),
    (36, "KDP", 3),
    (37, "KDP", 2),
    (38, "urea", 0.6),
    (39, "urea", 0.7),
    (40, "urea", 0.8),
    (41, "KDP", 4),
    (42, "urea", 0.8),
    (43, "KDP", 2.5),
    (44, None, None),  # active
    (45, "KDP", 3.5),  # 3-4
    (46, "KDP", 1.5),
    (47, "KDP", 2.3),
    (48, "KDP", 0.6),
    (49, None, None),  # active
    (50, "KDP", 10),
    ("51a", "SiO2", 150),  # ???
    ("51b", "SiO2", 155),
    ("51c", "SiO2", 90),
    ("51d", "SiO2", 110),
    ("52a", "SiO2", 15),
    ("52b", "SiO2", 24),
    ("52c", "SiO2", 35),
    ("52d", "SiO2", 11),
    ("52e", "SiO2", 20),
    ("52f", "SiO2", 17),  # .
    (53, "SiO2", 2),
    (54, "SiO2", 1000),
    (56, "SiO2", 400),  # !
    (57, None, None),  # large
    (58, "urea", 50),
    (59, "urea", 0.4),
    (60, "urea", 16.8),
    (61, "KDP", 8),
    (62, "KDP", 2),
    (63, "urea", 0.5),
    (64, "urea", 0.02),
    (65, "urea", 0.9),
    (66, "KDP", 1),
    (67, "KDP", 2),
    (68, "urea", 1),
    (69, None, None),  # active
    (70, "SiO2", 75),
    (71, "urea", 1),
    (72, "KDP", 2),
    (73, "urea", 0.7),
    (74, None, None),  # active
    (75, "urea", 0.8),
    (76, "KDP", 0.4),
    (77, "KDP", 0.5),
    (78, "KDP", 1),
    (79, "KDP", 2),
    (80, "KDP", 1),  # <1
    (81, "KDP", 3),
    (82, "KDP", 1),  # <1
    (83, "urea", 0.7),
    (84, "urea", 0.3),
    (85, "SiO2", 200),  #!
    (87, None, None),  # weak
    (88, None, None),  # weak
    (89, None, None),  # weak
    (90, None, None),  # weak
    (91, None, None),  # None
    (92, "urea", 0.3),
    (93, "urea", 0.6),
    (94, "urea", 0.8),
    (95, "urea", 0.8),
    (96, "urea", 0.8),
    (97, "KDP", 5),
    (98, "KDP", 1.5),
    (99, None, None),  # active
    (100, "SiO2", 6),
    (101, "SiO2", 20),
    (102, None, None),  # active
    (103, None, None),  # active
    (104, None, None),  # active
    (105, "KDP", 3),
    (106, "KDP", 1),  # <1
    (107, "urea", 1),
    (108, "SiO2", 460),
    (109, "KDP", 6),
    (110, "KDP", 1.5),
    (111, "KDP", 0.2),
    (112, "urea", 0.3),
    (113, "KDP", 4),
    (114, "urea", 2.9),
    (115, "urea", 0.2),
    (116, "urea", 0.6),
    (117, "KDP", 0.8),
    (118, "KDP", 0.9),
    (119, "urea", 0.3),
    (120, "urea", 0.8),
    (121, "urea", 0.4),
    (122, "KDP", 2.8),
    (123, "KDP", 2.6),
    (124, None, None),  # active
    (125, "urea", 0.05),
    (126, "urea", 0.06),
    (127, "urea", 1.2),
    (128, "urea", 0.1),
    (129, "urea", 1),
    (130, "urea", 6),
    (131, "urea", 5),
    (132, "KDP", 0.8),
    (133, "KDP", 20),
    (134, None, None),  # active
    (135, None, None),  # active
    (136, "urea", 80),
]

In [None]:
print(len(data1))
print(len(data2))
print(len(data1) + len(data2))

In [None]:
if len(data1[0]) == 2:
    data1 = [(idx, pair[0], pair[1]) for idx, pair in enumerate(data1, 1)]

print(data1)

In [None]:
import warnings

warnings.filterwarnings("error")

cifs1: dict[str, dict] = dict()
cifs2: dict[str, dict] = dict()
skipped_sio2 = 0
skipped_other_reason = 0
warnings_count = 0
for a, (cifs, path, data) in enumerate(((cifs1, path1, data1), (cifs2, path2, data2))):
    # iterate over existing .cif files
    for file in os.listdir(path):
        try:
            i = file.split("(")[1].split(")")[0]
            cif_name = file.split(") ")[1]

            # skip non-integer idx
            try:
                int(i)
            except Exception:
                skipped_sio2 += 1
                continue

            per_article_cif_idx = int(i)
            try:
                data_tuple = None
                # find corresponding manually copied data entry
                for (
                    per_article_data_idx,
                    reference_structure,
                    relative_intensity,
                ) in data:
                    if per_article_data_idx == per_article_cif_idx:
                        data_tuple = (
                            per_article_data_idx,
                            reference_structure,
                            relative_intensity,
                        )
                if data_tuple is None:
                    print(f"{per_article_cif_idx=} not found")
                    continue
                per_article_data_idx, reference_structure, relative_intensity = (
                    data_tuple
                )
                # shg = (
                #      reference_intensity ** 0.5 * KDP_eff
                #     if data_tuple[-2] == "KDP"
                #     else data_tuple[-1] ** 0.5 * urea if data_tuple[-2] == "urea" else None
                # )

                # calculating shg scalar by the eq:
                # d^2_{MOF} = d^2_{ref} * I_{MOF} / I_{ref}

                # d_{KP} ~= d_{MOF} = (d^2_{ref} * I_{MOF} / I_{ref}) ** 0.5
                if reference_structure == "KDP":
                    reference_d_KP = KDP_eff
                elif reference_structure == "urea":
                    reference_d_KP = urea
                elif reference_structure == "SiO2":
                    skipped_sio2 += 1
                    continue
                else:
                    skipped_other_reason += 1
                    continue
                shg = ((reference_d_KP**2) * relative_intensity) ** 0.5

                try:
                    structure = Structure.from_file(path.joinpath(file)).as_dict()
                except Exception:
                    warnings_count += 1
                    print(
                        [
                            line
                            for line in traceback.format_exc().split("\n")
                            if "UserWarning" in line
                        ][0]
                    )
                    print(
                        dict(
                            filename=file,
                            article_idx=a + 1,
                            per_article_cif_idx=per_article_cif_idx,
                        )
                    )

                cifs[cif_name] = dict(
                    structure=structure,
                    shg=shg,
                    filename=file,
                    article_idx=a + 1,
                    article_path=str(path),
                    per_article_cif_idx=per_article_cif_idx,
                )
            except Exception:
                skipped_other_reason += 1
                traceback.print_exc()
                pass
        except Exception:
            pass
            traceback.print_exc()
print(
    f"{(warnings_count,skipped_other_reason, skipped_sio2, len(cifs1), len(cifs2), len(cifs1) + len(cifs2))=}"
)

In [None]:
# print filename, d_kp
for cif_name, v in cifs.items():
    print(cif_name, v["filename"], v["shg"])

In [None]:
cifs_shg_articles = {}

for cif_name, v in cifs1.items():
    cifs_shg_articles[cif_name.replace(".cif", "")] = v
for cif_name, v in cifs2.items():
    cifs_shg_articles[cif_name.replace(".cif", "")] = v
# sorted(cifs1.keys()), sorted(cifs2.keys())
sorted(cifs_shg_articles.keys()), len(cifs_shg_articles)

In [None]:
articles_data_filename = "../../raw_data/private/cifs_shg_articles_cor.json"
save_to_json(cifs_shg_articles, articles_data_filename)

In [None]:
# articles_data_filename = "../../raw_data/private/cifs_shg_articles.json"
# sg_articles = load_from_json(articles_data_filename)

#### ABINIT data

In [None]:
# abinit json extraction from mpcontribs
raw_abinit_filepath = "../../raw_data/public/abinit.json"


def ab_extraction_from_mp_contribs(to_download: bool = False):
    from mpcontribs.client import Client

    MP_API_KEY = os.environ["MP_API_KEY"]
    client = Client(apikey=MP_API_KEY, project="shg")
    # print(client.available_query_params())
    # print(client.query_projects())
    d = client.get_all_ids()["shg"]
    keys = list(d.keys())
    keys, len(d[keys[1]]), list(d[keys[0]])[0]
    # print(client.get_contribution("654b3c3cad105cb1f2de5230"))
    data = {}
    i = 0
    if to_download:
        for idx, ident in zip(d[keys[0]], d[keys[1]]):
            contrib = client.get_contribution(idx)
            data[idx] = dict(
                data=contrib,
                structure=client.get_structure(contrib["structures"][0]["id"]),
            )
            i += 1
            print(f"{i=}")
    print(f"{len(data)=}")

    data_json = {}
    for k, v in data.items():
        data_json[k] = dict(
            data=v["data"],
            structure=v["structure"].to_json(),
        )
    save_to_json(data_json, raw_abinit_filepath)

In [None]:
# ab_extraction_from_mp_contribs(False)

In [None]:
raw_abinit_data_json = load_from_json(raw_abinit_filepath)

In [None]:
cifs_shg_abinit = {}
for cif_name, v in raw_abinit_data_json.items():
    identifier = v["data"]["identifier"]
    # print(list(v["data"].keys()))
    # print(v["data"]["identifier"])
    cifs_shg_abinit[identifier] = {
        "data": v["data"],
        "structure": Structure.from_str(v["structure"], fmt="json").as_dict(),
    }

In [None]:
abinit_filepath = "../../raw_data/public/shg_abinit.json"

In [None]:
save_to_json(cifs_shg_abinit, abinit_filepath)

In [None]:
cifs_shg_abinit = load_from_json(abinit_filepath)

#### Create base_shg_eff_dataset

In [None]:
# run parse_shg and cmp_abinit_vs_amcm first
# then run this cell
# to get scalars from tensors
raw_amcm_path = "../../raw_data/private/am_cm/amcm.json"
# append to base
base_shg_eff_dataset: dict[str, dict] = {}
cifs_shg_amcm: dict = load_from_json(raw_amcm_path)
m = 0
for cif_name, v in cifs_shg_amcm.items():
    # print(list(v["data"]["data"]["dinf"].keys()))
    shg = tensor2scalar(dict2tensor(v["data"]["data"]))

    print(cif_name)
    try:
        if shg < 250:
            base_shg_eff_dataset[cif_name] = dict(
                structure=v["structure"],
                shg=shg,
            )
            m = max(m, shg)
        else:
            print("filtered:", cif_name)
    except Exception:
        print(cif_name, traceback.format_exc())

In [None]:
# rename filenames in articles_data_filename like ccdc-*
# and append to base
cifs_shg_articles: dict = load_from_json(articles_data_filename)
for cif_name, v in cifs_shg_articles.items():
    print("ccdc-" + cif_name)
    if v["shg"] is not None:
        base_shg_eff_dataset["ccdc-" + cif_name] = dict(
            structure=v["structure"], shg=v["shg"]
        )

In [None]:
# get scalar shg from abinit
# and append to base
cifs_shg_abinit: dict = load_from_json(abinit_filepath)
for cif_name, v in cifs_shg_abinit.items():
    # print(list(v["data"]["data"]["dinf"].keys()))
    base_shg_eff_dataset[cif_name] = dict(
        structure=v["structure"],
        shg=tensor2scalar(dict2tensor(v["data"]["data"])),
    )

In [None]:
# base_shg_eff_dataset if full

### Filtering

In [None]:
# from base
# filter cifs without specie
for cif_name in list(base_shg_eff_dataset.keys()):
    try:
        crystal = Structure.from_dict(base_shg_eff_dataset[cif_name]["structure"])
        [crystal[i].specie.number for i in range(len(crystal))]
    except Exception as e:
        base_shg_eff_dataset.pop(cif_name)
        print(e, cif_name)

In [None]:
# filter 'problematic' cifs
# (they cannot be relabeled)
import warnings

for cif_name in list(base_shg_eff_dataset.keys()):
    # warnings.filterwarnings("error")
    try:
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            crystal = Structure.from_dict(base_shg_eff_dataset[cif_name]["structure"])
            crystal.relabel_sites(True).to_file("tmp_file.cif", "cif")
            crystal.from_file("tmp_file.cif")
            # print([e.category for e in w])
            if len(w) == 3:
                print(w[-1].message)
            # assert len(w) == 1
            # assert issubclass(w[-1].category, DeprecationWarning)
            # assert "deprecated" in str(w[-1].message)

    except Exception as e:
        # if "fractional coordinates rounded" not in str(
        #     e
        # ) and "We strongly discourage using implicit" not in str(e):
        base_shg_eff_dataset.pop(cif_name)
        print(e, cif_name)

In [None]:
# base_shg_eff_dataset_filepath = "../../final_data/base_dataset_of_eff_shg.json" - has problems with cifs_articles_data
base_shg_eff_dataset_filepath = "../../final_data/base_dataset_of_eff_shg_cor.json" # fixed problems
save_to_json(base_shg_eff_dataset, base_shg_eff_dataset_filepath)

In [None]:
base_shg_eff_dataset: dict[str, dict] = load_from_json(base_shg_eff_dataset_filepath)

In [None]:
# plot histogram by atom count
atom_count = []
for cif_name, v in base_shg_eff_dataset.items():
    atom_count.append(len(Structure.from_dict(v["structure"])))

In [None]:
# find potential repetitions by reduced form
# (not removing anything)

# check similar cifs to find the ones from different datasets
# can the duplicate be present in one dataset? - yes, in icsd
# Therefore, this code should be omitted

# reduced_formula -> similar cif_names
# reduced_formula_dict: dict[str, list[str]] = dict()
# counter = 0
# potential_repeat: list[tuple[list[str], str, str]] = []
# for cif_name, v in base_shg_eff_dataset.items():
#     # obtain reduced formula
#     structure = Structure.from_dict(v["structure"])
#     reduced_formula = structure.reduced_formula
#     # check if such reduced formula already present in dataset
#     if reduced_formula in reduced_formula_dict:
#         print(reduced_formula_dict[reduced_formula], cif_name, reduced_formula)
#         counter += 1
#         for potential_duplicate_cif_name in reduced_formula_dict[reduced_formula]:
#             if potential_duplicate_cif_name.split("-")[0] != cif_name.split("-")[0]:
#                 potential_repeat.append(
#                     (reduced_formula_dict[reduced_formula], cif_name, reduced_formula)
#                 )
#                 break
#     reduced_formula_dict.setdefault(reduced_formula, [])
#     reduced_formula_dict[reduced_formula].append(cif_name)
# print(
#     counter,
#     len(base_shg_eff_dataset),
#     len(reduced_formula_dict),
#     len(set(base_shg_eff_dataset.keys())),
# )
# potential_repeat

In [None]:
# manual check in MP (not removing anything)
# for l, mp_id, c in potential_repeat:
#     with MPRester(os.environ["MP_API_KEY"]) as mpr:
#         # gives deprecation warnings now
#         # doc = mpr.materials.provenance.get_data_by_id(mp_id, fields=["database_IDs"])
#         # DeprecationWarning: get_data_by_id is deprecated and will be removed soon. Please use the search method instead.
#         # if "icsd" in doc.database_IDs:
#         #     print(doc.database_IDs["icsd"])
#         #     for icsd_id in l:
#         #         if icsd_id in doc.database_IDs["icsd"]:
#         #             print(icsd_id, mp_id)
#         # else:
#         #     print("no icsd for", mp_id)

#         doc = mpr.materials.provenance.search(mp_id, fields=["database_IDs"])
#         # print(doc[0].database_IDs)
#         if "icsd" in doc[0].database_IDs:
#             print(doc[0].database_IDs["icsd"])
#             for icsd_id in l:
#                 if icsd_id in doc[0].database_IDs["icsd"]:
#                     print(icsd_id, mp_id)
#         else:
#             print("no icsd for", mp_id)

In [None]:
# find structures with teh same reduced formulas
reduced_formula_dict: dict[str, list[str]] = dict()
for cif_name, v in base_shg_eff_dataset.items():
    # obtain reduced formula
    structure = Structure.from_dict(v["structure"])
    reduced_formula = structure.reduced_formula
    reduced_formula_dict.setdefault(reduced_formula, [])
    reduced_formula_dict[reduced_formula].append(cif_name)
print(
    len(base_shg_eff_dataset),
    len(reduced_formula_dict),
    len(set(base_shg_eff_dataset.keys())),
)

In [None]:
def plot_struct(s:Structure, **kwargs):
    atoms = s.to_ase_atoms()
    write("tmp.png", atoms, **kwargs)
    plt.imshow(Image.open("tmp.png"))
    plt.show()

In [None]:
def is_same_struct(s1: Structure, s2: Structure, **kwargs):
    matcher = StructureMatcher(**kwargs)
    return matcher.fit(s1, s2)

In [None]:
# check duplicate based on reduced formula and then on structure matcher


def get_dumplicates_by_structure_matcher(
    is_same_struct_fn: Callable,
    local_reduced_formula_dict: dict[str, list[str]] = reduced_formula_dict,
    do_plots=False,
):
    # cif_id -> similarity group index (identity index)
    cif2iid: dict[str, int] = {}
    # identity index -> list[cif_ids]
    iid2cifs: dict[int, list[str]] = {}
    potential_duplicates = {}
    iid_increment = 0
    stats_total_steps = 0
    for reduced_formula, similar_cifs in tqdm(local_reduced_formula_dict.items()):
        if len(similar_cifs) > 1:
            potential_duplicates.setdefault(reduced_formula, [])
            potential_duplicates[reduced_formula].extend(similar_cifs)
            print(f"{reduced_formula=}")
            for pair in list(itertools.combinations(similar_cifs, 2)):
                stats_total_steps += 1
                cif_0 = pair[0]
                cif_1 = pair[1]
                structure0 = Structure.from_dict(
                    base_shg_eff_dataset[cif_0]["structure"]
                )
                structure1 = Structure.from_dict(
                    base_shg_eff_dataset[cif_1]["structure"]
                )
                if is_same_struct_fn(structure0, structure1):
                    iid_increment += 1
                    print(cif_0, "=", cif_1)
                    print(structure0.formula, structure1.formula)
                    if do_plots:
                        plot_struct(structure0)
                        plot_struct(structure1)
                    # set equal identity idxes for two structures
                    # and for all structures from the same identity groups
                    minimal_iid = iid_increment
                    relevant_iids: list[int] = [minimal_iid]
                    if cif_0 in cif2iid:
                        iid = cif2iid.get(cif_0, minimal_iid)
                        relevant_iids.append(iid)
                    if cif_1 in cif2iid:
                        iid = cif2iid.get(cif_1, minimal_iid)
                        relevant_iids.append(iid)
                    relevant_iids = list(set(relevant_iids))
                    minimal_iid = min(relevant_iids)
                    merged_cifs = [cif_0, cif_1]
                    # merge cifs from relevant iids to minimal_iid
                    for iid in relevant_iids:
                        if iid not in iid2cifs:
                            continue
                        for cif in iid2cifs[iid]:
                            merged_cifs.append(cif)
                            cif2iid[cif] = minimal_iid
                        # remove cifs from old iid
                        iid2cifs[iid] = []
                    merged_cifs = list(set(merged_cifs))
                    iid2cifs[minimal_iid] = merged_cifs
                    cif2iid.setdefault(cif_0, minimal_iid)
                    cif2iid.setdefault(cif_1, minimal_iid)
                    iid_increment += 1
                else:
                    # assign new identity idxes in case of new structure
                    # else pass
                    cif2iid.setdefault(cif_0, iid_increment)
                    cur_iid = cif2iid[cif_0]
                    iid2cifs.setdefault(cur_iid, [])
                    iid2cifs[cur_iid].append(cif_0)
                    iid2cifs[cur_iid] = list(set(iid2cifs[cur_iid]))
                    if cur_iid == iid_increment:
                        iid_increment += 1

                    # assign new identity idxes in case of new structure
                    # else pass
                    cif2iid.setdefault(cif_1, iid_increment)
                    cur_iid = cif2iid[cif_1]
                    iid2cifs.setdefault(cur_iid, [])
                    iid2cifs[cur_iid].append(cif_1)
                    iid2cifs[cur_iid] = list(set(iid2cifs[cur_iid]))
                    if cur_iid == iid_increment:
                        iid_increment += 1

                    print(cif_0, "!=", cif_1)
                    print(structure0.formula, structure1.formula)
                    if do_plots:
                        plot_struct(structure0)
                        plot_struct(structure1)
            # for cif_id in similar_cif_ids:
    return cif2iid, iid2cifs, potential_duplicates, stats_total_steps


cif2iid, iid2cifs, potential_duplicates, stats_total_steps = (
    get_dumplicates_by_structure_matcher(is_same_struct)
)
stats_total_steps

In [None]:
print(f"{len(potential_duplicates)=}")
print(f"{sum([len(v) for k,v in potential_duplicates.items()])=}")
potential_duplicates_list = sorted([v for v in potential_duplicates.values()])
potential_duplicates_list

In [None]:
duplicates = {k: v for k, v in iid2cifs.items() if len(v) > 1}
print(f"{len(duplicates)=}")
print(f"{sum([len(v) for k,v in duplicates.items()])=}")
duplicates_list = sorted([v for v in duplicates.values()])
duplicates_list

In [None]:
# example from https://docs.materialsproject.org/methodology/materials-methodology/related-materials
# which separates mp-561224 and mp-2739 successfully

import numpy as np
from mp_api.client import MPRester
from matminer.featurizers.site import CrystalNNFingerprint
from matminer.featurizers.structure import SiteStatsFingerprint
import os
MP_API_KEY = os.environ["MP_API_KEY"]

with MPRester(MP_API_KEY) as mpr:

    # Get structures
    diamond = mpr.get_structure_by_material_id("mp-66")
    gaas = mpr.get_structure_by_material_id("mp-2534")
    rocksalt = mpr.get_structure_by_material_id("mp-22862")
    perovskite = mpr.get_structure_by_material_id("mp-5827")
    spinel_caco2s4 = mpr.get_structure_by_material_id("mp-1408976")
    spinel_sicd2O4 = mpr.get_structure_by_material_id("mp-560842")
    tio2_1 = mpr.get_structure_by_material_id("mp-2739")
    tio2_2 = mpr.get_structure_by_material_id("mp-561224")

# Calculate structure fingerprints
ssf = SiteStatsFingerprint(
    CrystalNNFingerprint.from_preset('ops', distance_cutoffs=None, x_diff_weight=0),
    stats=('mean', 'std_dev', 'minimum', 'maximum'))
v_diamond = np.array(ssf.featurize(diamond))
v_gaas = np.array(ssf.featurize(gaas))
v_rocksalt = np.array(ssf.featurize(rocksalt))
v_perovskite = np.array(ssf.featurize(perovskite))
v_spinel_caco2s4 = np.array(ssf.featurize(spinel_caco2s4))
v_spinel_sicd2O4 = np.array(ssf.featurize(spinel_sicd2O4))

# Print out distance between structures
print('Distance between diamond and GaAs: {:.4f}'.format(np.linalg.norm(v_diamond - v_gaas)))
print('Distance between diamond and rocksalt: {:.4f}'.format(np.linalg.norm(v_diamond - v_rocksalt)))
print('Distance between diamond and perovskite: {:.4f}'.format(np.linalg.norm(v_diamond - v_perovskite)))
print('Distance between rocksalt and perovskite: {:.4f}'.format(np.linalg.norm(v_rocksalt - v_perovskite)))
print('Distance between Ca(CoS2)2-spinel and Si(CdO2)2-spinel: {:.4f}'.format(np.linalg.norm(v_spinel_caco2s4 - v_spinel_sicd2O4)))
# print(is_same_struct(diamond,gaas))
# print(is_same_struct(tio2_1,tio2_2))
print(np.exp(-np.linalg.norm(np.array(ssf.featurize(tio2_1)) - np.array(ssf.featurize(tio2_2)))) * 100)
# Print out structure similarity percentages
print('Diamond and GaAs Similarity: {:.2f}%'.format(np.exp(-np.linalg.norm(v_diamond - v_gaas)) * 100)) # = 100% for some reason
print('Diamond and rocksalt Similarity: {:.2f}%'.format(np.exp(-np.linalg.norm(v_diamond - v_rocksalt)) * 100))
print('Diamond and perovskite Similarity: {:.2f}%'.format(np.exp(-np.linalg.norm(v_diamond - v_perovskite)) * 100))
print('Rocksalt and perovskite Similarity: {:.2f}%'.format(np.exp(-np.linalg.norm(v_rocksalt - v_perovskite)) * 100))
print('Ca(CoS2)2-spinel and Si(CdO2)2-spinel Similarity: {:.2f}%'.format(np.exp(-np.linalg.norm(v_spinel_caco2s4 - v_spinel_sicd2O4)) * 100))

In [None]:
# a = Structure.from_dict(base_shg_eff_dataset["icsd-9290"]["structure"])
# b = Structure.from_dict(base_shg_eff_dataset["icsd-193973"]["structure"])
# print(
#     np.exp(-np.linalg.norm(np.array(ssf.featurize(a)) - np.array(ssf.featurize(b))))
# )

In [None]:

# from nvcs import viewer
# view = viewer(tio2_2)
# image = view.render_image()
# view2 = viewer(tio2_1)
# image_2 = view2.render_image()
# view2 = viewer(tio2_1)
# image_2 = view2.render_image()

In [None]:

# example 

# atoms = tio2_2.to_ase_atoms()
# view(atoms)
# write("tmp.png", atoms)
# write("tmp.png", atoms, rotation=("0x,0y,0z"))
# # write("tmp.png", atoms, radii=0.5, rotation=("0x,0y,0z"))

In [None]:
do_plots = False
for identity_id, t in duplicates.items():
    print(f"{identity_id=}")
    imgs_list = []
    for position, cif in enumerate(t):
        s = Structure.from_dict(base_shg_eff_dataset[cif]["structure"])
        atoms = s.to_ase_atoms()
        write("tmp.png", atoms, radii=0.5)
        print(f"{cif=}")
        img = Image.open("tmp.png")
        imgs_list.append(np.array(img))
        if do_plots:
            plt.imshow(img)
            plt.show()
    if len(t) > 1:
        for position, cif in enumerate(t[1:]):
            # check equality of plots
            if imgs_list[position - 1].shape == imgs_list[position].shape:
                img_diff = np.max(np.abs(imgs_list[position] - imgs_list[position - 1]))
                print(f"{img_diff=}")
                print(f"{img_diff > 0=}")
            else:
                img_diff = "high"
                print(f"{img_diff=}")
                img_diff = 1
                print(f"{img_diff > 0=}")
    print("---------------------------------")

In [None]:
smaller_reduced_formula_dict = {
    Structure.from_dict(
        base_shg_eff_dataset[cif_list[0]]["structure"]
    ).reduced_formula: cif_list
    for cif_list in duplicates.values()
}

In [None]:
smaller_reduced_formula_dict

In [None]:
# check duplicate based on reduced formula and then on structure matcher with CrystalNNFingerprint
# redefine similarity function
def is_same_struct_slow(s1: Structure, s2: Structure, **kwargs):
    matcher = StructureMatcher(**kwargs)
    is_same_based_on_crystal_fingerprint = (
        np.exp(
            -np.linalg.norm(np.array(ssf.featurize(s1)) - np.array(ssf.featurize(s2)))
        )
        > 0.999999999
    )

    return matcher.fit(s1, s2) and is_same_based_on_crystal_fingerprint


cif2iid_smol, iid2cifs_smol, potential_duplicates_smol, stats_total_steps = (
    get_dumplicates_by_structure_matcher(
        is_same_struct_slow, smaller_reduced_formula_dict, do_plots=False
    )
)
stats_total_steps

In [None]:
print(f"{len(potential_duplicates_smol)=}")
print(f"{sum([len(v) for k,v in potential_duplicates_smol.items()])=}")
potential_duplicates_list_smol = sorted([v for v in potential_duplicates_smol.values()])
potential_duplicates_list_smol

In [None]:
duplicates_smol = {k: v for k, v in iid2cifs_smol.items() if len(v) > 1}
print(f"{len(duplicates_smol)=}")
print(f"{sum([len(v) for k,v in duplicates_smol.items()])=}")
duplicates_list_smol = sorted([v for v in duplicates_smol.values()])
duplicates_list_smol

In [None]:
do_plots = True
for identity_id, t in duplicates_smol.items():
    print(f"{identity_id=}")
    imgs_list = []
    for position, cif in enumerate(t):
        s = Structure.from_dict(base_shg_eff_dataset[cif]["structure"])
        atoms = s.to_ase_atoms()
        write("tmp.png", atoms, radii=0.5)
        shg = base_shg_eff_dataset[cif]["shg"]
        print(f"{cif=}, {shg=}")
        img = Image.open("tmp.png")
        imgs_list.append(np.array(img))
        if do_plots:
            plt.imshow(img)
            plt.show()
    if len(t) > 1:
        for position, cif in enumerate(t[1:]):
            # check equality of plots
            if imgs_list[position - 1].shape == imgs_list[position].shape:
                img_diff = np.max(np.abs(imgs_list[position] - imgs_list[position - 1]))
                print(f"{img_diff=}")
                print(f"{img_diff > 0=}")
            else:
                img_diff = "high"
                print(f"{img_diff=}")
                img_diff = 1
                print(f"{img_diff > 0=}")
    print("---------------------------------")

### Remove duplicates


In [None]:
# TODO

In [None]:
# manual check
s1 = Structure.from_dict(base_shg_eff_dataset["icsd-65763"]["structure"])
s2 = Structure.from_dict(base_shg_eff_dataset["icsd-171421"]["structure"])

is_same_struct(s1, s2), is_same_struct_slow(s1, s2)

In [None]:
plt.hist(atom_count, bins=20)

In [None]:
draw_line = True
atom_count_ccdc = []
atom_count_icsd = []
atom_count_mp = []
for cif_name, v in base_shg_eff_dataset.items():
    if "icsd" in cif_name:
        atom_count_icsd.append(len(Structure.from_dict(v["structure"])))
    if "ccdc" in cif_name:
        atom_count_ccdc.append(len(Structure.from_dict(v["structure"])))
    if "mp" in cif_name:
        atom_count_mp.append(len(Structure.from_dict(v["structure"])))

bins = np.linspace(0, 1200, 100)

n, x, _ = plt.hist(
    atom_count_ccdc,
    bins=bins,
    alpha=0.5,
    fc=colors[0],
)
if draw_line:
    plt.plot((x[1:] + x[:-1]) / 2, n, label="atom_count_ccdc", c=colors[0])


n, x, _ = plt.hist(
    atom_count_icsd,
    bins=bins,
    alpha=0.5,
    fc=colors[1],
)
if draw_line:
    plt.plot((x[1:] + x[:-1]) / 2, n, label="atom_count_icsd", c=colors[1])


n, x, _ = plt.hist(
    atom_count_mp,
    bins=bins,
    alpha=0.5,
    fc=colors[2],
)
if draw_line:
    plt.plot((x[1:] + x[:-1]) / 2, n, label="atom_count_mp", c=colors[2])

plt.legend()
plt.yscale("log")
plt.show()
len(atom_count_ccdc)

In [None]:
import matplotlib.pyplot as plt

plt.hist(atom_count, bins=20)

In [None]:
for cif_name, v in base_shg_eff_dataset.items():
    print(cif_name)

In [None]:
big_crystals_threshold = 1000

for cif_name, v in base_shg_eff_dataset.items():
    if len(Structure.from_dict(v["structure"])) > big_crystals_threshold:
        print(cif_name)

In [None]:
for cif_name in list(base_shg_eff_dataset.keys()):
    if (
        len(Structure.from_dict(base_shg_eff_dataset[cif_name]["structure"]))
        > big_crystals_threshold
    ):
        base_shg_eff_dataset.pop(cif_name)

In [None]:
base_shg_eff_dataset_filepath = "../../final_data/base_dataset_of_eff_shg.json"
dataset_base_plus_part_of_QMOF_dataset_filepath = (
    "../../final_data/dataset_base_plus_part_of_QMOF.json"
)

In [None]:
# save_to_json(base_shg_eff_dataset, base_shg_eff_dataset_filepath)

In [None]:
# base_shg_eff_dataset = load_from_json(base_shg_eff_dataset_filepath)
dataset_base_plus_part_of_QMOF_dataset = load_from_json(
    dataset_base_plus_part_of_QMOF_dataset_filepath
)

In [None]:
# targets_list = [v["shg"] for v in base_shg_eff_dataset.values()]
targets_list = [v["shg"] for v in dataset_base_plus_part_of_QMOF_dataset.values()]

# plt.title("shg_eff hist for base dataset")
plt.title("d_KP distribution")
plt.hist(targets_list, bins=30)

In [None]:
draw_line = True
target_ccdc = []
target_icsd = []
target_mp = []
target_qmof = []
# for k, v in base_shg_eff_dataset.items():
for cif_name, v in dataset_base_plus_part_of_QMOF_dataset.items():
    if "icsd" in cif_name:
        target_icsd.append(v["shg"])
    if "ccdc" in cif_name:
        target_ccdc.append(v["shg"])
    if "mp" in cif_name:
        target_mp.append(v["shg"])
    if "qmof" in cif_name:
        target_qmof.append(v["shg"])
bins = np.linspace(0, 200, 100)

n, x, _ = plt.hist(
    target_ccdc,
    bins=bins,
    alpha=0.5,
    fc=colors[0],
)
if draw_line:
    plt.plot((x[1:] + x[:-1]) / 2, n, label="CCDC", c=colors[0])


n, x, _ = plt.hist(
    target_icsd,
    bins=bins,
    alpha=0.5,
    fc=colors[1],
)
if draw_line:
    plt.plot((x[1:] + x[:-1]) / 2, n, label="ICSD", c=colors[1])


n, x, _ = plt.hist(
    target_mp,
    bins=bins,
    alpha=0.5,
    fc=colors[2],
)
if draw_line:
    plt.plot((x[1:] + x[:-1]) / 2, n, label="MP", c=colors[2])

n, x, _ = plt.hist(
    target_qmof,
    bins=bins,
    alpha=0.5,
    fc=colors[3],
)
if draw_line:
    plt.plot((x[1:] + x[:-1]) / 2, n, label="QMOF", c=colors[3])

plt.xlabel("$d_{KP}$, pm/V")
plt.ylabel("count")
plt.legend()
plt.title("$d_{KP}$ distribution")
plt.yscale("log")
plt.show()

In [None]:
targets_properties = dict(
    mean=np.mean(targets_list), median=np.median(targets_list), std=np.std(targets_list)
)
targets_properties

In [None]:
from utils.experiment_tracking import calculate_metrics


print(
    "mean",
    calculate_metrics(targets_list, np.ones_like(targets_list) * np.mean(targets_list)),
)

print(
    "median",
    calculate_metrics(
        targets_list, np.ones_like(targets_list) * np.median(targets_list)
    ),
)

In [None]:
cifs_shg_abinit: dict = load_from_json(abinit_filepath)
base_max_abs_shg_dataset = {}
for cif_name, v in cifs_shg_abinit.items():
    # print(list(v["data"]["data"]["dinf"].keys()))
    base_max_abs_shg_dataset[cif_name] = dict(
        structure=v["structure"],
        shg=np.max(np.abs(dict2tensor(v["data"]["data"]))),
    )

In [None]:
# run parse_shg and cmp_abinit_vs_amcm first
cifs_shg_amcm: dict = load_from_json("../../raw_data/private/am_cm/amcm.json")
m = 0
for cif_name, v in cifs_shg_amcm.items():
    # print(list(v["data"]["data"]["dinf"].keys()))
    shg = np.max(np.abs(dict2tensor(v["data"]["data"])))

    print(cif_name)
    try:
        if shg < 250:
            base_max_abs_shg_dataset[cif_name] = dict(
                structure=v["structure"],
                shg=shg,
            )
            m = max(m, shg)
    except Exception as e:
        print(cif_name, e)

In [None]:
draw_line = True
target_ccdc = []
target_icsd = []
target_mp = []
for cif_name, v in base_max_abs_shg_dataset.items():
    if "icsd" in cif_name:
        target_icsd.append(v["shg"])
    if "ccdc" in cif_name:
        target_ccdc.append(v["shg"])
    if "mp" in cif_name:
        target_mp.append(v["shg"])

bins = np.linspace(0, 200, 100)

n, x, _ = plt.hist(
    target_ccdc,
    bins=bins,
    alpha=0.5,
    fc=colors[0],
)
if draw_line:
    plt.plot((x[1:] + x[:-1]) / 2, n, label="target_ccdc", c=colors[0])


n, x, _ = plt.hist(
    target_icsd,
    bins=bins,
    alpha=0.5,
    fc=colors[1],
)
if draw_line:
    plt.plot((x[1:] + x[:-1]) / 2, n, label="target_icsd", c=colors[1])


n, x, _ = plt.hist(
    target_mp,
    bins=bins,
    alpha=0.5,
    fc=colors[2],
)
if draw_line:
    plt.plot((x[1:] + x[:-1]) / 2, n, label="target_mp", c=colors[2])

plt.legend()
plt.yscale("log")
plt.title("shg_max_abs hist for base_max_abs_shg_dataset")
plt.show()

In [None]:
targets_list = [v["shg"] for v in base_max_abs_shg_dataset.values()]

print(
    "mean",
    calculate_metrics(targets_list, np.ones_like(targets_list) * np.mean(targets_list)),
)

print(
    "median",
    calculate_metrics(
        targets_list, np.ones_like(targets_list) * np.median(targets_list)
    ),
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

data = np.abs(targets_list) + 1e-100

transformed_data, lambda_value = stats.boxcox(data)

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.hist(data, bins=10, density=True, alpha=0.6, color="blue", edgecolor="black")
plt.title("Original Data Density")
plt.xlabel("Values")
plt.ylabel("Density")

plt.subplot(1, 2, 2)
plt.hist(
    transformed_data,
    bins=10,
    density=True,
    alpha=0.6,
    color="orange",
    edgecolor="black",
)
plt.title("Box-Cox Transformed Data Density")
plt.xlabel("Values")
plt.ylabel("Density")

plt.tight_layout()
plt.show()

print(f"Lambda value for Box-Cox transformation: {lambda_value}")

In [None]:
max_abs_shg_dataset = load_from_json("../../final_data/dataset_of_max_abs_shg.json")
targets_list = [v["shg"] for v in max_abs_shg_dataset.values()]

print(
    "mean",
    calculate_metrics(targets_list, np.ones_like(targets_list) * np.mean(targets_list)),
)

print(
    "median",
    calculate_metrics(
        targets_list, np.ones_like(targets_list) * np.median(targets_list)
    ),
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats

data = np.abs(targets_list) + 1e-100

transformed_data, lambda_value = stats.boxcox(data)

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.hist(data, bins=10, density=True, alpha=0.6, color="blue", edgecolor="black")
plt.title("Original Data Density")
plt.xlabel("Values")
plt.ylabel("Density")

plt.subplot(1, 2, 2)
plt.hist(
    transformed_data,
    bins=10,
    density=True,
    alpha=0.6,
    color="orange",
    edgecolor="black",
)
plt.title("Box-Cox Transformed Data Density")
plt.xlabel("Values")
plt.ylabel("Density")

plt.tight_layout()
plt.show()

print(f"Lambda value for Box-Cox transformation: {lambda_value}")