In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from itertools import cycle
from matplotlib.cm import get_cmap
from tqdm import tqdm
import torch
import numpy as np
from sklearn.metrics import roc_curve, auc, roc_auc_score
from tqdm import tqdm
from sklearn.preprocessing import label_binarize

test_loss = 0.0
class_correct = list(0 for i in range(len(classes)))
class_total = list(0 for i in range(len(classes)))
all_targets = []
all_probs = []
all_preds = []
model.eval()
for data, target in tqdm(test_loader):
    data, target = data.to(device), target.to(device)
    with torch.no_grad():
        output = model(data)
        _, preds = torch.max(output, 1)
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(target.cpu().numpy())   
        probs = torch.nn.functional.softmax(output, dim=1)
        all_probs.extend(probs.cpu().numpy())
all_targets = np.array(all_targets)
all_probs = np.array(all_probs)
all_preds = np.array(all_preds)
all_targets_one_hot = label_binarize(all_targets, classes=np.arange(len(classes)))
if all_targets_one_hot.ndim == 1:
    all_targets_one_hot = np.expand_dims(all_targets_one_hot, axis=1)
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(len(classes)):
    fpr[i], tpr[i], _ = roc_curve(all_targets_one_hot[:, i], all_probs[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
cmap = plt.get_cmap('tab10')
plt.style.use('seaborn-darkgrid') 
plt.figure(figsize=(10, 8))
colors = cycle([cmap(i) for i in np.linspace(0, 1, len(classes))])
for i, color in zip(range(len(classes)), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=2,
             label=f'ROC curve for {classes[i]} (area = {roc_auc[i]:.2f})')
fpr["micro"], tpr["micro"], _ = roc_curve(all_targets_one_hot.ravel(), all_probs.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(len(classes))]))
mean_tpr = np.zeros_like(all_fpr)
for i in range(len(classes)):
    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
mean_tpr /= len(classes)
plt.plot(all_fpr, mean_tpr,
         label=f'Macro-average ROC curve (area = {auc(all_fpr, mean_tpr):.2f})',
         color='darkorange', linestyle='--', linewidth=4)

plt.xlabel('False Positive Rate', fontsize=16)
plt.ylabel('True Positive Rate', fontsize=16)
plt.legend(loc='lower right', fontsize=16)
plt.grid(True)
plt.show()