In [1]:
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context='talk', style='ticks',
        color_codes=True, rc={'legend.frameon': False})

%matplotlib inline

In [2]:
from pathlib import Path

inputs_dir = Path("/projects/rlmolecule/pstjohn/crystal_inputs/")
fingerprint_dir = Path("/projects/rlmolecule/pstjohn/crystal_fingerprints/")

data = pd.read_pickle(Path(inputs_dir, "20220513_outliers_removed.p"))
fps = pd.read_parquet(Path(fingerprint_dir, 'battery_fingerprints_no_chem.parquet')).dropna()

# drop a few relaxed crystals that were marked as outliers
fps = fps[~(~fps.id.isin(data[data.type == 'relax'].id) & (fps.type == 'relaxed'))]

In [3]:
from sklearn.metrics.pairwise import cosine_distances

In [4]:
relaxed_unrelaxed = cosine_distances(
    fps[fps.type == 'relaxed'].drop(['id', 'type'], axis=1),
    fps[fps.type == 'unrelaxed'].drop(['id', 'type'], axis=1)
)

In [6]:
clustered_structures = fps[fps.type == 'relaxed'][['id', 'type']].copy()
clustered_structures['min_distance'] = relaxed_unrelaxed.min(1)

In [96]:
from sklearn.cluster import AgglomerativeClustering

cluster = AgglomerativeClustering(affinity='cosine',
                                  distance_threshold=0.005,
                                  n_clusters=None,
                                  linkage='single',
                                  memory='/tmp/scratch/')

clustered_structures['cluster'] = cluster.fit_predict(fps[fps.type == 'relaxed'].drop(['id', 'type'], axis=1))

In [97]:
from pymatgen.core import Composition

In [98]:
clustered_structures['composition'] = clustered_structures.id.str.extract('([a-zA-Z0-9-]*)_')

clustered_structures['comptype'] = clustered_structures.composition.apply(
    lambda x: tuple(sorted(int(i) for i in Composition(x).as_dict().values())))

In [99]:
clusters = pd.DataFrame(clustered_structures.groupby('cluster').cluster.count()).rename(columns={'cluster': 'count'})
clusters['min_dist'] = clustered_structures.groupby('cluster').min_distance.min()
clusters['num_comptypes'] = clustered_structures.groupby('cluster').comptype.unique().apply(len)

In [100]:
clusters[(clusters.min_dist > 0.01) & 
         (clusters['num_comptypes'] >= 3)].sort_values('min_dist', ascending=False).head(10)

Unnamed: 0_level_0,count,min_dist,num_comptypes
cluster,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5248,5,0.057626,3
2926,9,0.047262,3
786,7,0.0259,3
1298,6,0.021922,3
1556,5,0.017814,3
731,21,0.01573,3


In [101]:
clustered_structures[clustered_structures.cluster == 5248]

Unnamed: 0,id,type,min_distance,cluster,composition,comptype
5663,K1Cl1_sg194_icsd_262069_1,relaxed,0.079166,5248,K1Cl1,"(1, 1)"
6209,K1La1S2_sg194_icsd_000262_1,relaxed,0.058615,5248,K1La1S2,"(1, 1, 2)"
17046,K1Br1_sg194_icsd_262069_1,relaxed,0.105858,5248,K1Br1,"(1, 1)"
32590,K1La1S2_sg194_icsd_102034_1,relaxed,0.057626,5248,K1La1S2,"(1, 1, 2)"
69884,Li3P1_sg178_icsd_080477_1,relaxed,0.087793,5248,Li3P1,"(1, 3)"


In [102]:
clustered_structures[clustered_structures.cluster == 2926]

Unnamed: 0,id,type,min_distance,cluster,composition,comptype
8708,Zn1Ge1O3_sg123_icsd_168904_1,relaxed,0.047262,2926,Zn1Ge1O3,"(1, 1, 3)"
13418,Li2Hf1N2_sg123_icsd_044353_2,relaxed,0.061429,2926,Li2Hf1N2,"(1, 2, 2)"
14121,Li2Ge1N2_sg123_icsd_044353_2,relaxed,0.057534,2926,Li2Ge1N2,"(1, 2, 2)"
15228,Zn1Zr1S3_sg123_icsd_168904_1,relaxed,0.056524,2926,Zn1Zr1S3,"(1, 1, 3)"
17760,K1Sn1Br2N1_sg140_icsd_411137_2,relaxed,0.053603,2926,K1Sn1Br2N1,"(1, 1, 1, 2)"
38480,Mg1Hg2N2_sg123_icsd_044353_2,relaxed,0.052841,2926,Mg1Hg2N2,"(1, 2, 2)"
41040,Zn1Ge1O3_sg123_icsd_168904_2,relaxed,0.061131,2926,Zn1Ge1O3,"(1, 1, 3)"
49213,K3F1O1_sg123_icsd_168904_2,relaxed,0.057575,2926,K3F1O1,"(1, 1, 3)"
51157,Zn1Sn1S3_sg123_icsd_168904_1,relaxed,0.05346,2926,Zn1Sn1S3,"(1, 1, 3)"


In [None]:
# new_protos_subset.to_csv('20220601_relaxed_prototypes.csv', index=False)