In [1]:
import numpy as np
np.random.seed(42)
import pandas as pd
from sklearn.model_selection import cross_val_predict
from sklearn.ensemble import RandomForestClassifier
from sklearn import metrics
import time

In [2]:
# time.sleep(1000)

In [3]:
def get_metrics(y_pred, y_test, to_print=True):
    correct_labels = np.where(y_pred==y_test)[0]
    accuracy = metrics.accuracy_score(y_test, y_pred)
    precision = metrics.precision_score(y_test, y_pred,average='macro')
    recall = metrics.recall_score(y_test, y_pred,average='macro')
    f1score = metrics.f1_score(y_test, y_pred,average='macro')
    # rocscore = metrics.roc_auc_score(y_test, y_pred,average='micro',multi_class="ovo")
    confusion_matrix = metrics.confusion_matrix(y_test, y_pred)  
    classification_report = metrics.classification_report(y_test, y_pred)

    if to_print:
        print("Identified {} correct labels out of {} labels".format(len(correct_labels), y_test.shape[0]))
        print("Accuracy:",accuracy)
        print("Precision:",precision)
        print("Recall:",recall)
        print("F1 Score:",f1score)
        # print("ROC AUC Score:",rocscore)
        print("Confusion Matrix:\n", confusion_matrix)
        print("Classification_Report:\n", classification_report)

#     return (correct_labels, accuracy, precision, recall, confusion_matrix, classification_report)

In [4]:
def rfcfit(fname):
    print("*****************")
    print(fname)
    
    
    df = pd.read_csv(f"saved_fits/{fname}.csv")
    avgtime = df.dropna()['timetaken'].mean()
    failedobjs = len(df) - len(df.dropna())
    
    print(f"There are {failedobjs} objects for whom at least 1 passband's fit failed.")
    print(f"Average time taken for 1 full object is {avgtime} seconds per object for all 6 passbands")
    df = df.fillna(0)
    X = df[df.columns[3:]].values
    y = df["objclass"].values
    
    rfc = RandomForestClassifier(n_estimators=200,random_state=42)
    preds = cross_val_predict(rfc, X, y, cv=10)
    
    get_metrics(preds,y)
    f1score = metrics.f1_score(preds,y,average='macro')
    
    return avgtime, f1score, failedobjs

In [5]:
allparams = ["oldbazin","bazin","fred",
             "karpenka","villar","alercev1","alercev2"]
paramname = ["Bazin - old implementation","Bazin - new implementation","FRED",
             "Karpenka","Villar","Alercev1", "Alercev2"]
columnname = ["Parameterisation Used","Average Time per Object for all 6 passbands (s)",
              "Average F1 Score for RFC (n=150)", "Total objects (of 480) where fit failed (for ≥ 1 passband)"]

In [6]:
summarylist = []
for i, param in enumerate(allparams):
    res = list(rfcfit(param))
    summarylist.append([paramname[i]]+res)

*****************
oldbazin
There are 0 objects for whom at least 1 passband's fit failed.
Average time taken for 1 full object is 0.00814443197515271 seconds per object for all 6 passbands
Identified 279 correct labels out of 480 labels
Accuracy: 0.58125
Precision: 0.5612961875561917
Recall: 0.58125
F1 Score: 0.5675510058897013
Confusion Matrix:
 [[55  0  0  0  1  0  0  4]
 [ 1 39  5  2  1  1  5  6]
 [ 0  4 43  3  0  5  4  1]
 [ 0  3  1 29 10 13  4  0]
 [ 0  1  2  8 39  7  3  0]
 [ 0  2  9 12  6 22  9  0]
 [ 0 11 13  8 11 12  5  0]
 [ 1  3  5  1  2  1  0 47]]
Classification_Report:
               precision    recall  f1-score   support

         AGN       0.96      0.92      0.94        60
      SLSN-I       0.62      0.65      0.63        60
        SNII       0.55      0.72      0.62        60
        SNIa       0.46      0.48      0.47        60
   SNIa-91bg       0.56      0.65      0.60        60
       SNIax       0.36      0.37      0.36        60
       SNIbc       0.17      0.

# SUMMARY

In [7]:
summarydf = pd.DataFrame(data=summarylist,columns=columnname) 
summarydf.style.set_properties(align="right")

Unnamed: 0,Parameterisation Used,Average Time per Object for all 6 passbands (s),Average F1 Score for RFC (n=150),Total objects (of 480) where fit failed (for ≥ 1 passband)
0,Bazin - old implementation,0.008144,0.567551,0
1,Bazin - new implementation,0.037019,0.61107,11
2,FRED,0.051502,0.453149,54
3,Karpenka,0.140439,0.423564,12
4,Villar,0.042789,0.425545,0
5,Alercev1,0.054733,0.490861,0
6,Alercev2,0.095001,0.481691,10


In [10]:
print(summarydf.to_markdown())

|    | Parameterisation Used      |   Average Time per Object for all 6 passbands (s) |   Average F1 Score for RFC (n=150) |   Total objects (of 480) where fit failed (for ≥ 1 passband) |
|---:|:---------------------------|--------------------------------------------------:|-----------------------------------:|-------------------------------------------------------------:|
|  0 | Bazin - old implementation |                                        0.00814443 |                           0.567551 |                                                            0 |
|  1 | Bazin - new implementation |                                        0.0370185  |                           0.61107  |                                                           11 |
|  2 | FRED                       |                                        0.0515021  |                           0.453149 |                                                           54 |
|  3 | Karpenka                   |                         

<!-- |Parameterisation Used| Average Time per Object for all   6 passbands (s) | Average F1 Score for RFC (n=150) | Total objects (of 480) where fit failed   (for ≥ 1 passband) |   |
|:----------------------:|:---------------------------------------------------:|:----------------:|:------------------------------------------------------:|---|
| Bazin                  | 0.0309                                              | 0.56 ± 0.02      | 11                                                     |   |
| FRED                   | 0.0452                                              | 0.47 ± 0.02      | 54                                                    |   |
| Alerce: Sid Guesses    | 0.0789                                              | 0.47 ± 0.02      | 10                                                     |   |
| Alerce: Alerce Guesses | 0.0947                                              | 0.47 ± 0.02      | 10                                                     |   |
| Alerce: Using R        | 0.0952                                              | 0.45 ± 0.02      | 13                                                     |   |
| Villar                 | 0.0375                                              | 0.50 ± 0.02      | 0                                                      |   | -->