In [39]:
import numpy as np
from modules.models.modelling_functions import *
from modules.eda_functions import *
from xgboost import XGBClassifier
from sklearn.metrics import confusion_matrix, classification_report

In [40]:
path = '../data/lags/'

In [41]:
all_data_as_dict = import_all_files_as_dict(path)

In [42]:
all_data_as_dict

{'ABEO':              Close   lag_1   lag_2   lag_3   lag_4   lag_5  label
 Date                                                             
 2017-09-12  345.00  350.00  346.25  346.25  335.00  355.00      0
 2017-09-13  350.00  345.00  350.00  346.25  346.25  335.00      1
 2017-09-14  350.00  350.00  345.00  350.00  346.25  346.25      1
 2017-09-15  430.00  350.00  350.00  345.00  350.00  346.25      1
 2017-09-18  412.50  430.00  350.00  350.00  345.00  350.00      0
 ...            ...     ...     ...     ...     ...     ...    ...
 2022-08-15    4.65    4.43    4.66    4.16    3.95    4.02      1
 2022-08-16    4.77    4.65    4.43    4.66    4.16    3.95      1
 2022-08-17    4.60    4.77    4.65    4.43    4.66    4.16      0
 2022-08-18    4.36    4.60    4.77    4.65    4.43    4.66      0
 2022-08-19    4.11    4.36    4.60    4.77    4.65    4.43      0
 
 [1244 rows x 7 columns],
 'ABIO':                 Close  lag_1  lag_2      lag_3      lag_4      lag_5  label
 Date   

In [43]:
models = {}
for company in all_data_as_dict.keys():
    X, y = all_data_as_dict[company][['Close', 'lag_1']], all_data_as_dict[company]['label']
    X_train, y_train, X_dev, y_dev, X_test, y_test = split_train_dev_test(X, y)
    model = XGBClassifier(n_estimators=100)
    model.fit(X_train, y_train)
    y_hat = model.predict(X_dev)

    models[company] = {
        'model': model,
        'y_hat': y_hat,
        'X_dev': X_dev,
        'y_dev': y_dev,
        'X_test': X_test,
        'y_test': y_test
    }

In [44]:
costs = {}
for company in all_data_as_dict.keys():
    cost = models[company]['model'].score(models[company]['X_dev'], models[company]['y_hat'])
    costs[company] = cost

In [45]:
costs

{'ABEO': 1.0,
 'ABIO': 1.0,
 'ABUS': 1.0,
 'ACAD': 1.0,
 'ACER': 1.0,
 'ACHN': 1.0,
 'ACHV': 1.0,
 'ACIU': 1.0,
 'ACOR': 1.0,
 'ACRS': 1.0,
 'ACST': 1.0,
 'ADAP': 1.0,
 'ADIL': 1.0,
 'ADMA': 1.0,
 'ADVM': 1.0,
 'ADXS': 1.0,
 'AEZS': 1.0,
 'AFMD': 1.0,
 'AGE': 1.0,
 'AGEN': 1.0,
 'AGIO': 1.0,
 'AGLE': 1.0,
 'AGTC': 1.0,
 'AKBA': 1.0,
 'AKTX': 1.0,
 'ALBO': 1.0,
 'ALDX': 1.0,
 'ALKS': 1.0,
 'ALLK': 1.0,
 'ALLO': 1.0,
 'ALNA': 1.0,
 'ALNY': 1.0,
 'ALPN': 1.0,
 'ALRN': 1.0,
 'ALT': 1.0,
 'AMGN': 1.0,
 'AMPE': 1.0,
 'AMRN': 1.0,
 'ANAB': 1.0,
 'ANIK': 1.0,
 'ANIP': 1.0,
 'APLS': 1.0,
 'APM': 1.0,
 'APTO': 1.0,
 'APTX': 1.0,
 'APVO': 1.0,
 'AQB': 1.0,
 'AQST': 1.0,
 'ARAV': 1.0,
 'ARCT': 1.0,
 'ARDS': 1.0,
 'ARDX': 1.0,
 'ARGX': 1.0,
 'ARRY': 1.0,
 'ARVN': 1.0,
 'ARWR': 1.0,
 'ASLN': 1.0,
 'ASMB': 1.0,
 'ASND': 1.0,
 'ASNS': 1.0,
 'ATHX': 1.0,
 'ATNM': 1.0,
 'ATNX': 1.0,
 'ATRA': 1.0,
 'AUPH': 1.0,
 'AUTL': 1.0,
 'AVDL': 1.0,
 'AVEO': 1.0,
 'AVRO': 1.0,
 'AVXL': 1.0,
 'AXON': 1.0,
 'AXSM': 1

In [46]:
predictions = {}
for company in all_data_as_dict.keys():
    X_test = models[company]['X_test']
    y_hat = models[company]['model'].predict(X_test)
    predictions[company] = y_hat

In [47]:
for company in all_data_as_dict.keys():
    print(f"Confusion Matrix {company}:\n{confusion_matrix(models[company]['y_test'], predictions[company])}")

Confusion Matrix ABEO:
[[ 82   0]
 [106   0]]
Confusion Matrix ABIO:
[[93  0]
 [94  1]]
Confusion Matrix ABUS:
[[91  4]
 [18 75]]
Confusion Matrix ACAD:
[[80 12]
 [ 6 90]]
Confusion Matrix ACER:
[[59 32]
 [30 67]]
Confusion Matrix ACHN:
[[30]]
Confusion Matrix ACHV:
[[55 46]
 [15 72]]
Confusion Matrix ACIU:
[[81 23]
 [46 38]]
Confusion Matrix ACOR:
[[ 5 97]
 [ 0 86]]
Confusion Matrix ACRS:
[[78 10]
 [37 63]]
Confusion Matrix ACST:
[[ 1 89]
 [ 0 98]]
Confusion Matrix ADAP:
[[90 13]
 [38 47]]
Confusion Matrix ADIL:
[[77  6]
 [16 53]]
Confusion Matrix ADMA:
[[37 39]
 [18 94]]
Confusion Matrix ADVM:
[[ 0 93]
 [ 0 95]]
Confusion Matrix ADXS:
[[102   2]
 [ 76   8]]
Confusion Matrix AEZS:
[[78  2]
 [54 54]]
Confusion Matrix AFMD:
[[82 17]
 [ 7 82]]
Confusion Matrix AGE:
[[69  2]
 [43 25]]
Confusion Matrix AGEN:
[[95  0]
 [27 66]]
Confusion Matrix AGIO:
[[103   2]
 [ 70  13]]
Confusion Matrix AGLE:
[[106   1]
 [ 65  16]]
Confusion Matrix AGTC:
[[  2 101]
 [  0  85]]
Confusion Matrix AKBA:
[[96

In [48]:
for company in all_data_as_dict.keys():
    print(f"Report: {company}:\n{classification_report(models[company]['y_test'], predictions[company])}")

Report: ABEO:
              precision    recall  f1-score   support

           0       0.44      1.00      0.61        82
           1       0.00      0.00      0.00       106

    accuracy                           0.44       188
   macro avg       0.22      0.50      0.30       188
weighted avg       0.19      0.44      0.26       188

Report: ABIO:
              precision    recall  f1-score   support

           0       0.50      1.00      0.66        93
           1       1.00      0.01      0.02        95

    accuracy                           0.50       188
   macro avg       0.75      0.51      0.34       188
weighted avg       0.75      0.50      0.34       188

Report: ABUS:
              precision    recall  f1-score   support

           0       0.83      0.96      0.89        95
           1       0.95      0.81      0.87        93

    accuracy                           0.88       188
   macro avg       0.89      0.88      0.88       188
weighted avg       0.89      0.8

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Report: APLS:
              precision    recall  f1-score   support

           0       0.82      0.82      0.82        90
           1       0.82      0.82      0.82        89

    accuracy                           0.82       179
   macro avg       0.82      0.82      0.82       179
weighted avg       0.82      0.82      0.82       179

Report: APM:
              precision    recall  f1-score   support

           0       0.50      0.95      0.66        65
           1       0.77      0.14      0.24        72

    accuracy                           0.53       137
   macro avg       0.63      0.55      0.45       137
weighted avg       0.64      0.53      0.43       137

Report: APTO:
              precision    recall  f1-score   support

           0       0.83      0.25      0.38        96
           1       0.55      0.95      0.69        92

    accuracy                           0.59       188
   macro avg       0.69      0.60      0.54       188
weighted avg       0.69      0.59

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Report: CCXI:
              precision    recall  f1-score   support

           0       0.56      0.92      0.69        95
           1       0.75      0.26      0.38        93

    accuracy                           0.59       188
   macro avg       0.65      0.59      0.54       188
weighted avg       0.65      0.59      0.54       188

Report: CDMO:
              precision    recall  f1-score   support

           0       0.81      0.18      0.29        97
           1       0.52      0.96      0.67        91

    accuracy                           0.55       188
   macro avg       0.67      0.57      0.48       188
weighted avg       0.67      0.55      0.48       188

Report: CDTX:
              precision    recall  f1-score   support

           0       0.89      0.16      0.26       103
           1       0.49      0.98      0.65        85

    accuracy                           0.53       188
   macro avg       0.69      0.57      0.46       188
weighted avg       0.71      0.5

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Report: GTHX:
              precision    recall  f1-score   support

           0       0.81      0.23      0.35        93
           1       0.56      0.95      0.70        95

    accuracy                           0.59       188
   macro avg       0.68      0.59      0.53       188
weighted avg       0.68      0.59      0.53       188

Report: HALO:
              precision    recall  f1-score   support

           0       0.89      0.37      0.52        93
           1       0.61      0.96      0.74        95

    accuracy                           0.66       188
   macro avg       0.75      0.66      0.63       188
weighted avg       0.75      0.66      0.63       188

Report: IBIO:
              precision    recall  f1-score   support

           0       0.74      0.84      0.79        88
           1       0.84      0.74      0.79       100

    accuracy                           0.79       188
   macro avg       0.79      0.79      0.79       188
weighted avg       0.79      0.7

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Report: MRKR:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        89
           1       0.53      1.00      0.69        99

    accuracy                           0.53       188
   macro avg       0.26      0.50      0.34       188
weighted avg       0.28      0.53      0.36       188

Report: MRNA:
              precision    recall  f1-score   support

           0       0.88      0.76      0.81        78
           1       0.74      0.87      0.80        61

    accuracy                           0.81       139
   macro avg       0.81      0.81      0.81       139
weighted avg       0.82      0.81      0.81       139

Report: MRNS:
              precision    recall  f1-score   support

           0       0.88      0.87      0.88       106
           1       0.83      0.85      0.84        82

    accuracy                           0.86       188
   macro avg       0.86      0.86      0.86       188
weighted avg       0.86      0.8

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

Report: QURE:
              precision    recall  f1-score   support

           0       0.75      0.83      0.78        92
           1       0.81      0.73      0.77        96

    accuracy                           0.78       188
   macro avg       0.78      0.78      0.78       188
weighted avg       0.78      0.78      0.78       188

Report: RARE:
              precision    recall  f1-score   support

           0       0.93      0.90      0.91       100
           1       0.89      0.92      0.91        88

    accuracy                           0.91       188
   macro avg       0.91      0.91      0.91       188
weighted avg       0.91      0.91      0.91       188

Report: RCKT:
              precision    recall  f1-score   support

           0       0.93      0.92      0.93        92
           1       0.93      0.94      0.93        96

    accuracy                           0.93       188
   macro avg       0.93      0.93      0.93       188
weighted avg       0.93      0.9

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr