In [None]:
import keras
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import breakdown
import preprocess

from matplotlib import colors
from sklearn.metrics import confusion_matrix

def confusion_matrix_dataframe(Y_true, Y_pred):
  cm = confusion_matrix(np.argmax(Y_true.values, axis=1), np.argmax(Y_pred, axis=1))
  return pd.DataFrame(cm, columns=Y_true.columns, index=Y_true.columns)

def plot_confusion_matrix(cmdf, cmap='rocket'):
  ax = sns.heatmap(data=cmdf, annot=True, fmt='d', square=True, norm=colors.PowerNorm(gamma=2/5), cmap=cmap)
  ax.set_xlabel('Predicted label')
  ax.set_ylabel('True label')
  return ax

In [None]:
FILE = '../../finalproject/wildfires-shuffled.parquet'
(X_train, Y_train), (X_test, Y_test) = preprocess.load_dataset(FILE)

## Confusion matrices

In [None]:
model = keras.models.load_model('../models/mlp-4x512-with-bnorm.h5')

In [None]:
Y_pred = model.predict(X_test)

In [None]:
cmdf = confusion_matrix_dataframe(Y_test, Y_pred)
plot_confusion_matrix(cmdf).figure.savefig('mlp_confusion_mtx.pdf', bbox_inches='tight')

In [None]:
# copy-pasted from Jerry
logreg_labels = '''
Lightning
Equipment Use
Smoking
Campfire
Debris Burning
Railroad
Arson
Children
Fireworks
Powerline
Structure
'''.strip().split("\n")
logreg_cm = np.array([
  [21160, 2, 0, 0, 6370, 0, 148, 0, 0, 0, 0],
  [5566, 1, 0, 0, 9066, 0, 49, 0, 0, 0, 0],
  [1499, 1, 0, 0, 3693, 0, 26, 0, 0, 0, 0],
  [4369, 0, 0, 0, 3276, 0, 31, 0, 0, 0, 0],
  [4517, 0, 0, 0, 38444, 0, 68, 0, 0, 0, 0],
  [364, 0, 0, 0, 2961, 0, 9, 0, 0, 0, 0],
  [3488, 0, 0, 0, 24430, 0,  174, 0, 0, 0, 0],
  [1579, 0, 0, 0, 4473, 0, 28, 0, 0, 0, 0],
  [492, 0, 0, 0, 564, 0, 1, 0, 0, 0, 0],
  [377, 0, 0, 0, 1056, 0, 7, 0, 0, 0, 0],
  [121, 0, 0, 0, 256, 0, 0, 0, 0, 0, 0],
])
logreg_cmdf = pd.DataFrame(logreg_cm, columns=logreg_labels, index=logreg_labels)
plot_confusion_matrix(logreg_cmdf).figure.savefig('logreg_confusion_mtx.pdf', bbox_inches='tight')