# Search for chemical compositions of Superconductors

In [1]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error 

In [2]:
import warnings 
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore', category=DeprecationWarning)

In [3]:
import numpy as np
import pandas as pd
import time
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import cm
import seaborn as sb
from IPython.display import display,HTML
import ipywidgets as widgets
import random

# plt.style.use('dark_background')
# from jupyterthemes import jtplot
# jtplot.style(theme='monokai', context='notebook', ticks=True, grid=True)
import matplotlib as mpl
mpl.rcParams['figure.facecolor'] = '#222222'
mpl.rcParams['axes.facecolor'] = '#222222'
mpl.rcParams['axes.labelcolor'] = '#edece9'
mpl.rcParams['xtick.color'] = '#edece9'
mpl.rcParams['ytick.color'] = '#edece9'
mpl.rcParams['text.color'] = '#edece9' 

In [None]:
### Load trained Random Forest Regressor model

In [4]:
import zipfile
import os
import pickle

filebase = "rf_regressor"

with zipfile.ZipFile(filebase+".zip") as zipref:
    zipref.extractall()

with open(filebase+".dat", 'rb') as f:
    model = pickle.load(f)

os.remove(filebase+".dat")

# with open("../../sc_data_inc/"+filebase+".dat", 'rb') as f:
#     model = pickle.load(f)

In [None]:
### Initialize input features 

In [5]:
data_head = pd.read_csv("SC_data_frac.csv",nrows=0) 
data_Tc = pd.read_csv("SC_data_frac.csv",usecols=["Critical Temperature"]) 

In [6]:
col_bool = [col[:-2]+"_bool" for col in data_head if col.endswith("_x")]

data_headr = ["Scaled T_c"]
data_headr.extend(col_bool)

data_Tc = data_Tc[data_Tc["Critical Temperature"]>0]
data_std = data_Tc.std()

In [54]:
### Initialize variables for interactive ploting 

In [147]:
T_range = widgets.IntRangeSlider(value=[1, 100],
                                 min=0,
                                 max=150,
                                 step=1,
                                 description='Critical Temperature range:',
                                 disabled=False,
                                 continuous_update=False,
                                 orientation='horizontal',
                                 readout=True,
                                 readout_format='d',
                                 style={'description_width': 'initial'},
                                  layout={'width':'800px'}
                                )

sym_dict = {}
for sym in col_bool:
    sym_dict[sym[:-5]] = widgets.ToggleButton(value=bool(np.random.choice([True,False])),
                                              description=sym[:-5],
                                              disabled=False,
                                              button_style='',
                                              indent=False, layout={'width':'45px', 'height':'30px', 'align':'center'}
                                             )

changed_dict = {}
changed_dict['T_range'] = T_range
changed_dict['Update'] =  widgets.ToggleButton(value=True,
                                               description='Update',
                                               disabled=False,
                                               button_style='',
                                               indent=False, layout={'width':'45px', 'height':'30px', 'align':'center'}
                                              )


In [148]:
sel_all = widgets.Button(description="Select All")
def select_all_elem(button_click):
    for sym in col_bool:
        sym_dict[sym[:-5]].value = True
    changed_dict['Update'].value = not(changed_dict['Update'].value)
sel_all.on_click(select_all_elem)

unsel_all = widgets.Button(description="Unselect All")
def unselect_all_elem(button_click):
    for sym in col_bool:
        sym_dict[sym[:-5]].value = False
    changed_dict['Update'].value = not(changed_dict['Update'].value)
unsel_all.on_click(unselect_all_elem)
    
inv_all = widgets.Button(description="Invert Selection")
def invert_all_elem(button_click):
    for sym in col_bool:
        sym_dict[sym[:-5]].value = not(sym_dict[sym[:-5]].value)
    changed_dict['Update'].value = not(changed_dict['Update'].value)
inv_all.on_click(invert_all_elem)

rand_all = widgets.Button(description="Randomize Selection")
def randomize_elem(button_click):
    p = np.random.uniform(0,1)
    for sym in col_bool:
        sym_dict[sym[:-5]].value = bool(np.random.choice([True,False],p=[p,1-p]))
    changed_dict['Update'].value = not(changed_dict['Update'].value)
rand_all.on_click(randomize_elem)

update_all = widgets.Button(description="Update Plot")
def update_plot(button_click):
    changed_dict['Update'].value = not(changed_dict['Update'].value)
update_all.on_click(update_plot)


In [162]:

def show_plot(**kwargs):
    global fig,ax
    try:
        plt.close();
        fig = plt.figure(figsize=(12,8));
        ax = fig.add_subplot(111);
        # ax.cla()
        # fig.clf()
    except:
        fig = plt.figure(figsize=(12,8));
        ax = fig.add_subplot(111);
    
    test_X = pd.DataFrame(columns=data_headr)
    t_lim = kwargs['T_range']
    # for i,t in enumerate(np.linspace(t_lim[0],t_lim[1],101)):
    #     test_X.loc[i,"Scaled T_c"] = t/data_std["Critical Temperature"]
    test_X.loc[:,"Scaled T_c"] = np.linspace(t_lim[0],t_lim[1],101)/data_std["Critical Temperature"]
    for sym in col_bool:
        test_X.loc[:, sym] = int(sym_dict[sym[:-5]].value)
    
    Y_pred = model.predict(test_X)
    non_zero_elems = []
    zero_elems = []
    for i,sym in enumerate(col_bool):
        if (np.sum(Y_pred[:,i])>0):
            if (sym_dict[sym[:-5]].value>0):
                non_zero_elems.append([i,sym])
            else:
                zero_elems.append([i,sym])
                
    color_hsv = cm.get_cmap('hsv', len(non_zero_elems)+len(zero_elems)+1)
    ax.cla()
    for j,[i,sym] in enumerate(non_zero_elems):
        ax.plot(data_std["Critical Temperature"]*test_X["Scaled T_c"],100*Y_pred[:,i],c=color_hsv(j),label=sym[:-5]+" ~ "+str('%.2f'%(Y_pred[-1,i]*100))+"%",linestyle='-');
    for j,[i,sym] in enumerate(zero_elems):
        ax.plot(data_std["Critical Temperature"]*test_X["Scaled T_c"],100*Y_pred[:,i],c=color_hsv(j+len(non_zero_elems)),label=sym[:-5]+" ~ "+str('%.2f'%(Y_pred[-1,i]*100))+"%",linestyle=':');
    if len(non_zero_elems)+len(zero_elems)>0:
        ax.legend(title="Composition at Max($\mathregular{T_c}$)", loc='upper left', bbox_to_anchor=(1.0, 1.0), fancybox=True, framealpha=0.1, ncol=1);
    ax.set_xlabel("Critical Temperature ($\mathregular{T_c}$) [K]");
    ax.set_ylabel("Predicted proportion of elements [%]");
    fig.canvas.draw();
    plt.show()
    # return fig
#     display(fig)


In [163]:
# periodict = {}
def gb(k):
    
    if k in sym_dict:
        return sym_dict[k]
    # elif k in periodict: 
    #     return periodict[k]
    elif k=='->':
        return widgets.Button(description='',
                              disabled=True,
                              indent=False, layout={'width':'45px', 'height':'30px', 'align':'center'},
                              icon='arrow-right',
                              style={'button_color':"#CC0000"}
                             )
    elif k=='':
        return widgets.Button(description='',
                              disabled=True,
                              indent=False, layout={'width':'45px', 'height':'30px', 'align':'center'},
                              style={'button_color':"#000000"}
                             )
    else:
        return widgets.Button(description=k,
                              disabled=True,
                              indent=False, layout={'width':'45px', 'height':'30px', 'align':'center'},
                              style={'button_color':"#0000CC"}
                             )
arg_dict = []
# row 1
arg_dict.append(gb("H"))
arg_dict.append(gb("D"))
arg_dict.append(gb("T"))
for i in range(14):
    arg_dict.append(gb(""))
arg_dict.append(gb("He"))

# row 2
arg_dict.append(gb("Li"))
arg_dict.append(gb("Be"))
for i in range(10):
    arg_dict.append(gb(""))
arg_dict.append(gb("B"))
arg_dict.append(gb("C"))
arg_dict.append(gb("N"))
arg_dict.append(gb("O"))
arg_dict.append(gb("F"))
arg_dict.append(gb("Ne"))

# row 3
arg_dict.append(gb("Na"))
arg_dict.append(gb("Mg"))
for i in range(10):
    arg_dict.append(gb(""))
arg_dict.append(gb("Al"))
arg_dict.append(gb("Si"))
arg_dict.append(gb("P"))
arg_dict.append(gb("S"))
arg_dict.append(gb("Cl"))
arg_dict.append(gb("Ar"))

# row 4
arg_dict.append(gb("K"))
arg_dict.append(gb("Ca"))
arg_dict.append(gb("Sc"))
arg_dict.append(gb("Ti"))
arg_dict.append(gb("V"))
arg_dict.append(gb("Cr"))
arg_dict.append(gb("Mn"))
arg_dict.append(gb("Fe"))
arg_dict.append(gb("Co"))
arg_dict.append(gb("Ni"))
arg_dict.append(gb("Cu"))
arg_dict.append(gb("Zn"))
arg_dict.append(gb("Ga"))
arg_dict.append(gb("Ge"))
arg_dict.append(gb("As"))
arg_dict.append(gb("Se"))
arg_dict.append(gb("Br"))
arg_dict.append(gb("Kr"))

# row 5
arg_dict.append(gb("Rb"))
arg_dict.append(gb("Sr"))
arg_dict.append(gb("Y"))
arg_dict.append(gb("Zr"))
arg_dict.append(gb("Nb"))
arg_dict.append(gb("Mo"))
arg_dict.append(gb("Tc"))
arg_dict.append(gb("Ru"))
arg_dict.append(gb("Rh"))
arg_dict.append(gb("Pd"))
arg_dict.append(gb("Ag"))
arg_dict.append(gb("Cd"))
arg_dict.append(gb("In"))
arg_dict.append(gb("Sn"))
arg_dict.append(gb("Sb"))
arg_dict.append(gb("Te"))
arg_dict.append(gb("I"))
arg_dict.append(gb("Xe"))

# row 6
arg_dict.append(gb("Cs"))
arg_dict.append(gb("Ba"))
arg_dict.append(gb("Ln"))
arg_dict.append(gb("Hf"))
arg_dict.append(gb("Ta"))
arg_dict.append(gb("W"))
arg_dict.append(gb("Re"))
arg_dict.append(gb("Os"))
arg_dict.append(gb("Ir"))
arg_dict.append(gb("Pt"))
arg_dict.append(gb("Au"))
arg_dict.append(gb("Hg"))
arg_dict.append(gb("Tl"))
arg_dict.append(gb("Pb"))
arg_dict.append(gb("Bi"))
arg_dict.append(gb("Po"))
arg_dict.append(gb("At"))
arg_dict.append(gb("Rn"))

# row 7
arg_dict.append(gb("Fr"))
arg_dict.append(gb("Ra"))
arg_dict.append(gb("An"))
arg_dict.append(gb("Rf"))
arg_dict.append(gb("Db"))
arg_dict.append(gb("Sg"))
arg_dict.append(gb("Bh"))
arg_dict.append(gb("Hs"))
arg_dict.append(gb("Mt"))
arg_dict.append(gb("Ds"))
arg_dict.append(gb("Rg"))
arg_dict.append(gb("Cn"))
arg_dict.append(gb("Nh"))
arg_dict.append(gb("Fl"))
arg_dict.append(gb("Mc"))
arg_dict.append(gb("Lv"))
arg_dict.append(gb("Ts"))
arg_dict.append(gb("Og"))

# row 6-Ln
arg_dict.append(gb(""))
arg_dict.append(gb("Ln"))
arg_dict.append(gb("->"))
arg_dict.append(gb("La"))
arg_dict.append(gb("Ce"))
arg_dict.append(gb("Pr"))
arg_dict.append(gb("Nd"))
arg_dict.append(gb("Pm"))
arg_dict.append(gb("Sm"))
arg_dict.append(gb("Eu"))
arg_dict.append(gb("Gd"))
arg_dict.append(gb("Tb"))
arg_dict.append(gb("Dy"))
arg_dict.append(gb("Ho"))
arg_dict.append(gb("Er"))
arg_dict.append(gb("Tm"))
arg_dict.append(gb("Yb"))
arg_dict.append(gb("Lu"))

# row 7-An
arg_dict.append(gb(""))
arg_dict.append(gb("An"))
arg_dict.append(gb("->"))
arg_dict.append(gb("Ac"))
arg_dict.append(gb("Th"))
arg_dict.append(gb("Pa"))
arg_dict.append(gb("U"))
arg_dict.append(gb("Np"))
arg_dict.append(gb("Pu"))
arg_dict.append(gb("Am"))
arg_dict.append(gb("Cm"))
arg_dict.append(gb("Bk"))
arg_dict.append(gb("Cf"))
arg_dict.append(gb("Es"))
arg_dict.append(gb("Fm"))
arg_dict.append(gb("Md"))
arg_dict.append(gb("No"))
arg_dict.append(gb("Lr"))


In [None]:
## The plot below is interactive: 
#### After manually selecting elements update plot.

In [164]:
plot_out = widgets.interactive_output(show_plot, changed_dict)

out1 = widgets.Output(layout={'border': '0px solid black'})
# arg_dict = [sym_dict[n] for n in sorted(list(sym_dict.keys()))]
accordian = widgets.Accordion(children=[widgets.GridBox(arg_dict, layout=widgets.Layout(grid_template_columns="repeat(18, 50px)"))])
accordian.set_title(0,'Elements')

with out1:
    out1.clear_output()
    display(widgets.VBox([accordian,
                                        #   layout={'display':'flexbox','grid-column':10}),
              widgets.HBox([update_all,sel_all,unsel_all,inv_all,rand_all]), T_range, plot_out]),layout={'color':'black'})

out1

Output(layout=Layout(border='0px solid black'))