# Searching for chemical compositions of Superconductors

In [1]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error 

In [2]:
import warnings 
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore', category=DeprecationWarning)

In [3]:
import numpy as np
import pandas as pd
import time
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import cm
import seaborn as sb
from IPython.display import display,HTML
import ipywidgets as widgets
import random

# plt.style.use('dark_background')
# from jupyterthemes import jtplot
# jtplot.style(theme='monokai', context='notebook', ticks=True, grid=True)
import matplotlib as mpl
mpl.rcParams['figure.facecolor'] = '#222222'
mpl.rcParams['axes.facecolor'] = '#222222'
mpl.rcParams['axes.labelcolor'] = '#edece9'
mpl.rcParams['xtick.color'] = '#edece9'
mpl.rcParams['ytick.color'] = '#edece9'
mpl.rcParams['text.color'] = '#edece9' 

### Load trained Random Forest Regressor model

In [4]:
import zipfile
import os
import pickle

filebase = "rf_regressor"

with zipfile.ZipFile(filebase+".zip") as zipref:
    zipref.extractall()

with open(filebase+".dat", 'rb') as f:
    model = pickle.load(f)

os.remove(filebase+".dat")

# with open("../../sc_data_inc/"+filebase+".dat", 'rb') as f:
#     model = pickle.load(f)

### Initialize input features 

In [5]:
data_head = pd.read_csv("SC_data_frac.csv",nrows=0) 
data_Tc = pd.read_csv("SC_data_frac.csv",usecols=["Critical Temperature"]) 

In [6]:
col_bool = [col[:-2]+"_bool" for col in data_head if col.endswith("_x")]

data_headr = ["Scaled T_c"]
data_headr.extend(col_bool)

data_Tc = data_Tc[data_Tc["Critical Temperature"]>0]
data_std = data_Tc.std()

### Initialize variables for interactive ploting 

In [48]:
help(widgets.IntRangeSlider)

Help on class IntRangeSlider in module ipywidgets.widgets.widget_int:

class IntRangeSlider(_BoundedIntRange)
 |  IntRangeSlider(*args, **kwargs)
 |  
 |  Slider/trackbar that represents a pair of ints bounded by minimum and maximum value.
 |  
 |  Parameters
 |  ----------
 |  value : int tuple
 |      The pair (`lower`, `upper`) of integers
 |  min : int
 |      The lowest allowed value for `lower`
 |  max : int
 |      The highest allowed value for `upper`
 |  
 |  Method resolution order:
 |      IntRangeSlider
 |      _BoundedIntRange
 |      _IntRange
 |      _Int
 |      ipywidgets.widgets.widget_description.DescriptionWidget
 |      ipywidgets.widgets.domwidget.DOMWidget
 |      ipywidgets.widgets.valuewidget.ValueWidget
 |      ipywidgets.widgets.widget_core.CoreWidget
 |      ipywidgets.widgets.widget.Widget
 |      ipywidgets.widgets.widget.LoggingHasTraits
 |      traitlets.traitlets.HasTraits
 |      traitlets.traitlets.HasDescriptors
 |      builtins.object
 |  
 |  Data 

In [49]:
T_range = widgets.IntRangeSlider(value=[1, 100],
                                 min=0,
                                 max=150,
                                 step=1,
                                 description='Critical Temperature range:',
                                 disabled=False,
                                 continuous_update=False,
                                 orientation='horizontal',
                                 readout=True,
                                 readout_format='d',
                                 style={'description_width': 'initial'},
                                  layout={'width':'800px'}
                                )

sym_dict = {}
for sym in col_bool:
    sym_dict[sym[:-5]] = widgets.ToggleButton(value=bool(np.random.choice([True,False])),
                                          description=sym[:-5],
                                          disabled=False,
                                          button_style='',
                                          indent=False, layout={'width':'45px', 'height':'30px', 'align':'center'}
                                         )
changed_dict = {}
changed_dict['T_range'] = T_range
changed_dict['Update'] =  widgets.Checkbox(value=True,
                                           description='Update',
                                           disabled=False,
                                           button_style=''
                                          )


In [50]:
sel_all = widgets.Button(description="Select All")
def select_all_elem(button_click):
    for sym in col_bool:
        sym_dict[sym[:-5]].value = True
    changed_dict['Update'].value = not(changed_dict['Update'].value)
sel_all.on_click(select_all_elem)

unsel_all = widgets.Button(description="Unselect All")
def unselect_all_elem(button_click):
    for sym in col_bool:
        sym_dict[sym[:-5]].value = False
    changed_dict['Update'].value = not(changed_dict['Update'].value)
unsel_all.on_click(unselect_all_elem)
    
inv_all = widgets.Button(description="Invert Selection")
def invert_all_elem(button_click):
    for sym in col_bool:
        sym_dict[sym[:-5]].value = not(sym_dict[sym[:-5]].value)
    changed_dict['Update'].value = not(changed_dict['Update'].value)
inv_all.on_click(invert_all_elem)

rand_all = widgets.Button(description="Randomize Selection")
def randomize_elem(button_click):
    p = np.random.uniform(0,1)
    for sym in col_bool:
        sym_dict[sym[:-5]].value = bool(np.random.choice([True,False],p=[p,1-p]))
    changed_dict['Update'].value = not(changed_dict['Update'].value)
rand_all.on_click(randomize_elem)

update_all = widgets.Button(description="Update Plot")
def update_plot(button_click):
    changed_dict['Update'].value = not(changed_dict['Update'].value)
update_all.on_click(update_plot)   

In [51]:

def show_plot(**kwargs):
    global fig,ax
    try:
        plt.close();
        fig = plt.figure(figsize=(12,8));
        ax = fig.add_subplot(111);
        # ax.cla()
        # fig.clf()
    except:
        fig = plt.figure(figsize=(12,8));
        ax = fig.add_subplot(111);
    
    test_X = pd.DataFrame(columns=data_headr)
    t_lim = kwargs['T_range']
    # for i,t in enumerate(np.linspace(t_lim[0],t_lim[1],101)):
    #     test_X.loc[i,"Scaled T_c"] = t/data_std["Critical Temperature"]
    test_X.loc[:,"Scaled T_c"] = np.linspace(t_lim[0],t_lim[1],101)/data_std["Critical Temperature"]
    for sym in col_bool:
        test_X.loc[:, sym] = int(sym_dict[sym[:-5]].value)
    
    Y_pred = model.predict(test_X)
    non_zero_elems = []
    zero_elems = []
    for i,sym in enumerate(col_bool):
        if (np.sum(Y_pred[:,i])>0):
            if (sym_dict[sym[:-5]].value>0):
                non_zero_elems.append([i,sym])
            else:
                zero_elems.append([i,sym])
                
    color_hsv = cm.get_cmap('hsv', len(non_zero_elems)+len(zero_elems)+1)
    ax.cla()
    for j,[i,sym] in enumerate(non_zero_elems):
        ax.plot(data_std["Critical Temperature"]*test_X["Scaled T_c"],100*Y_pred[:,i],c=color_hsv(j),label=sym[:-5]+" ~ "+str('%.2f'%(Y_pred[-1,i]*100))+"%",linestyle='-');
    for j,[i,sym] in enumerate(zero_elems):
        ax.plot(data_std["Critical Temperature"]*test_X["Scaled T_c"],100*Y_pred[:,i],c=color_hsv(j+len(non_zero_elems)),label=sym[:-5]+" ~ "+str('%.2f'%(Y_pred[-1,i]*100))+"%",linestyle=':');
    if len(non_zero_elems)+len(zero_elems)>0:
        ax.legend(title="Composition at Max(T_c)", loc='upper left', bbox_to_anchor=(1.0, 1.0), fancybox=True, framealpha=0.1, ncol=1);
    ax.set_xlabel("Critical Temperature (T_c) [K]");
    ax.set_ylabel("Predicted proportion of elements [%]");
    fig.canvas.draw();
    plt.show()
    # return fig
#     display(fig)


## The plot below is interactive: 
After manually selecting elements update plot.

In [52]:
plot_out = widgets.interactive_output(show_plot, changed_dict)

out1 = widgets.Output(layout={'border': '0px solid black'})
arg_dict = {n:sym_dict[n] for n in sorted(list(sym_dict.keys())) }
accordian = widgets.Accordion(children=[widgets.GridBox(list(arg_dict.values()), layout=widgets.Layout(grid_template_columns="repeat(16, 50px)"))])
accordian.set_title(0,'Elements')

with out1:
    out1.clear_output()
    display(widgets.VBox([accordian,
                                        #   layout={'display':'flexbox','grid-column':10}),
              widgets.HBox([update_all,sel_all,unsel_all,inv_all,rand_all]), T_range, plot_out]),layout={'color':'black'})

out1

VBox(children=(Accordion(children=(GridBox(children=(ToggleButton(value=True, description='Ag', layout=Layout(…

Output(layout=Layout(border='0px solid black'))