In [1]:
import pandas as pd
import numpy as np
from scipy import stats
from tqdm import tqdm

In [2]:
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve

In [3]:
data = pd.read_csv("OneNullDataset.csv")

In [4]:
data

Unnamed: 0,PATIENTID,Site,MORPH_ICD10_O2,BEHAVIOUR_ICD10_O2,T_BEST,N_BEST,M_BEST,GRADE,AGE,SEX,...,CLINICAL_TRIAL,CHEMO_RADIATION,REGIMEN_MOD_TIME_DELAY,REGIMEN_MOD_STOPPED_EARLY,REGIMEN_OUTCOME_SUMMARY,CYCLE_NUMBER,ACTUAL_DOSE_PER_ADMINISTRATION,ADMINISTRATION_ROUTE,DRUG_GROUP,diff
0,40129115,C34,8041.0,MALIGNANT,1a,0.0,0,GX,83,FEMALE,...,Not Taking Part,NO,NO,YES,0,5,152.0,1,PACLITAXEL,2119
1,40100400,C34,8041.0,MALIGNANT,1b,0.0,0,G3,72,FEMALE,...,Not Taking Part,NO,NO,NO,0,3,200.0,1,NOT CHEMO,1725
2,40010618,C34,8070.0,MALIGNANT,2a,1.0,0,G3,73,MALE,...,Not Taking Part,NO,NO,NO,0.0,4,2800.0,2,ERLOTINIB,1981
3,40065028,C34,8070.0,MALIGNANT,1,0.0,0,G3,67,FEMALE,...,Not Taking Part,NO,YES,NO,0.0,3,120.0,1,ETOPOSIDE,272
4,40035484,C34,8046.0,MALIGNANT,3,3.0,1b,G3,61,MALE,...,Not Taking Part,NO,NO,NO,0,3,900.0,1,PEMETREXED,235
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
945,40130314,C34,8140.0,MALIGNANT,3,2.0,0,GX,69,FEMALE,...,Not Taking Part,NO,NO,NO,0,1,8.0,2,STEROID,97
946,110021844,C34,8140.0,MALIGNANT,1a,0.0,0,GX,86,MALE,...,Not Taking Part,NO,YES,NO,0,6,113.0,1,NOT CHEMO,759
947,40153872,C34,8140.0,MALIGNANT,2a,1.0,1b,GX,62,FEMALE,...,Not Taking Part,NO,NO,NO,2,1,2128.0,1,GEMCITABINE,66
948,40063001,C34,8070.0,MALIGNANT,3,0.0,0,GX,73,MALE,...,Not Taking Part,NO,NO,NO,0,6,8.0,2,NOT CHEMO,1716


In [5]:
data= data.rename(columns={"MORPH_ICD10_O2":"Morph","BEHAVIOUR_ICD10_O2":"Behaviour","T_BEST": "T Best",
                          "N_BEST" : "N Best", "M_BEST" : "M Best", "GRADE" : "Grade", "AGE" : "Age",
                          "SEX":"Sex","CANCERCAREPLANINTENT":"Cancer Plan", "NEWVITALSTATUS" : "Vital Status",
                          "HEIGHT_AT_START_OF_REGIMEN":"Height","WEIGHT_AT_START_OF_REGIMEN":"Weight","MAPPED_REGIMEN":"Regimen","CLINICAL_TRIAL":"Clinical Trial",
                          "CHEMO_RADIATION":"Chemo Radiation","REGIMEN_MOD_TIME_DELAY":"Regimen Time Delay","REGIMEN_MOD_STOPPED_EARLY":"Regimen Stopped Early",
                          "REGIMEN_OUTCOME_SUMMARY":"Outcome","CYCLE_NUMBER":"Cycle","ACTUAL_DOSE_PER_ADMINISTRATION":"Dose Administration",
                           "ADMINISTRATION_ROUTE":"Administration Route", "DRUG_GROUP":"Drug Group","ACE27":"ACE"})


In [6]:
from sklearn import preprocessing
import matplotlib.pyplot as plt
plt.rc("font", size = 14)
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import seaborn as sns
sns.set(style="white")
sns.set(style="whitegrid",color_codes=True)

from sklearn.preprocessing import OneHotEncoder

In [7]:
data['Vital Status'].value_counts()

D    520
A    430
Name: Vital Status, dtype: int64

In [8]:
data = data.rename(columns={"Vital Status":"Alive"})

In [9]:
cat_vars = data.columns.tolist()
cat_vars.remove("Site")
cat_vars.remove("Height")
cat_vars.remove("Weight")
cat_vars.remove("Morph")
cat_vars.remove("Age")
cat_vars.remove("Cycle")
cat_vars.remove("Dose Administration")
cat_vars.remove("diff")
cat_vars

['PATIENTID',
 'Behaviour',
 'T Best',
 'N Best',
 'M Best',
 'Grade',
 'Sex',
 'Cancer Plan',
 'CNS',
 'ACE',
 'Alive',
 'Regimen',
 'Clinical Trial',
 'Chemo Radiation',
 'Regimen Time Delay',
 'Regimen Stopped Early',
 'Outcome',
 'Administration Route',
 'Drug Group']

In [10]:
num_vars = ["diff","Height","Weight","Morph","Age","Cycle","Dose Administration"]

In [11]:
data_final = data.copy()
dete = pd.get_dummies(data_final)
dete

Unnamed: 0,PATIENTID,Morph,N Best,Age,ACE,Height,Weight,Cycle,Dose Administration,Administration Route,...,Drug Group_PROCARBAZINE,Drug Group_RITUXIMAB,Drug Group_STEROID,Drug Group_THIOTEPA,Drug Group_TOPOTECAN,Drug Group_TRASTUZUMAB,Drug Group_TRIAL,Drug Group_VINCRISTINE,Drug Group_VINORELBINE,Drug Group_ZOLEDRONIC ACID
0,40129115,8041.0,0.0,83,3,1.77,66.0,5,152.0,1,...,0,0,0,0,0,0,0,0,0,0
1,40100400,8041.0,0.0,72,3,1.50,90.0,3,200.0,1,...,0,0,0,0,0,0,0,0,0,0
2,40010618,8070.0,1.0,73,9,1.55,68.0,4,2800.0,2,...,0,0,0,0,0,0,0,0,0,0
3,40065028,8070.0,0.0,67,9,1.76,75.7,3,120.0,1,...,0,0,0,0,0,0,0,0,0,0
4,40035484,8046.0,3.0,61,2,1.71,71.7,3,900.0,1,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
945,40130314,8140.0,2.0,69,1,1.50,85.4,1,8.0,2,...,0,0,1,0,0,0,0,0,0,0
946,110021844,8140.0,0.0,86,9,1.56,76.0,6,113.0,1,...,0,0,0,0,0,0,0,0,0,0
947,40153872,8140.0,1.0,62,9,1.63,66.4,1,2128.0,1,...,0,0,0,0,0,0,0,0,0,0
948,40063001,8070.0,0.0,73,0,1.59,41.0,6,8.0,2,...,0,0,0,0,0,0,0,0,0,0


In [12]:
#X = X.drop(columns=["Site","Alive"], axis = 1)
for var in cat_vars:
    data_final = pd.concat((data_final,pd.get_dummies(data[var], prefix = var)),1)


Ytrain = data_final[["diff"]]
Xtrain = data_final.drop(columns=["Site","diff"])
for var in cat_vars:
    Xtrain = Xtrain.drop(columns=[var])


  data_final = pd.concat((data_final,pd.get_dummies(data[var], prefix = var)),1)


In [13]:
Ytrain

Unnamed: 0,diff
0,2119
1,1725
2,1981
3,272
4,235
...,...
945,97
946,759
947,66
948,1716


In [14]:
Xtrain

Unnamed: 0,Morph,Age,Height,Weight,Cycle,Dose Administration,PATIENTID_10281336,PATIENTID_10284884,PATIENTID_10298791,PATIENTID_10302807,...,Drug Group_PROCARBAZINE,Drug Group_RITUXIMAB,Drug Group_STEROID,Drug Group_THIOTEPA,Drug Group_TOPOTECAN,Drug Group_TRASTUZUMAB,Drug Group_TRIAL,Drug Group_VINCRISTINE,Drug Group_VINORELBINE,Drug Group_ZOLEDRONIC ACID
0,8041.0,83,1.77,66.0,5,152.0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,8041.0,72,1.50,90.0,3,200.0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,8070.0,73,1.55,68.0,4,2800.0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,8070.0,67,1.76,75.7,3,120.0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,8046.0,61,1.71,71.7,3,900.0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
945,8140.0,69,1.50,85.4,1,8.0,0,0,0,0,...,0,0,1,0,0,0,0,0,0,0
946,8140.0,86,1.56,76.0,6,113.0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
947,8140.0,62,1.63,66.4,1,2128.0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
948,8070.0,73,1.59,41.0,6,8.0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [15]:
X_train, X_test, Y_train, Y_test = train_test_split(Xtrain, Ytrain, test_size = 0.3, random_state = 30)

# XGBoost

In [16]:
import xgboost as xgb
from sklearn.metrics import mean_squared_error as MSE
from sklearn.metrics import r2_score


In [17]:
xgb_r = xgb.XGBRegressor(objective ='reg:squarederror', n_estimators = 100, seed = 123)
  
# Fitting the model
xgb_r.fit(X_train, Y_train)

XGBRegressor(base_score=0.5, booster='gbtree', callbacks=None,
             colsample_bylevel=1, colsample_bynode=1, colsample_bytree=1,
             early_stopping_rounds=None, enable_categorical=False,
             eval_metric=None, gamma=0, gpu_id=-1, grow_policy='depthwise',
             importance_type=None, interaction_constraints='',
             learning_rate=0.300000012, max_bin=256, max_cat_to_onehot=4,
             max_delta_step=0, max_depth=6, max_leaves=0, min_child_weight=1,
             missing=nan, monotone_constraints='()', n_estimators=100, n_jobs=0,
             num_parallel_tree=1, predictor='auto', random_state=123,
             reg_alpha=0, reg_lambda=1, ...)

In [18]:
pred = xgb_r.predict(X_test)

In [19]:
# RMSE Computation
rmse = np.sqrt(MSE(Y_test, pred))
print("RMSE : % f" %(rmse))

RMSE :  344.722773


In [20]:
r2_score(Y_test, pred)

0.7461217478539397

# Explainable Boosting Machine (EBM)


In [21]:
from interpret.glassbox import ExplainableBoostingRegressor
from interpret import show

In [22]:
ebm = ExplainableBoostingRegressor()
ebm.fit(X_train, Y_train)

#ebm_global = ebm.explain_global()
#show(emb_global)

#ebm_local = ebm.explain_local(X_test[:5], Y_test[:5])
#show(ebm_local)

ExplainableBoostingRegressor()

In [23]:
preds = ebm.predict(X_test)

In [24]:
# RMSE Computation
rmse = np.sqrt(MSE(Y_test, preds))
print("RMSE : % f" %(rmse))

RMSE :  393.889862


In [25]:
r2_score(Y_test, preds)

0.6685369238204444

# Random Forest

In [26]:
from sklearn.ensemble import RandomForestRegressor

In [27]:
#Create a Gaussian Classifier
clf=RandomForestRegressor(n_estimators=100)

#Train the model using the training sets y_pred=clf.predict(X_test)
clf.fit(X_train,Y_train.values.ravel())
pred=clf.predict(X_test)

In [28]:
# RMSE Computation
rmse = np.sqrt(MSE(Y_test, pred))
print("RMSE : % f" %(rmse))

RMSE :  343.533442


In [29]:
r2_score(Y_test, pred)

0.7478705406990671

# CatBoost

In [30]:
from catboost import CatBoostRegressor

In [31]:
model = CatBoostRegressor()
# Fit model
model.fit(X_train, Y_train)
# Get predictions
pred = model.predict(X_test)

Learning rate set to 0.038387
0:	learn: 665.5612360	total: 139ms	remaining: 2m 18s
1:	learn: 653.1222069	total: 142ms	remaining: 1m 10s
2:	learn: 641.3785775	total: 145ms	remaining: 48.2s
3:	learn: 629.6383301	total: 148ms	remaining: 36.9s
4:	learn: 618.5194575	total: 152ms	remaining: 30.2s
5:	learn: 608.0349788	total: 155ms	remaining: 25.7s
6:	learn: 598.8634667	total: 158ms	remaining: 22.4s
7:	learn: 589.8071738	total: 161ms	remaining: 20s
8:	learn: 581.1299355	total: 164ms	remaining: 18.1s
9:	learn: 572.7877066	total: 168ms	remaining: 16.6s
10:	learn: 564.7976286	total: 171ms	remaining: 15.3s
11:	learn: 556.9626557	total: 174ms	remaining: 14.3s
12:	learn: 548.9407582	total: 177ms	remaining: 13.4s
13:	learn: 542.1865253	total: 180ms	remaining: 12.7s
14:	learn: 535.5957998	total: 183ms	remaining: 12s
15:	learn: 529.1063340	total: 187ms	remaining: 11.5s
16:	learn: 523.3658128	total: 190ms	remaining: 11s
17:	learn: 517.4339517	total: 193ms	remaining: 10.6s
18:	learn: 512.1867747	total: 

162:	learn: 337.2731610	total: 674ms	remaining: 3.46s
163:	learn: 336.9037077	total: 677ms	remaining: 3.45s
164:	learn: 336.5910007	total: 680ms	remaining: 3.44s
165:	learn: 335.5380252	total: 684ms	remaining: 3.43s
166:	learn: 335.3943203	total: 686ms	remaining: 3.42s
167:	learn: 335.0613212	total: 689ms	remaining: 3.41s
168:	learn: 334.6911893	total: 692ms	remaining: 3.4s
169:	learn: 334.1509490	total: 695ms	remaining: 3.39s
170:	learn: 333.6632211	total: 698ms	remaining: 3.38s
171:	learn: 333.4122056	total: 701ms	remaining: 3.38s
172:	learn: 333.1459590	total: 705ms	remaining: 3.37s
173:	learn: 332.8870305	total: 708ms	remaining: 3.36s
174:	learn: 332.5782861	total: 711ms	remaining: 3.35s
175:	learn: 332.3526971	total: 715ms	remaining: 3.35s
176:	learn: 332.1071922	total: 718ms	remaining: 3.34s
177:	learn: 331.2255733	total: 721ms	remaining: 3.33s
178:	learn: 330.9266674	total: 725ms	remaining: 3.32s
179:	learn: 330.7715693	total: 728ms	remaining: 3.31s
180:	learn: 330.4267534	total

318:	learn: 287.2769269	total: 1.21s	remaining: 2.58s
319:	learn: 287.0770873	total: 1.21s	remaining: 2.58s
320:	learn: 286.8875703	total: 1.22s	remaining: 2.58s
321:	learn: 286.7652112	total: 1.22s	remaining: 2.57s
322:	learn: 286.5690145	total: 1.23s	remaining: 2.57s
323:	learn: 286.3741964	total: 1.23s	remaining: 2.56s
324:	learn: 286.2529528	total: 1.23s	remaining: 2.56s
325:	learn: 286.1324081	total: 1.24s	remaining: 2.55s
326:	learn: 285.9446398	total: 1.24s	remaining: 2.55s
327:	learn: 285.7532293	total: 1.24s	remaining: 2.54s
328:	learn: 285.6335214	total: 1.25s	remaining: 2.54s
329:	learn: 285.0376027	total: 1.25s	remaining: 2.54s
330:	learn: 284.8542552	total: 1.25s	remaining: 2.53s
331:	learn: 284.6347979	total: 1.26s	remaining: 2.53s
332:	learn: 284.5246527	total: 1.26s	remaining: 2.52s
333:	learn: 284.4055615	total: 1.26s	remaining: 2.52s
334:	learn: 284.2225479	total: 1.27s	remaining: 2.52s
335:	learn: 284.1042934	total: 1.27s	remaining: 2.51s
336:	learn: 283.9177346	tota

481:	learn: 248.7119365	total: 1.8s	remaining: 1.93s
482:	learn: 248.5757145	total: 1.8s	remaining: 1.93s
483:	learn: 248.4880522	total: 1.8s	remaining: 1.92s
484:	learn: 248.3630912	total: 1.81s	remaining: 1.92s
485:	learn: 248.2764340	total: 1.81s	remaining: 1.91s
486:	learn: 248.1896432	total: 1.81s	remaining: 1.91s
487:	learn: 248.0552516	total: 1.82s	remaining: 1.91s
488:	learn: 247.9019497	total: 1.82s	remaining: 1.9s
489:	learn: 247.7748540	total: 1.82s	remaining: 1.9s
490:	learn: 247.3798764	total: 1.83s	remaining: 1.89s
491:	learn: 247.2937896	total: 1.83s	remaining: 1.89s
492:	learn: 247.1971853	total: 1.83s	remaining: 1.89s
493:	learn: 247.1109600	total: 1.84s	remaining: 1.88s
494:	learn: 247.0250578	total: 1.84s	remaining: 1.88s
495:	learn: 246.9395789	total: 1.84s	remaining: 1.87s
496:	learn: 246.8550155	total: 1.85s	remaining: 1.87s
497:	learn: 246.6978230	total: 1.85s	remaining: 1.86s
498:	learn: 246.5651309	total: 1.85s	remaining: 1.86s
499:	learn: 246.0425353	total: 1.

634:	learn: 224.1341784	total: 2.29s	remaining: 1.32s
635:	learn: 224.0631654	total: 2.3s	remaining: 1.31s
636:	learn: 223.9923820	total: 2.3s	remaining: 1.31s
637:	learn: 223.9152479	total: 2.3s	remaining: 1.31s
638:	learn: 223.8450000	total: 2.31s	remaining: 1.3s
639:	learn: 223.7743431	total: 2.31s	remaining: 1.3s
640:	learn: 223.7039153	total: 2.31s	remaining: 1.3s
641:	learn: 223.6324371	total: 2.32s	remaining: 1.29s
642:	learn: 223.5630687	total: 2.32s	remaining: 1.29s
643:	learn: 222.9183545	total: 2.32s	remaining: 1.28s
644:	learn: 222.8483120	total: 2.33s	remaining: 1.28s
645:	learn: 222.7792754	total: 2.33s	remaining: 1.28s
646:	learn: 222.4129676	total: 2.33s	remaining: 1.27s
647:	learn: 222.3440280	total: 2.34s	remaining: 1.27s
648:	learn: 222.2747371	total: 2.34s	remaining: 1.26s
649:	learn: 222.2061628	total: 2.34s	remaining: 1.26s
650:	learn: 222.1377217	total: 2.35s	remaining: 1.26s
651:	learn: 222.0295008	total: 2.35s	remaining: 1.25s
652:	learn: 221.9612214	total: 2.3

797:	learn: 202.9332581	total: 2.83s	remaining: 718ms
798:	learn: 202.8759787	total: 2.84s	remaining: 714ms
799:	learn: 202.8183960	total: 2.84s	remaining: 711ms
800:	learn: 202.3242193	total: 2.85s	remaining: 707ms
801:	learn: 202.2670508	total: 2.85s	remaining: 704ms
802:	learn: 201.8385888	total: 2.85s	remaining: 700ms
803:	learn: 201.7807889	total: 2.86s	remaining: 696ms
804:	learn: 201.7232440	total: 2.86s	remaining: 693ms
805:	learn: 201.3527553	total: 2.86s	remaining: 689ms
806:	learn: 201.2979273	total: 2.87s	remaining: 685ms
807:	learn: 201.2112904	total: 2.87s	remaining: 682ms
808:	learn: 201.1537962	total: 2.87s	remaining: 678ms
809:	learn: 201.0795041	total: 2.88s	remaining: 674ms
810:	learn: 201.0214068	total: 2.88s	remaining: 671ms
811:	learn: 200.9674042	total: 2.88s	remaining: 667ms
812:	learn: 200.7047630	total: 2.88s	remaining: 664ms
813:	learn: 200.6467913	total: 2.89s	remaining: 660ms
814:	learn: 200.5899363	total: 2.89s	remaining: 657ms
815:	learn: 200.4781130	tota

955:	learn: 183.3982253	total: 3.37s	remaining: 155ms
956:	learn: 183.3333751	total: 3.38s	remaining: 152ms
957:	learn: 182.8255847	total: 3.38s	remaining: 148ms
958:	learn: 182.7500849	total: 3.38s	remaining: 145ms
959:	learn: 182.6869527	total: 3.39s	remaining: 141ms
960:	learn: 182.3273215	total: 3.39s	remaining: 138ms
961:	learn: 182.2831961	total: 3.39s	remaining: 134ms
962:	learn: 182.1675104	total: 3.4s	remaining: 131ms
963:	learn: 182.1219830	total: 3.4s	remaining: 127ms
964:	learn: 182.0730166	total: 3.4s	remaining: 123ms
965:	learn: 182.0243498	total: 3.41s	remaining: 120ms
966:	learn: 181.9563239	total: 3.41s	remaining: 116ms
967:	learn: 181.5704821	total: 3.41s	remaining: 113ms
968:	learn: 181.0779579	total: 3.42s	remaining: 109ms
969:	learn: 181.0289924	total: 3.42s	remaining: 106ms
970:	learn: 180.9805215	total: 3.42s	remaining: 102ms
971:	learn: 180.9322871	total: 3.43s	remaining: 98.7ms
972:	learn: 180.6222282	total: 3.43s	remaining: 95.2ms
973:	learn: 180.1284025	total

In [32]:
# RMSE Computation
rmse = np.sqrt(MSE(Y_test, pred))
print("RMSE : % f" %(rmse))

RMSE :  346.376622


In [33]:
r2_score(Y_test, pred)

0.7436798807717723

# Artificial Neural Network (ANN)

In [34]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [35]:
torch.manual_seed(1)  # Set seed for reproducibility.
trainloader = torch.utils.data.DataLoader(data, batch_size=10, shuffle=True, num_workers=1)

In [46]:
n_input, n_hidden, n_out, batch_size, learning_rate = 837, 400, 1, 100, 0.01

input_tensor = torch.from_numpy(X_train.to_numpy()).type(torch.FloatTensor)
label_tensor = torch.from_numpy(Y_train.to_numpy()).type(torch.FloatTensor)
test_input_tensor = torch.from_numpy(X_test.to_numpy()).type(torch.FloatTensor)

model = nn.Sequential(nn.Linear(n_input, n_hidden),
                      nn.ReLU(),
                      nn.Linear(n_hidden, n_out),
                      )
loss_function = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

losses = []
for epoch in range(5000):
    
    pred = model(input_tensor)
    loss = loss_function(pred, label_tensor)
    losses.append(loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [47]:
new_pred = model(test_input_tensor)
new_pred

tensor([[ 506.9208],
        [ 341.4448],
        [ 638.8336],
        [ 474.7303],
        [ 946.7957],
        [ 113.8599],
        [1454.0094],
        [1857.1428],
        [ 594.4770],
        [1315.2083],
        [ 725.6999],
        [ 495.3300],
        [1755.6742],
        [ 242.5662],
        [1720.2639],
        [ 836.6931],
        [ 683.5887],
        [ 948.6465],
        [  85.3613],
        [ 435.2157],
        [ 715.9808],
        [1907.6342],
        [1157.2522],
        [1446.3096],
        [1423.5701],
        [ 769.7714],
        [ 453.2327],
        [1328.3160],
        [ 952.9349],
        [1350.5042],
        [1254.0610],
        [1259.7931],
        [  32.3064],
        [ 841.5859],
        [1044.2864],
        [1111.8967],
        [1008.2763],
        [ 379.7319],
        [ 147.4953],
        [ 568.6996],
        [1440.7675],
        [1105.8916],
        [ 865.0492],
        [1609.0725],
        [ 925.0082],
        [1552.0439],
        [1110.1482],
        [ 548

In [48]:
test_label_tensor = torch.from_numpy(Y_test.to_numpy()).type(torch.FloatTensor)

In [49]:
from torchmetrics.functional import mean_squared_error
from torchmetrics.functional import r2_score

In [50]:
mean_squared_error(new_pred, test_label_tensor, squared = False)

tensor(393.3559, grad_fn=<SqrtBackward0>)

In [51]:
r2_score(new_pred, test_label_tensor)

tensor(0.6694, grad_fn=<MeanBackward0>)