In [1]:
from recommender.core import RecommenderSystem, OccupationData
import gzip
import json
import joblib
from gensim.models import Word2Vec
from pymatgen.core import Element

In [2]:
json_file_path = 'data/occupancy_data.json.gz'

model = Word2Vec.load(f'models/word2vec.model')
embedding = model.wv

clf = joblib.load(f'models/decision_tree.joblib')
distance_threshold = clf.tree_.threshold[0]

In [3]:
ions = ['Xe', 'Kr', 'Ar', 'Ne', 'He', 'I', 'Br', 'Cl', 'F', 'Te', 'Se', 'S', 'O', 'Bi', 'Sb', 'As', 'P', 'N', 'Pb', 'Sn', 'Ge', 'Si', 'C', 'Tl', 'In', 'Ga', 'Al', 'B', 'Hg', 'Cd', 'Zn', 'Au', 'Ag', 'Cu', 'Pt', 'Pd', 'Ni', 'Ir', 'Rh', 'Co', 'Os', 'Ru', 'Fe', 'Re', 'Tc', 'Mn', 'W', 'Mo', 'Cr', 'Ta', 'Nb', 'V', 'Hf', 'Zr', 'Ti', 'Pu', 'Np', 'U', 'Pa', 'Th', 'Ac', 'Lu', 'Yb', 'Tm', 'Er', 'Ho', 'Dy', 'Tb', 'Gd', 'Eu', 'Sm', 'Pm', 'Nd', 'Pr', 'Ce', 'La', 'Y', 'Sc', 'Ba', 'Sr', 'Ca', 'Mg', 'Be', 'Cs', 'Rb', 'K', 'Na', 'Li', 'H']
ion_forbidden_list = [ion for ion in ions if (Element(ion).Z > 83 or 
                                              ion in ['Tc', 'Pm'] or
                                              Element(ion).group == 18)]
ion_forbidden_list 

['Xe', 'Kr', 'Ar', 'Ne', 'He', 'Tc', 'Pu', 'Np', 'U', 'Pa', 'Th', 'Ac', 'Pm']

In [4]:
occupation_data = OccupationData(json_file_path)

In [5]:
kagome_id = 1514602

kagome_AM = occupation_data.get_AM_from_OQMD_id(kagome_id)

In [6]:
kagome_AM.equivalent_sites_indexes

[0, 1, 1, 3, 1, 1, 6, 6, 6]

In [7]:
rs = RecommenderSystem(occupation_data=occupation_data, 
                       embedding=embedding, 
                       distance_threshold=distance_threshold, 
                       ion_forbidden_list=ion_forbidden_list)

In [8]:
rs.get_recommendation_for_AM(kagome_AM)

{(191, 9, (1, '6/mmm'), 0), site index: 0: [('Rb',
   0.018323421478271484,
   False),
  ('K', 0.02188342809677124, False),
  ('Cs', 0.02744007110595703, False),
  ('Tl', 0.11811000108718872, False),
  ('Na', 0.13025516271591187, False),
  ('Ba', 0.261111319065094, True)],
 (191, 9, (1, '6/mmm'), 0), site index: 3: [('Sb',
   0.003526031970977783,
   False),
  ('Bi', 0.028451979160308838, False),
  ('As', 0.07530736923217773, False),
  ('P', 0.15748435258865356, False),
  ('Ir', 0.24137479066848755, True),
  ('Pt', 0.28948670625686646, True),
  ('Ta', 0.2916148900985718, True),
  ('Re', 0.3077079653739929, True),
  ('Rh', 0.3191390633583069, True),
  ('Au', 0.3234304189682007, True),
  ('Os', 0.3247630000114441, True)],
 (191, 9, (3, 'mmm'), 0), site index: 6: [('V', 0.03470265865325928, False),
  ('Nb', 0.07285183668136597, False),
  ('Ta', 0.08547395467758179, False),
  ('Mn', 0.08579069375991821, False),
  ('Ti', 0.09940123558044434, True),
  ('Cr', 0.10330325365066528, False),
  ('

In [9]:
rs.get_recommendation_for_site('221_5_1(m-3m)_1(m-3m)_3(4/mmm)_0[2:3(4/mmm)]')

[('F', 0.0973658561706543, False),
 ('Cl', 0.10440915822982788, False),
 ('Br', 0.12838470935821533, False),
 ('H', 0.16112929582595825, False),
 ('O', 0.17956966161727905, False),
 ('I', 0.2550504207611084, False)]