In [54]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn.metrics import r2_score, explained_variance_score, mean_absolute_error
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.neighbors import KNeighborsRegressor

In [16]:
def get_raw_df():
    return pd.read_csv(r"C:\Users\91910\Desktop\IIT_Kanpur\Semester6\MSE643A\Project\bandGap\dielectric_constant_data.csv")

def get_propsFromStruct(df): #returns len, angles and periodic properties of a crystal structure for all compounds
    structure_arr = []

    for struct in df['structure']:
        finalStruct_arr = []
        modLen_arr = []
        modAngl_arr = []
        #modPerio_arr = []
        split_arr = struct.split('\n')
        len_arr = split_arr[2].split()[2:]
        angl_arr = split_arr[3].split()[1:]
        #perio_arr = split_arr[4].split()[2:]
        for i in range(3):
            modLen_arr.append(float(len_arr[i]))
            modAngl_arr.append(float(angl_arr[i]))
        for ele in modLen_arr:
            finalStruct_arr.append(ele)
        for ele in modAngl_arr:
            finalStruct_arr.append(ele)
#         for ele in perio_arr:
#             finalStruct_arr.append(ele)
        structure_arr.append(finalStruct_arr)
        
    return np.array(structure_arr)

def transform_raw_e_electronic(df): #transforms e_electronic into model readable values
    res = []
    for tensor_per_struct in df['e_electronic']:
        tensor_split = tensor_per_struct[2:-2].split(',')
        for i in range(len(tensor_split)):
            if (i==2 or i==5):
                tensor_split[i] = tensor_split[i].strip()[:-1]
            elif (i==3 or i==6):
                tensor_split[i] = tensor_split[i].strip()[1:]
        for j in range(len(tensor_split)):
            tensor_split[j] = float(tensor_split[j])
        res.append(tensor_split)
    return np.array(res)

def transform_raw_e_total(df): #transforms e_total into model readable values
    res = []
    for tensor_per_struct in df['e_total']:
        tensor_split = tensor_per_struct[2:-2].split(',')
        for i in range(len(tensor_split)):
            if (i==2 or i==5):
                tensor_split[i] = tensor_split[i].strip()[:-1]
            elif (i==3 or i==6):
                tensor_split[i] = tensor_split[i].strip()[1:]
        for j in range(len(tensor_split)):
            tensor_split[j] = float(tensor_split[j])
        res.append(tensor_split)
    return np.array(res)

def get_mod_df():
    raw_df = get_raw_df()
    mod_df = raw_df.copy()
    mod_df.drop(raw_df.columns[[0,1,2,6,8,9,14,15,16]], inplace= True, axis= 1)
    transformed_e_electronic = transform_raw_e_electronic(raw_df)
    transformed_e_total = transform_raw_e_total(raw_df)
    propsFromStruct_arr = get_propsFromStruct(raw_df)
    for i in range(9):
        mod_df['e_e_' + str(i)] = transformed_e_electronic[:, i]
    for i in range(9):
        mod_df['e_t_' + str(i)] = transformed_e_total[:, i]
    mod_df['len_a'] = propsFromStruct_arr[:,0]
    mod_df['len_b'] = propsFromStruct_arr[:,1]
    mod_df['len_c'] = propsFromStruct_arr[:,2]
    mod_df['angl_alpha'] = propsFromStruct_arr[:,3]
    mod_df['angl_beta'] = propsFromStruct_arr[:,4]
    mod_df['angl_gamma'] = propsFromStruct_arr[:,5]
    #below are excluded as they are true for all the compounds
#     mod_df['prd_p'] = propsFromStruct_arr[:,6]
#     mod_df['prd_b'] = propsFromStruct_arr[:,7]
#     mod_df['prd_c'] = propsFromStruct_arr[:,8]
    return mod_df

def transform_boolData(df, col_name): #1d array must be given
    le = LabelEncoder()
    df[col_name] = le.fit_transform(df[col_name])
    
def scale_features(df):
    scaler = StandardScaler()
    scaled_df = scaler.fit_transform(df)
    return scaled_df
    
def normal_split(df, test_size, shuffle):
    X = df.drop('band_gap', axis=1)
    y = df['band_gap']
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42, shuffle=shuffle)
    
    return (X_train, X_test, y_train, y_test)

def evaluate_metrics(y_test, y_preds):
    r2 = r2_score(y_test, y_preds)
    mae = mean_absolute_error(y_test, y_preds)
    evs = explained_variance_score(y_test, y_preds)

    print('r2 score is ', r2)
    print('mean absolute error is ', mae)
    print('explained variance score is ', evs)

In [10]:
mod_df = get_mod_df()

mod_df.drop(columns = ['len_a', 'len_b', 'len_c', 'angl_alpha', 'angl_beta', 'angl_gamma'], axis= 1, inplace=True)

mod_df

Unnamed: 0,nsites,space_group,volume,band_gap,n,poly_electronic,poly_total,pot_ferroelectric,e_e_0,e_e_1,...,e_e_8,e_t_0,e_t_1,e_t_2,e_t_3,e_t_4,e_t_5,e_t_6,e_t_7,e_t_8
0,3,225,159.501208,1.88,1.86,3.44,6.23,False,3.441158,-3.097000e-05,...,3.441319,6.234147,-3.525200e-04,-0.000098,-3.499200e-04,6.235413,2.481000e-05,-0.000095,2.175000e-05,6.235207
1,3,166,84.298097,3.52,1.78,3.16,6.73,False,3.346884,-4.498543e-02,...,2.731545,7.970187,-2.942389e-01,-1.463590,-2.942397e-01,8.264334,-9.046643e-01,-1.463589,-9.046600e-01,3.945366
2,3,164,108.335875,1.17,2.23,4.97,10.64,False,5.543085,-5.280000e-06,...,3.838153,13.806061,6.911900e-04,0.000097,6.969600e-04,13.805466,4.435100e-04,0.000123,4.420300e-04,4.315681
3,4,186,88.162562,1.12,2.65,7.04,17.99,False,7.093167,7.990000e-06,...,6.922545,16.795354,8.200000e-07,-0.009482,7.870000e-06,16.786730,2.064269e-02,-0.008708,1.761772e-02,20.396643
4,6,136,82.826401,2.87,1.53,2.35,7.12,False,2.423962,7.452000e-05,...,2.317388,6.440556,2.044660e-03,0.001320,1.889010e-03,7.456696,1.344158e-02,0.001279,1.441676e-02,7.459124
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1051,7,111,212.493121,0.87,2.77,7.67,11.76,True,7.748968,0.000000e+00,...,7.522554,11.851595,1.000000e-08,0.000000,1.000000e-08,11.851596,0.000000e+00,0.000000,0.000000e+00,11.562826
1052,8,194,220.041363,3.60,2.00,3.99,7.08,True,4.405044,6.100000e-07,...,3.165817,8.771364,1.650000e-06,0.000000,1.650000e-06,8.784490,-1.000000e-08,0.000000,-1.000000e-08,3.696193
1053,4,216,73.882306,0.14,14.58,212.61,232.60,True,212.607502,-1.843000e-05,...,212.607516,232.597074,-5.407400e-04,0.002588,-5.407400e-04,232.593258,1.830120e-03,0.002588,1.830120e-03,232.596394
1054,5,221,177.269065,0.21,2.53,6.41,22.44,True,6.405117,0.000000e+00,...,6.405117,22.437998,0.000000e+00,0.000000,0.000000e+00,22.438018,0.000000e+00,0.000000,0.000000e+00,22.438270


In [14]:
transform_boolData(mod_df, 'pot_ferroelectric')

mod_df

Unnamed: 0,nsites,space_group,volume,band_gap,n,poly_electronic,poly_total,pot_ferroelectric,e_e_0,e_e_1,...,e_e_8,e_t_0,e_t_1,e_t_2,e_t_3,e_t_4,e_t_5,e_t_6,e_t_7,e_t_8
0,3,225,159.501208,1.88,1.86,3.44,6.23,0,3.441158,-3.097000e-05,...,3.441319,6.234147,-3.525200e-04,-0.000098,-3.499200e-04,6.235413,2.481000e-05,-0.000095,2.175000e-05,6.235207
1,3,166,84.298097,3.52,1.78,3.16,6.73,0,3.346884,-4.498543e-02,...,2.731545,7.970187,-2.942389e-01,-1.463590,-2.942397e-01,8.264334,-9.046643e-01,-1.463589,-9.046600e-01,3.945366
2,3,164,108.335875,1.17,2.23,4.97,10.64,0,5.543085,-5.280000e-06,...,3.838153,13.806061,6.911900e-04,0.000097,6.969600e-04,13.805466,4.435100e-04,0.000123,4.420300e-04,4.315681
3,4,186,88.162562,1.12,2.65,7.04,17.99,0,7.093167,7.990000e-06,...,6.922545,16.795354,8.200000e-07,-0.009482,7.870000e-06,16.786730,2.064269e-02,-0.008708,1.761772e-02,20.396643
4,6,136,82.826401,2.87,1.53,2.35,7.12,0,2.423962,7.452000e-05,...,2.317388,6.440556,2.044660e-03,0.001320,1.889010e-03,7.456696,1.344158e-02,0.001279,1.441676e-02,7.459124
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1051,7,111,212.493121,0.87,2.77,7.67,11.76,1,7.748968,0.000000e+00,...,7.522554,11.851595,1.000000e-08,0.000000,1.000000e-08,11.851596,0.000000e+00,0.000000,0.000000e+00,11.562826
1052,8,194,220.041363,3.60,2.00,3.99,7.08,1,4.405044,6.100000e-07,...,3.165817,8.771364,1.650000e-06,0.000000,1.650000e-06,8.784490,-1.000000e-08,0.000000,-1.000000e-08,3.696193
1053,4,216,73.882306,0.14,14.58,212.61,232.60,1,212.607502,-1.843000e-05,...,212.607516,232.597074,-5.407400e-04,0.002588,-5.407400e-04,232.593258,1.830120e-03,0.002588,1.830120e-03,232.596394
1054,5,221,177.269065,0.21,2.53,6.41,22.44,1,6.405117,0.000000e+00,...,6.405117,22.437998,0.000000e+00,0.000000,0.000000e+00,22.438018,0.000000e+00,0.000000,0.000000e+00,22.438270


In [22]:
X_train, X_test, y_train, y_test = normal_split(mod_df, 0.3, True)

X_train

Unnamed: 0,nsites,space_group,volume,n,poly_electronic,poly_total,pot_ferroelectric,e_e_0,e_e_1,e_e_2,...,e_e_8,e_t_0,e_t_1,e_t_2,e_t_3,e_t_4,e_t_5,e_t_6,e_t_7,e_t_8
310,4,36,145.521118,3.02,9.10,12.48,0,9.092807,-0.128835,0.000640,...,9.111896,12.293298,-1.824314e-01,-7.414520e-03,-1.824441e-01,12.293183,-1.448878e-02,-8.631280e-03,-1.642273e-02,12.848129
493,9,162,125.649892,2.28,5.20,14.44,1,5.282135,-0.000118,0.000567,...,5.044919,14.196741,-2.020685e-02,3.954385e-02,-2.020685e-02,14.173463,6.849700e-02,3.954515e-02,6.849926e-02,14.939830
104,4,129,92.406572,2.19,4.80,11.57,0,5.116323,0.000003,0.000028,...,4.165931,14.902334,-9.728000e-05,9.780800e-04,-9.738000e-05,14.957064,9.887800e-04,9.782000e-04,9.881500e-04,4.847717
97,4,71,176.931872,1.73,3.00,5.86,0,2.945779,0.030535,-0.025786,...,2.993131,5.776181,1.313331e-01,4.525379e-02,1.313275e-01,6.033881,-2.152550e-01,4.525042e-02,-2.152684e-01,5.758591
328,2,225,55.005179,1.67,2.80,8.88,0,2.798744,0.000001,-0.000001,...,2.798764,8.877549,4.567100e-04,2.005900e-04,4.552800e-04,8.878194,1.452000e-04,2.005900e-04,1.452000e-04,8.878139
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
330,8,194,353.715043,2.84,8.09,13.83,0,7.718620,-0.000011,-0.000034,...,8.842169,13.279189,2.702000e-05,-2.128500e-04,5.917000e-05,13.274460,-1.217060e-03,-2.653100e-04,-1.092290e-03,14.939210
466,4,166,54.633956,2.15,4.62,10.55,1,4.970524,0.000000,0.000000,...,3.906257,11.510467,1.000000e-08,-1.000000e-08,1.000000e-08,11.510434,4.730000e-06,-1.000000e-08,4.730000e-06,8.637382
121,5,217,87.578887,1.28,1.63,2.17,0,1.630283,0.000035,0.000021,...,1.630310,2.168768,5.974400e-04,-2.071500e-04,5.272500e-04,2.169429,-1.544160e-03,-2.505900e-04,-1.514870e-03,2.168502
1044,4,166,139.017820,2.26,5.12,13.09,1,5.473830,-0.056863,-0.338594,...,4.344032,14.933416,-2.928209e-01,-1.743245e+00,-2.928209e-01,15.238735,-1.057886e+00,-1.743245e+00,-1.057886e+00,9.109768


In [23]:
X_train_no_bools = X_train.drop('pot_ferroelectric', axis=1)
X_test_no_bools = X_test.drop('pot_ferroelectric', axis=1)

X_train_no_bools

Unnamed: 0,nsites,space_group,volume,n,poly_electronic,poly_total,e_e_0,e_e_1,e_e_2,e_e_3,...,e_e_8,e_t_0,e_t_1,e_t_2,e_t_3,e_t_4,e_t_5,e_t_6,e_t_7,e_t_8
310,4,36,145.521118,3.02,9.10,12.48,9.092807,-0.128835,0.000640,-1.288473e-01,...,9.111896,12.293298,-1.824314e-01,-7.414520e-03,-1.824441e-01,12.293183,-1.448878e-02,-8.631280e-03,-1.642273e-02,12.848129
493,9,162,125.649892,2.28,5.20,14.44,5.282135,-0.000118,0.000567,-1.179800e-04,...,5.044919,14.196741,-2.020685e-02,3.954385e-02,-2.020685e-02,14.173463,6.849700e-02,3.954515e-02,6.849926e-02,14.939830
104,4,129,92.406572,2.19,4.80,11.57,5.116323,0.000003,0.000028,2.440000e-06,...,4.165931,14.902334,-9.728000e-05,9.780800e-04,-9.738000e-05,14.957064,9.887800e-04,9.782000e-04,9.881500e-04,4.847717
97,4,71,176.931872,1.73,3.00,5.86,2.945779,0.030535,-0.025786,3.052938e-02,...,2.993131,5.776181,1.313331e-01,4.525379e-02,1.313275e-01,6.033881,-2.152550e-01,4.525042e-02,-2.152684e-01,5.758591
328,2,225,55.005179,1.67,2.80,8.88,2.798744,0.000001,-0.000001,-1.700000e-07,...,2.798764,8.877549,4.567100e-04,2.005900e-04,4.552800e-04,8.878194,1.452000e-04,2.005900e-04,1.452000e-04,8.878139
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
330,8,194,353.715043,2.84,8.09,13.83,7.718620,-0.000011,-0.000034,2.083000e-05,...,8.842169,13.279189,2.702000e-05,-2.128500e-04,5.917000e-05,13.274460,-1.217060e-03,-2.653100e-04,-1.092290e-03,14.939210
466,4,166,54.633956,2.15,4.62,10.55,4.970524,0.000000,0.000000,0.000000e+00,...,3.906257,11.510467,1.000000e-08,-1.000000e-08,1.000000e-08,11.510434,4.730000e-06,-1.000000e-08,4.730000e-06,8.637382
121,5,217,87.578887,1.28,1.63,2.17,1.630283,0.000035,0.000021,-3.504000e-05,...,1.630310,2.168768,5.974400e-04,-2.071500e-04,5.272500e-04,2.169429,-1.544160e-03,-2.505900e-04,-1.514870e-03,2.168502
1044,4,166,139.017820,2.26,5.12,13.09,5.473830,-0.056863,-0.338594,-5.686262e-02,...,4.344032,14.933416,-2.928209e-01,-1.743245e+00,-2.928209e-01,15.238735,-1.057886e+00,-1.743245e+00,-1.057886e+00,9.109768


In [32]:
scaled_X_train_no_bools = scale_features(X_train_no_bools)
scaled_y_train = scale_features(np.array(y_train).reshape(len(y_train),1))

scaled_X_test_no_bools = scale_features(X_test_no_bools)
scaled_y_test = scale_features(np.array(y_test).reshape(len(y_test), 1))

scaled_X_train_df = pd.DataFrame(scaled_X_train_no_bools, columns= X_train_no_bools.columns)
scaled_X_test_df = pd.DataFrame(scaled_X_test_no_bools, columns= X_test_no_bools.columns)

scaled_y_train = scaled_y_train.reshape(len(scaled_y_train),)
scaled_y_test = scaled_y_test.reshape(len(scaled_y_test),)

scaled_X_train_df

Unnamed: 0,nsites,space_group,volume,n,poly_electronic,poly_total,e_e_0,e_e_1,e_e_2,e_e_3,...,e_e_8,e_t_0,e_t_1,e_t_2,e_t_3,e_t_4,e_t_5,e_t_6,e_t_7,e_t_8
0,-1.036277,-1.609201,-0.212833,0.584859,0.199161,-0.120952,0.145676,-1.075392,0.057398,-1.073317,...,0.291365,-0.152762,-0.597079,0.090201,-0.596792,-0.130043,0.078367,0.086701,0.073913,-0.040526
1,0.438133,0.274455,-0.413661,-0.112772,-0.151391,-0.020650,-0.157823,0.006775,0.057079,0.008314,...,-0.126372,-0.061854,-0.027717,0.172205,-0.027087,-0.052711,0.253192,0.170811,0.252811,0.074756
2,-1.036277,-0.218884,-0.749633,-0.197619,-0.187346,-0.167520,-0.171029,0.007788,0.054706,0.009325,...,-0.216657,-0.028155,0.042862,0.104857,0.043528,-0.020483,0.110974,0.103478,0.110591,-0.481460
3,-1.036277,-1.085964,0.104619,-0.631281,-0.349139,-0.459725,-0.343901,0.264485,-0.058808,0.265824,...,-0.337120,-0.464016,0.504144,0.182176,0.505034,-0.387476,-0.344584,0.180772,-0.344977,-0.431258
4,-1.626042,1.216282,-1.127630,-0.687846,-0.367116,-0.305179,-0.355612,0.007778,0.054578,0.009303,...,-0.357084,-0.315896,0.044806,0.103499,0.045469,-0.270495,0.109197,0.102120,0.108815,-0.259328
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
734,0.143251,0.752843,1.891273,0.415165,0.108377,-0.051867,0.036229,0.007672,0.054434,0.009480,...,0.263660,-0.105676,0.043298,0.102777,0.044078,-0.089685,0.106327,0.101307,0.106208,0.074722
735,-1.036277,0.334253,-1.131381,-0.235329,-0.203525,-0.219718,-0.182642,0.007767,0.054585,0.009305,...,-0.243329,-0.190149,0.043203,0.103149,0.043870,-0.162236,0.108901,0.101770,0.108519,-0.272597
736,-0.741395,1.096685,-0.798424,-1.055516,-0.472282,-0.648557,-0.448673,0.008063,0.054677,0.009010,...,-0.477101,-0.636304,0.045300,0.102787,0.045722,-0.546413,0.105638,0.101332,0.105318,-0.629123
737,-1.036277,0.334253,-0.278558,-0.131627,-0.158582,-0.089736,-0.142556,-0.470297,-1.434334,-0.468476,...,-0.198363,-0.026671,-0.984514,-2.941105,-0.984387,-0.008898,-2.119741,-2.941748,-2.120043,-0.246562


In [36]:
X_train = X_train.reset_index().drop('index', axis=1)
X_test = X_test.reset_index().drop('index', axis=1)

scaled_X_train_df['pot_ferroelectric'] = X_train['pot_ferroelectric']
scaled_X_test_df['pot_ferroelectric'] = X_test['pot_ferroelectric']

scaled_X_train_df

Unnamed: 0,nsites,space_group,volume,n,poly_electronic,poly_total,e_e_0,e_e_1,e_e_2,e_e_3,...,e_t_0,e_t_1,e_t_2,e_t_3,e_t_4,e_t_5,e_t_6,e_t_7,e_t_8,pot_ferroelectric
0,-1.036277,-1.609201,-0.212833,0.584859,0.199161,-0.120952,0.145676,-1.075392,0.057398,-1.073317,...,-0.152762,-0.597079,0.090201,-0.596792,-0.130043,0.078367,0.086701,0.073913,-0.040526,0
1,0.438133,0.274455,-0.413661,-0.112772,-0.151391,-0.020650,-0.157823,0.006775,0.057079,0.008314,...,-0.061854,-0.027717,0.172205,-0.027087,-0.052711,0.253192,0.170811,0.252811,0.074756,1
2,-1.036277,-0.218884,-0.749633,-0.197619,-0.187346,-0.167520,-0.171029,0.007788,0.054706,0.009325,...,-0.028155,0.042862,0.104857,0.043528,-0.020483,0.110974,0.103478,0.110591,-0.481460,0
3,-1.036277,-1.085964,0.104619,-0.631281,-0.349139,-0.459725,-0.343901,0.264485,-0.058808,0.265824,...,-0.464016,0.504144,0.182176,0.505034,-0.387476,-0.344584,0.180772,-0.344977,-0.431258,0
4,-1.626042,1.216282,-1.127630,-0.687846,-0.367116,-0.305179,-0.355612,0.007778,0.054578,0.009303,...,-0.315896,0.044806,0.103499,0.045469,-0.270495,0.109197,0.102120,0.108815,-0.259328,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
734,0.143251,0.752843,1.891273,0.415165,0.108377,-0.051867,0.036229,0.007672,0.054434,0.009480,...,-0.105676,0.043298,0.102777,0.044078,-0.089685,0.106327,0.101307,0.106208,0.074722,0
735,-1.036277,0.334253,-1.131381,-0.235329,-0.203525,-0.219718,-0.182642,0.007767,0.054585,0.009305,...,-0.190149,0.043203,0.103149,0.043870,-0.162236,0.108901,0.101770,0.108519,-0.272597,1
736,-0.741395,1.096685,-0.798424,-1.055516,-0.472282,-0.648557,-0.448673,0.008063,0.054677,0.009010,...,-0.636304,0.045300,0.102787,0.045722,-0.546413,0.105638,0.101332,0.105318,-0.629123,0
737,-1.036277,0.334253,-0.278558,-0.131627,-0.158582,-0.089736,-0.142556,-0.470297,-1.434334,-0.468476,...,-0.026671,-0.984514,-2.941105,-0.984387,-0.008898,-2.119741,-2.941748,-2.120043,-0.246562,1


In [37]:
scaled_y_train

array([-9.61877380e-01, -1.02410620e+00, -7.31630740e-01, -2.83583224e-01,
        4.69385518e-01, -8.12528208e-01,  2.51584642e-01, -3.76926456e-01,
       -5.95594662e-02, -1.52902699e-01,  3.13813464e-01, -5.94727332e-01,
        8.92541505e-01, -1.04277485e+00, -7.81413797e-01, -1.52902699e-01,
        5.62728750e-01,  1.27126999e-01,  4.32048225e-01,  9.79661855e-01,
        1.85708824e+00, -6.69401918e-01,  1.04811356e+00, -9.06738770e-02,
       -5.33365841e-02, -5.63612921e-01, -4.95161218e-01,  1.27836020e+00,
        1.18501697e+00,  1.02235470e-01, -3.02251871e-01, -1.77794227e-01,
        1.68907042e+00,  2.07488912e+00, -5.63612921e-01, -7.37853622e-01,
        1.62061872e+00, -7.20052305e-02, -1.09878079e+00,  5.93843161e-01,
       -5.69835804e-01, -9.36985851e-01,  4.07156696e-01, -8.68534147e-01,
        1.17879408e+00, -9.68967592e-02, -2.33800167e-01,  2.75608841e-02,
        2.00021453e+00, -2.83583224e-01,  9.23655915e-01, -1.34234052e-01,
        3.82265167e-01, -

In [42]:
#yes scaling
regr1 = svm.SVR()
regr1.fit(scaled_X_train_df, scaled_y_train)
y_preds1 = regr1.predict(scaled_X_test_df)

evaluate_metrics(scaled_y_test, y_preds1)

r2 score is  0.5962302924124594
mean absolute error is  0.44054311088767634
explained variance score is  0.5968095536578264


In [45]:
X_train

Unnamed: 0,nsites,space_group,volume,n,poly_electronic,poly_total,pot_ferroelectric,e_e_0,e_e_1,e_e_2,...,e_e_8,e_t_0,e_t_1,e_t_2,e_t_3,e_t_4,e_t_5,e_t_6,e_t_7,e_t_8
0,4,36,145.521118,3.02,9.10,12.48,0,9.092807,-0.128835,0.000640,...,9.111896,12.293298,-1.824314e-01,-7.414520e-03,-1.824441e-01,12.293183,-1.448878e-02,-8.631280e-03,-1.642273e-02,12.848129
1,9,162,125.649892,2.28,5.20,14.44,1,5.282135,-0.000118,0.000567,...,5.044919,14.196741,-2.020685e-02,3.954385e-02,-2.020685e-02,14.173463,6.849700e-02,3.954515e-02,6.849926e-02,14.939830
2,4,129,92.406572,2.19,4.80,11.57,0,5.116323,0.000003,0.000028,...,4.165931,14.902334,-9.728000e-05,9.780800e-04,-9.738000e-05,14.957064,9.887800e-04,9.782000e-04,9.881500e-04,4.847717
3,4,71,176.931872,1.73,3.00,5.86,0,2.945779,0.030535,-0.025786,...,2.993131,5.776181,1.313331e-01,4.525379e-02,1.313275e-01,6.033881,-2.152550e-01,4.525042e-02,-2.152684e-01,5.758591
4,2,225,55.005179,1.67,2.80,8.88,0,2.798744,0.000001,-0.000001,...,2.798764,8.877549,4.567100e-04,2.005900e-04,4.552800e-04,8.878194,1.452000e-04,2.005900e-04,1.452000e-04,8.878139
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
734,8,194,353.715043,2.84,8.09,13.83,0,7.718620,-0.000011,-0.000034,...,8.842169,13.279189,2.702000e-05,-2.128500e-04,5.917000e-05,13.274460,-1.217060e-03,-2.653100e-04,-1.092290e-03,14.939210
735,4,166,54.633956,2.15,4.62,10.55,1,4.970524,0.000000,0.000000,...,3.906257,11.510467,1.000000e-08,-1.000000e-08,1.000000e-08,11.510434,4.730000e-06,-1.000000e-08,4.730000e-06,8.637382
736,5,217,87.578887,1.28,1.63,2.17,0,1.630283,0.000035,0.000021,...,1.630310,2.168768,5.974400e-04,-2.071500e-04,5.272500e-04,2.169429,-1.544160e-03,-2.505900e-04,-1.514870e-03,2.168502
737,4,166,139.017820,2.26,5.12,13.09,1,5.473830,-0.056863,-0.338594,...,4.344032,14.933416,-2.928209e-01,-1.743245e+00,-2.928209e-01,15.238735,-1.057886e+00,-1.743245e+00,-1.057886e+00,9.109768


In [46]:
#no scaling
regr2 = DecisionTreeRegressor()
regr2.fit(X_train, y_train)
y_preds2 = regr2.predict(X_test)

evaluate_metrics(y_test, y_preds2)

r2 score is  0.3678162770243959
mean absolute error is  0.8036277602523657
explained variance score is  0.36986177104551277


In [48]:
#no scaling
regr3 = RandomForestRegressor()
regr3.fit(X_train, y_train)
y_preds3 = regr3.predict(X_test)

evaluate_metrics(y_test, y_preds3)

r2 score is  0.6627677743203102
mean absolute error is  0.6329690851735016
explained variance score is  0.6694665200954992


In [51]:
#yes scaling
regr4 = KNeighborsRegressor()
regr4.fit(scaled_X_train_df, scaled_y_train)
y_preds4 = regr4.predict(scaled_X_test_df)

evaluate_metrics(scaled_y_test, y_preds4)

r2 score is  0.5483447412706364
mean absolute error is  0.4641065000424111
explained variance score is  0.5485912347220026


In [52]:
#no scaling
regr5 = GradientBoostingRegressor()
regr5.fit(X_train, y_train)
y_preds5 = regr5.predict(X_test)

evaluate_metrics(y_test, y_preds5)

r2 score is  0.6550940800012128
mean absolute error is  0.6555878437119748
explained variance score is  0.6622640699732043
