In [1]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context='talk', style='ticks',
        color_codes=True, rc={'legend.frameon': False})

%matplotlib inline

In [16]:
from pathlib import Path

inputs_dir = Path("/projects/rlmolecule/pstjohn/crystal_inputs/")
fingerprint_dir = Path("/projects/rlmolecule/pstjohn/crystal_fingerprints/")

data = pd.read_pickle(Path(inputs_dir, "20220513_outliers_removed.p"))
fps = pd.read_parquet(Path(fingerprint_dir, 'battery_fingerprints_no_chem.parquet')).dropna()

# drop a few relaxed crystals that were marked as outliers
fps = fps[~(~fps.id.isin(data[data.type == 'relax'].id) & (fps.type == 'relaxed'))]

In [17]:
from sklearn.cluster import AgglomerativeClustering

cluster = AgglomerativeClustering(affinity='cosine',
                                  distance_threshold=0.001,
                                  n_clusters=None,
                                  linkage='single',
                                  memory='/tmp/scratch/')

cluster_labels = cluster.fit_predict(fps.drop(['id', 'type'], axis=1))

In [20]:
clustered_structures = fps[['id', 'type']].copy()
clustered_structures['cluster'] = cluster_labels

In [26]:
cluster_counts = clustered_structures.groupby('cluster').type.agg(['unique', 'count'])

In [71]:
relaxed_counts = cluster_counts[cluster_counts['unique'].astype(str) == "['relaxed']"]['count'].sort_values(ascending=False)

In [73]:
relaxed_counts.head(10)

cluster
335     76
7022    68
1       53
1698    46
3683    41
3488    37
146     37
7374    36
412     35
262     34
Name: count, dtype: int64

In [75]:
(335)

0

In [48]:
clustered_structures[clustered_structures.cluster == 335]

Unnamed: 0,id,type,cluster
4908,Mg1Cl2_sg166_icsd_024755_1,relaxed,335
5404,Mg1Br2_sg166_icsd_024755_1,relaxed,335
5688,Zn1I2_sg2_icsd_051572_1,relaxed,335
8683,Mg1Cl2_sg147_icsd_058847_1,relaxed,335
9360,Zn1Br2_sg143_icsd_200593_1,relaxed,335
...,...,...,...
64349,Zn1Sb2S6_sg148_icsd_078684_1,relaxed,335
65013,Mg1Br2_sg164_icsd_020745_1,relaxed,335
66990,Zn1I2_sg164_icsd_280743_1,relaxed,335
67903,Zn1Cl2_sg71_icsd_038274_1,relaxed,335


In [85]:
(relaxed_counts >= 10).sum()

94

In [91]:
from pymatgen.core.composition import Composition

In [96]:
c = Composition('Zn1Br2')

In [98]:
len(c.elements)

2

In [99]:
new_protos.id.str.extract('([a-zA-Z0-9-]*)_').apply(lambda x: len(Composition(x).elements))

TypeError: '<' not supported between instances of 'str' and 'float'

In [92]:
new_protos.id.str.extract('([a-zA-Z0-9-]*)_').apply(Composition)

TypeError: '<' not supported between instances of 'str' and 'float'

In [83]:
new_protos = clustered_structures[
    clustered_structures.cluster.isin(relaxed_counts[relaxed_counts >= 10].index)
].sort_values('cluster', key=lambda xi: [relaxed_counts.index.get_loc(x) for x in xi]).drop('type', axis=1)

In [84]:
new_protos.to_csv('20220601_relaxed_prototypes.csv', index=False)