## Load Libraries

In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import random
import torch

## SET SEED

In [None]:
SEED = 456
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

## Confusion Matrix

- inference 후에 나온 output과 비교해야 한다.
- SEED는 동일해야 한다.

In [None]:
data = pd.read_csv("train.csv")

dataset_train, dataset_valid = train_test_split(data, test_size=0.2, stratify=data['target'],random_state=SEED) 
dataset_test = pd.read_csv("confusion_matrix.csv")   # dataset_valid의 inference 결과

In [None]:
targets = list(dataset_valid['target'])
preds = list(dataset_test['target'])

In [None]:
label_list = [0,1,2,3,4,5,6]

val_labels = targets
pred_labels = preds

val_nums_arr = np.array(val_labels)
pred_answer_arr = np.array(pred_labels)

cm = confusion_matrix(val_nums_arr, pred_answer_arr)
cm_a = confusion_matrix(val_nums_arr, val_nums_arr)

- basic confusion matrix

In [None]:
sns.set(rc={'figure.figsize':(17, 15)})

ax = sns.heatmap(cm, annot=True, cmap='Blues', fmt="d",
                linewidths=.5, 
                annot_kws={"size": 10})
ax.set_title('Confusion Matrix\n',fontsize=20)
ax.set_xlabel('\nPredicted Labels')
ax.set_ylabel('True Labels ')

ax.yaxis.set_ticklabels(label_list, rotation=0, ha="right")
plt.show()

- normalize confusion matrix

In [None]:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
cm_a = cm_a.astype('float') / cm.sum(axis=1)[:, np.newaxis]


# plot
sns.set(rc={'figure.figsize':(17, 15)})

ax = sns.heatmap(cm, annot=True, cmap='Blues', fmt=".1f",
                linewidths=.5, 
                annot_kws={"size": 10},
                vmin = 0.0, vmax = 1.0)
ax.set_title('Confusion Matrix\n',fontsize=20)
ax.set_xlabel('\nPredicted Labels')
ax.set_ylabel('True Labels ')
# ax.set_xlim(0, 1)

ax.yaxis.set_ticklabels(label_list, rotation=0, ha="right")
plt.show()