# Searching for chemical compositions of Superconductors

In [1]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error 

In [2]:
import warnings 
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore', category=DeprecationWarning)

In [3]:
import numpy as np
import pandas as pd
import time
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import cm
import seaborn as sb
from IPython.display import display
import ipywidgets as widgets
import random

# plt.style.use('dark_background')
# from jupyterthemes import jtplot
# jtplot.style(theme='monokai', context='notebook', ticks=True, grid=True)

## Load trained Random Forest Regressor model

In [4]:
import zipfile
import os
import pickle

with zipfile.ZipFile("rf_regressor.zip") as zipref:
    zipref.extractall()

file = 'rf_regressor.dat'

with open(file, 'rb') as f:
    model = pickle.load(f)

os.remove("rf_regressor.dat")

In [5]:
import numpy
print("numpy=="+numpy.__version__)

import sklearn
print("scikit-learn=="+sklearn.__version__)

import pandas
print("pandas=="+pandas.__version__)

import matplotlib
print("matplotlib=="+matplotlib.__version__)

import ipywidgets
print("ipywidgets=="+ipywidgets.__version__)

import seaborn
print("seaborn=="+seaborn.__version__)

import IPython
print("IPython=="+IPython.__version__)

import scipy
print("scipy=="+scipy.__version__)

from platform import python_version
print("Python "+python_version())

numpy==1.19.2
scikit-learn==0.23.2
pandas==1.2.0
matplotlib==3.3.2
ipywidgets==7.6.3
seaborn==0.11.1
IPython==7.19.0
scipy==1.5.2
Python 3.7.9


## Initialize input features 

In [6]:
data_head = pd.read_csv("SC_data_frac.csv",nrows=0) 
data_Tc = pd.read_csv("SC_data_frac.csv",usecols=["Critical Temperature"]) 

In [7]:
col_bool = [col[:-2]+"_bool" for col in data_head if col.endswith("_x")]

data_headr = ["Scaled T_c"]
data_headr.extend(col_bool)

data_Tc = data_Tc[data_Tc["Critical Temperature"]>0]
data_std = data_Tc.std()

## Initialize variables for interactive ploting 

In [8]:

T_range = widgets.IntRangeSlider(value=[1, 100],
                                 min=0,
                                 max=150,
                                 step=1,
                                 description='Temperature Range:',
                                 disabled=False,
                                 continuous_update=False,
                                 orientation='horizontal',
                                 readout=True,
                                 readout_format='d'
                                )

sym_dict = {}
for sym in col_bool:
    sym_dict[sym[:-5]] = widgets.Checkbox(value=bool(np.random.choice([True,False])),
                                          description=sym[:-5],
                                          disabled=False,
                                          button_style=''
                                         )
changed_dict = {}
changed_dict['T_range'] = T_range
changed_dict['Update'] =  widgets.Checkbox(value=True,
                                           description='Update',
                                           disabled=False,
                                           button_style=''
                                          )


In [9]:
sel_all = widgets.Button(description="Select All")
unsel_all = widgets.Button(description="Unselect All")
inv_all = widgets.Button(description="Invert Selection")
rand_all = widgets.Button(description="Randomize Selection")

def select_all_elem(button_click):
    for sym in col_bool:
        sym_dict[sym[:-5]].value = True
    changed_dict['Update'].value = not(changed_dict['Update'].value)

def unselect_all_elem(button_click):
    for sym in col_bool:
        sym_dict[sym[:-5]].value = False
    changed_dict['Update'].value = not(changed_dict['Update'].value)
    
def invert_all_elem(button_click):
    for sym in col_bool:
        sym_dict[sym[:-5]].value = not(sym_dict[sym[:-5]].value)
    changed_dict['Update'].value = not(changed_dict['Update'].value)

def randomize_elem(button_click):
    p = np.random.uniform(0,1)
    for sym in col_bool:
        sym_dict[sym[:-5]].value = bool(np.random.choice([True,False],p=[p,1-p]))
    changed_dict['Update'].value = not(changed_dict['Update'].value)
        
sel_all.on_click(select_all_elem)
unsel_all.on_click(unselect_all_elem)
inv_all.on_click(invert_all_elem)
rand_all.on_click(randomize_elem)

In [10]:

def show_plot(**kwargs):
    global fig,ax
    try:
        plt.close();
        fig = plt.figure(figsize=(8,6));
        ax = fig.add_subplot(111);
#         ax.cla()
#         fig.clf()
    except:
        fig = plt.figure(figsize=(12,8));
        ax = fig.add_subplot(111);
    
    test_X = pd.DataFrame(columns=data_headr)
    t_lim = kwargs['T_range']
    for i,t in enumerate(np.linspace(t_lim[0],t_lim[1],101)):
        test_X.loc[i,"Scaled T_c"] = t/data_std["Critical Temperature"]
    for sym in col_bool:
        test_X.loc[:, sym] = int(sym_dict[sym[:-5]].value)
    
    Y_pred = model.predict(test_X)
    non_zero_elems = 0
    for i,sym in enumerate(col_bool):
        if (sum(Y_pred[:,i])>0 ):
            non_zero_elems += 1

    color_hsv = cm.get_cmap('hsv', non_zero_elems+1)
    j=0
    ax.cla()
    for i,sym in enumerate(col_bool):
        if ( sum(Y_pred[:,i])>0 ):
            j += 1
            if (sym_dict[sym[:-5]].value):
                ax.plot(data_std["Critical Temperature"]*test_X["Scaled T_c"],Y_pred[:,i],c=color_hsv(j),label=sym[:-5]+"_x",linestyle='-');
            else:
                ax.plot(data_std["Critical Temperature"]*test_X["Scaled T_c"],Y_pred[:,i],c=color_hsv(j),label=sym[:-5]+"_x",linestyle=':');
    ax.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0), fancybox=True, framealpha=0.1, ncol=1);
    ax.set_xlabel("Temperature (K)");
    ax.set_ylabel("Predicted proportion of elements");
    fig.canvas.draw();
    plt.show()

#     display(fig)


In [11]:
update_all = widgets.Button(description="Update Plot")
def update_plot(button_click):
    changed_dict['Update'].value = not(changed_dict['Update'].value)
update_all.on_click(update_plot)

## The plot below is interactive: 
After manually selecting elements update plot.

In [12]:
plot_out = widgets.interactive_output(show_plot, changed_dict)

out1 = widgets.Output(layout={'border': '1px solid black'})

with out1:
    out1.clear_output()
    display(widgets.VBox([widgets.GridBox(list(sym_dict.values()), layout=widgets.Layout(grid_template_columns="repeat(8, 100px)")),
              widgets.HBox([update_all,sel_all,unsel_all,inv_all,rand_all]), T_range, plot_out]))

out1

Output(layout=Layout(border='1px solid black'))