In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from vit_pytorch import ViT
from tqdm import tqdm  # Import tqdm
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [15]:
train_root_dir="./../data_processed/train"
val_root_dir="./../data_processed/validation"
test_root_dir="./../data_processed/test"

In [16]:
# Define the hyperparameters
batch_size = 32
learning_rate = 1e-4
num_epochs = 60

In [17]:
####################################
# Training
####################################

trans={
    # Train uses data augmentation
    'train':
    transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomRotation(degrees=15),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.4762, 0.3054, 0.2368],
                             [0.3345, 0.2407, 0.2164])
    ]),
    # Validation does not use augmentation
    'valid':
    transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.4762, 0.3054, 0.2368],
                             [0.3345, 0.2407, 0.2164])
    ]),
    
    # Test does not use augmentation
    'test':
    transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.4762, 0.3054, 0.2368],
                             [0.3345, 0.2407, 0.2164])
    ]),
}

In [18]:
#Generators
training_dataset = ImageFolder(train_root_dir,transform=trans['train'])
validation_dataset = ImageFolder(val_root_dir,transform=trans['valid'])
test_dataset = ImageFolder(test_root_dir,transform=trans['test'])

train_dataloader = DataLoader(training_dataset,batch_size,shuffle=True) # ** unpacks a dictionary into keyword arguments
val_dataloader = DataLoader(validation_dataset,batch_size)
test_dataloader = DataLoader(test_dataset,batch_size)

print('Number of Training set images:{}'.format(len(training_dataset)))
print('Number of Validation set images:{}'.format(len(validation_dataset)))
print('Number of Test set images:{}'.format(len(test_dataset)))

Number of Training set images:1063
Number of Validation set images:266
Number of Test set images:444


In [19]:
# Initialize the Vision Transformer model
model = ViT(
    image_size=224,
    patch_size=32,
    num_classes=len(training_dataset.classes),
    dim=768,  # Dimension of the model
    depth=12,  # Number of transformer layers
    heads=12,  # Number of attention heads
    mlp_dim=3072,  # Dimension of the MLP layers
    dropout=0.1
)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Set the device to use (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print("Device: ", device)
model.to(device)

Device:  cuda


ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
    (1): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=3072, out_features=768, bias=True)
    (3): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): Transformer(
    (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0-11): 12 x ModuleList(
        (0): Attention(
          (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attend): Softmax(dim=-1)
          (dropout): Dropout(p=0.1, inplace=False)
          (to_qkv): Linear(in_features=768, out_features=2304, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=768, out_features=768, bias=True)
            (1): Dropout(p=0.1, inplace=False)
          )
        )
        (1): FeedForward(
          (net): Sequential(
          

## Training

In [20]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    # Use tqdm for a progress bar
    with tqdm(total=len(train_dataloader), desc=f'Epoch {epoch+1}/{num_epochs}', unit='batch') as pbar:
        for i, (images, labels) in enumerate(train_dataloader):
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            # Update the progress bar
            pbar.set_postfix(loss=running_loss / (i + 1))
            pbar.update(1)

    print(f'Training - Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_dataloader)}')

    # Validation
    model.eval()
    val_loss = 0.0
    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for images, labels in val_dataloader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())

    print(f'Validation - Epoch {epoch+1}/{num_epochs}, Loss: {val_loss/len(val_dataloader)}')

    # Compute and print validation accuracy
    accuracy = accuracy_score(true_labels, predicted_labels)
    print(f'Validation - Epoch {epoch+1}/{num_epochs}, Accuracy: {accuracy * 100:.2f}%')

# Save the trained model
torch.save(model.state_dict(), 'vit_model.pth')
print("Training completed and modal saved to vit_model.pth")

Epoch 1/60: 100%|██████████| 34/34 [00:16<00:00,  2.11batch/s, loss=0.975]


Training - Epoch 1/60, Loss: 0.9750384872450548
Validation - Epoch 1/60, Loss: 0.7162366824017631
Validation - Epoch 1/60, Accuracy: 64.29%


Epoch 2/60: 100%|██████████| 34/34 [00:15<00:00,  2.16batch/s, loss=0.578]


Training - Epoch 2/60, Loss: 0.5781842014368843
Validation - Epoch 2/60, Loss: 0.6091475735108057
Validation - Epoch 2/60, Accuracy: 71.80%


Epoch 3/60: 100%|██████████| 34/34 [00:16<00:00,  2.04batch/s, loss=0.552]


Training - Epoch 3/60, Loss: 0.5521366429679534
Validation - Epoch 3/60, Loss: 0.6078521874215868
Validation - Epoch 3/60, Accuracy: 71.80%


Epoch 4/60: 100%|██████████| 34/34 [00:16<00:00,  2.08batch/s, loss=0.513]


Training - Epoch 4/60, Loss: 0.5125803027082892
Validation - Epoch 4/60, Loss: 0.5282057490613725
Validation - Epoch 4/60, Accuracy: 77.82%


Epoch 5/60: 100%|██████████| 34/34 [00:16<00:00,  2.01batch/s, loss=0.434]


Training - Epoch 5/60, Loss: 0.4344012868755004
Validation - Epoch 5/60, Loss: 0.5262392353680398
Validation - Epoch 5/60, Accuracy: 71.43%


Epoch 6/60: 100%|██████████| 34/34 [00:16<00:00,  2.04batch/s, loss=0.383]


Training - Epoch 6/60, Loss: 0.38323995514827613
Validation - Epoch 6/60, Loss: 0.31708698057466084
Validation - Epoch 6/60, Accuracy: 87.22%


Epoch 7/60: 100%|██████████| 34/34 [00:16<00:00,  2.08batch/s, loss=0.316]


Training - Epoch 7/60, Loss: 0.31627604628310485
Validation - Epoch 7/60, Loss: 0.26430418714880943
Validation - Epoch 7/60, Accuracy: 89.10%


Epoch 8/60: 100%|██████████| 34/34 [00:16<00:00,  2.06batch/s, loss=0.244]


Training - Epoch 8/60, Loss: 0.243776004323188
Validation - Epoch 8/60, Loss: 0.24253814833031762
Validation - Epoch 8/60, Accuracy: 89.10%


Epoch 9/60: 100%|██████████| 34/34 [00:16<00:00,  2.07batch/s, loss=0.23] 


Training - Epoch 9/60, Loss: 0.23035268277368126
Validation - Epoch 9/60, Loss: 0.23725675460365084
Validation - Epoch 9/60, Accuracy: 90.98%


Epoch 10/60: 100%|██████████| 34/34 [00:16<00:00,  2.10batch/s, loss=0.177]


Training - Epoch 10/60, Loss: 0.17697497073780089
Validation - Epoch 10/60, Loss: 0.20667441354857552
Validation - Epoch 10/60, Accuracy: 92.48%


Epoch 11/60: 100%|██████████| 34/34 [00:16<00:00,  2.03batch/s, loss=0.154]


Training - Epoch 11/60, Loss: 0.15429790909675992
Validation - Epoch 11/60, Loss: 0.22717159613966942
Validation - Epoch 11/60, Accuracy: 92.86%


Epoch 12/60: 100%|██████████| 34/34 [00:16<00:00,  2.11batch/s, loss=0.184]


Training - Epoch 12/60, Loss: 0.18359637096085968
Validation - Epoch 12/60, Loss: 0.17844223769174683
Validation - Epoch 12/60, Accuracy: 93.98%


Epoch 13/60: 100%|██████████| 34/34 [00:16<00:00,  2.03batch/s, loss=0.128]


Training - Epoch 13/60, Loss: 0.12833824864698246
Validation - Epoch 13/60, Loss: 0.24857278344117933
Validation - Epoch 13/60, Accuracy: 92.11%


Epoch 14/60: 100%|██████████| 34/34 [00:16<00:00,  2.10batch/s, loss=0.155]


Training - Epoch 14/60, Loss: 0.15486405592630892
Validation - Epoch 14/60, Loss: 0.21072234865278006
Validation - Epoch 14/60, Accuracy: 92.48%


Epoch 15/60: 100%|██████████| 34/34 [00:16<00:00,  2.11batch/s, loss=0.146]


Training - Epoch 15/60, Loss: 0.14641663174637976
Validation - Epoch 15/60, Loss: 0.17270902906441027
Validation - Epoch 15/60, Accuracy: 95.11%


Epoch 16/60: 100%|██████████| 34/34 [00:15<00:00,  2.16batch/s, loss=0.135] 


Training - Epoch 16/60, Loss: 0.13490501402274652
Validation - Epoch 16/60, Loss: 0.3190892085226046
Validation - Epoch 16/60, Accuracy: 89.10%


Epoch 17/60: 100%|██████████| 34/34 [00:16<00:00,  2.06batch/s, loss=0.137]


Training - Epoch 17/60, Loss: 0.13744167621959658
Validation - Epoch 17/60, Loss: 0.2091641538362536
Validation - Epoch 17/60, Accuracy: 92.11%


Epoch 18/60: 100%|██████████| 34/34 [00:16<00:00,  2.02batch/s, loss=0.127]


Training - Epoch 18/60, Loss: 0.12659376048866441
Validation - Epoch 18/60, Loss: 0.15397770770101082
Validation - Epoch 18/60, Accuracy: 93.98%


Epoch 19/60: 100%|██████████| 34/34 [00:17<00:00,  1.96batch/s, loss=0.0853]


Training - Epoch 19/60, Loss: 0.08527433692806345
Validation - Epoch 19/60, Loss: 0.22938039526343346
Validation - Epoch 19/60, Accuracy: 92.11%


Epoch 20/60: 100%|██████████| 34/34 [00:17<00:00,  1.98batch/s, loss=0.0888]


Training - Epoch 20/60, Loss: 0.08875076299163458
Validation - Epoch 20/60, Loss: 0.15679437077293792
Validation - Epoch 20/60, Accuracy: 94.36%


Epoch 21/60: 100%|██████████| 34/34 [00:16<00:00,  2.06batch/s, loss=0.0883]


Training - Epoch 21/60, Loss: 0.08834102827891269
Validation - Epoch 21/60, Loss: 0.17747894131268063
Validation - Epoch 21/60, Accuracy: 94.36%


Epoch 22/60: 100%|██████████| 34/34 [00:17<00:00,  1.90batch/s, loss=0.0817]


Training - Epoch 22/60, Loss: 0.08169761928729713
Validation - Epoch 22/60, Loss: 0.22494748369273213
Validation - Epoch 22/60, Accuracy: 93.61%


Epoch 23/60: 100%|██████████| 34/34 [00:17<00:00,  1.99batch/s, loss=0.07]  


Training - Epoch 23/60, Loss: 0.07003449896514855
Validation - Epoch 23/60, Loss: 0.26766304299235344
Validation - Epoch 23/60, Accuracy: 93.23%


Epoch 24/60: 100%|██████████| 34/34 [00:16<00:00,  2.01batch/s, loss=0.101] 


Training - Epoch 24/60, Loss: 0.10111265891122029
Validation - Epoch 24/60, Loss: 0.19677299592230055
Validation - Epoch 24/60, Accuracy: 94.36%


Epoch 25/60: 100%|██████████| 34/34 [00:16<00:00,  2.10batch/s, loss=0.0936]


Training - Epoch 25/60, Loss: 0.09355558364597313
Validation - Epoch 25/60, Loss: 0.19800789571470684
Validation - Epoch 25/60, Accuracy: 93.23%


Epoch 26/60: 100%|██████████| 34/34 [00:15<00:00,  2.14batch/s, loss=0.0856]


Training - Epoch 26/60, Loss: 0.08559039936346166
Validation - Epoch 26/60, Loss: 0.2362109989238282
Validation - Epoch 26/60, Accuracy: 95.11%


Epoch 27/60: 100%|██████████| 34/34 [00:15<00:00,  2.17batch/s, loss=0.0657]


Training - Epoch 27/60, Loss: 0.06572697982739877
Validation - Epoch 27/60, Loss: 0.27786060858407935
Validation - Epoch 27/60, Accuracy: 94.74%


Epoch 28/60: 100%|██████████| 34/34 [00:15<00:00,  2.20batch/s, loss=0.0645]


Training - Epoch 28/60, Loss: 0.06445699817199699
Validation - Epoch 28/60, Loss: 0.2067231144497378
Validation - Epoch 28/60, Accuracy: 93.98%


Epoch 29/60: 100%|██████████| 34/34 [00:15<00:00,  2.19batch/s, loss=0.0744]


Training - Epoch 29/60, Loss: 0.07440256923847996
Validation - Epoch 29/60, Loss: 0.2484858331994878
Validation - Epoch 29/60, Accuracy: 92.11%


Epoch 30/60: 100%|██████████| 34/34 [00:16<00:00,  2.11batch/s, loss=0.0649]


Training - Epoch 30/60, Loss: 0.06490199001478579
Validation - Epoch 30/60, Loss: 0.19944150207771194
Validation - Epoch 30/60, Accuracy: 95.49%


Epoch 31/60: 100%|██████████| 34/34 [00:16<00:00,  2.06batch/s, loss=0.0598]


Training - Epoch 31/60, Loss: 0.059767303152886385
Validation - Epoch 31/60, Loss: 0.23404530292868408
Validation - Epoch 31/60, Accuracy: 94.36%


Epoch 32/60: 100%|██████████| 34/34 [00:16<00:00,  2.05batch/s, loss=0.0461]


Training - Epoch 32/60, Loss: 0.04606600509620929
Validation - Epoch 32/60, Loss: 0.28012469108216465
Validation - Epoch 32/60, Accuracy: 93.98%


Epoch 33/60: 100%|██████████| 34/34 [00:16<00:00,  2.09batch/s, loss=0.0896]


Training - Epoch 33/60, Loss: 0.08964276596260093
Validation - Epoch 33/60, Loss: 0.25602914672344923
Validation - Epoch 33/60, Accuracy: 93.98%


Epoch 34/60: 100%|██████████| 34/34 [00:16<00:00,  2.05batch/s, loss=0.132]


Training - Epoch 34/60, Loss: 0.13165940193678527
Validation - Epoch 34/60, Loss: 0.16018982958565983
Validation - Epoch 34/60, Accuracy: 93.23%


Epoch 35/60: 100%|██████████| 34/34 [00:16<00:00,  2.12batch/s, loss=0.0782]


Training - Epoch 35/60, Loss: 0.07823752454372451
Validation - Epoch 35/60, Loss: 0.16622119841890204
Validation - Epoch 35/60, Accuracy: 94.36%


Epoch 36/60: 100%|██████████| 34/34 [00:15<00:00,  2.13batch/s, loss=0.0708]


Training - Epoch 36/60, Loss: 0.07076516861150808
Validation - Epoch 36/60, Loss: 0.17546855001192954
Validation - Epoch 36/60, Accuracy: 95.11%


Epoch 37/60: 100%|██████████| 34/34 [00:16<00:00,  2.05batch/s, loss=0.0544]


Training - Epoch 37/60, Loss: 0.054443948963821376
Validation - Epoch 37/60, Loss: 0.20307280786154377
Validation - Epoch 37/60, Accuracy: 93.61%


Epoch 38/60: 100%|██████████| 34/34 [00:15<00:00,  2.13batch/s, loss=0.0415]


Training - Epoch 38/60, Loss: 0.04150964925408035
Validation - Epoch 38/60, Loss: 0.25806422571056625
Validation - Epoch 38/60, Accuracy: 94.36%


Epoch 39/60: 100%|██████████| 34/34 [00:15<00:00,  2.14batch/s, loss=0.0588]


Training - Epoch 39/60, Loss: 0.05875898852451321
Validation - Epoch 39/60, Loss: 0.18893576599657536
Validation - Epoch 39/60, Accuracy: 93.98%


Epoch 40/60: 100%|██████████| 34/34 [00:15<00:00,  2.23batch/s, loss=0.0549]


Training - Epoch 40/60, Loss: 0.054948588949628174
Validation - Epoch 40/60, Loss: 0.19917669819874895
Validation - Epoch 40/60, Accuracy: 94.36%


Epoch 41/60: 100%|██████████| 34/34 [00:16<00:00,  2.07batch/s, loss=0.0943]


Training - Epoch 41/60, Loss: 0.09426064975559711
Validation - Epoch 41/60, Loss: 0.18208669777959585
Validation - Epoch 41/60, Accuracy: 93.98%


Epoch 42/60: 100%|██████████| 34/34 [00:16<00:00,  2.11batch/s, loss=0.0507]


Training - Epoch 42/60, Loss: 0.05071537263746209
Validation - Epoch 42/60, Loss: 0.21611474855389032
Validation - Epoch 42/60, Accuracy: 94.74%


Epoch 43/60: 100%|██████████| 34/34 [00:16<00:00,  2.12batch/s, loss=0.0311] 


Training - Epoch 43/60, Loss: 0.03107344889705179
Validation - Epoch 43/60, Loss: 0.19395724359330618
Validation - Epoch 43/60, Accuracy: 95.11%


Epoch 44/60: 100%|██████████| 34/34 [00:16<00:00,  2.06batch/s, loss=0.0682]


Training - Epoch 44/60, Loss: 0.06817863400518785
Validation - Epoch 44/60, Loss: 0.21042988682165742
Validation - Epoch 44/60, Accuracy: 94.36%


Epoch 45/60: 100%|██████████| 34/34 [00:15<00:00,  2.15batch/s, loss=0.0613]


Training - Epoch 45/60, Loss: 0.061262799753903356
Validation - Epoch 45/60, Loss: 0.2210455226401488
Validation - Epoch 45/60, Accuracy: 93.98%


Epoch 46/60: 100%|██████████| 34/34 [00:16<00:00,  2.03batch/s, loss=0.0352]


Training - Epoch 46/60, Loss: 0.03518544209039058
Validation - Epoch 46/60, Loss: 0.25211805671763915
Validation - Epoch 46/60, Accuracy: 94.36%


Epoch 47/60: 100%|██████████| 34/34 [00:16<00:00,  2.07batch/s, loss=0.0634]


Training - Epoch 47/60, Loss: 0.06338364642579108
Validation - Epoch 47/60, Loss: 0.161733935535368
Validation - Epoch 47/60, Accuracy: 93.98%


Epoch 48/60: 100%|██████████| 34/34 [00:16<00:00,  2.10batch/s, loss=0.0379]


Training - Epoch 48/60, Loss: 0.037935689940829485
Validation - Epoch 48/60, Loss: 0.2107118368278154
Validation - Epoch 48/60, Accuracy: 93.98%


Epoch 49/60: 100%|██████████| 34/34 [00:16<00:00,  2.10batch/s, loss=0.0203]


Training - Epoch 49/60, Loss: 0.020299406104159597
Validation - Epoch 49/60, Loss: 0.21864898554566833
Validation - Epoch 49/60, Accuracy: 94.74%


Epoch 50/60: 100%|██████████| 34/34 [00:16<00:00,  2.03batch/s, loss=0.0389]


Training - Epoch 50/60, Loss: 0.038895561201463674
Validation - Epoch 50/60, Loss: 0.19475127482372853
Validation - Epoch 50/60, Accuracy: 94.74%


Epoch 51/60: 100%|██████████| 34/34 [00:16<00:00,  2.09batch/s, loss=0.0263]


Training - Epoch 51/60, Loss: 0.026319667900251635
Validation - Epoch 51/60, Loss: 0.22217274027773076
Validation - Epoch 51/60, Accuracy: 94.36%


Epoch 52/60: 100%|██████████| 34/34 [00:16<00:00,  2.11batch/s, loss=0.0211]


Training - Epoch 52/60, Loss: 0.02106525617358891
Validation - Epoch 52/60, Loss: 0.2546717225470477
Validation - Epoch 52/60, Accuracy: 94.74%


Epoch 53/60: 100%|██████████| 34/34 [00:16<00:00,  2.10batch/s, loss=0.117] 


Training - Epoch 53/60, Loss: 0.11693082632058684
Validation - Epoch 53/60, Loss: 0.1703763496544626
Validation - Epoch 53/60, Accuracy: 93.61%


Epoch 54/60: 100%|██████████| 34/34 [00:16<00:00,  2.04batch/s, loss=0.0507]


Training - Epoch 54/60, Loss: 0.05065485561156974
Validation - Epoch 54/60, Loss: 0.22775685678546628
Validation - Epoch 54/60, Accuracy: 93.23%


Epoch 55/60: 100%|██████████| 34/34 [00:16<00:00,  2.00batch/s, loss=0.0427]


Training - Epoch 55/60, Loss: 0.042748818037045354
Validation - Epoch 55/60, Loss: 0.20692780470320335
Validation - Epoch 55/60, Accuracy: 95.11%


Epoch 56/60: 100%|██████████| 34/34 [00:16<00:00,  2.08batch/s, loss=0.0323]


Training - Epoch 56/60, Loss: 0.03233132071395898
Validation - Epoch 56/60, Loss: 0.2556010837369185
Validation - Epoch 56/60, Accuracy: 93.61%


Epoch 57/60: 100%|██████████| 34/34 [00:16<00:00,  2.06batch/s, loss=0.0351] 


Training - Epoch 57/60, Loss: 0.03513580116593991
Validation - Epoch 57/60, Loss: 0.14982974860403273
Validation - Epoch 57/60, Accuracy: 95.86%


Epoch 58/60: 100%|██████████| 34/34 [00:16<00:00,  2.03batch/s, loss=0.0247]


Training - Epoch 58/60, Loss: 0.02471222268521909
Validation - Epoch 58/60, Loss: 0.20003103069029748
Validation - Epoch 58/60, Accuracy: 95.11%


Epoch 59/60: 100%|██████████| 34/34 [00:16<00:00,  2.02batch/s, loss=0.052] 


Training - Epoch 59/60, Loss: 0.05202341584577773
Validation - Epoch 59/60, Loss: 0.17493108584959474
Validation - Epoch 59/60, Accuracy: 95.11%


Epoch 60/60: 100%|██████████| 34/34 [00:16<00:00,  2.05batch/s, loss=0.0237]


Training - Epoch 60/60, Loss: 0.02371059050132959
Validation - Epoch 60/60, Loss: 0.15420653828833666
Validation - Epoch 60/60, Accuracy: 95.49%
Training completed and modal saved to vit_model.pth


## Testing

## Visualizing Results

In [22]:
# plt.figure(figsize=(20, 16))
# sns.set(font_scale=1.5)
# sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
# plt.xlabel('Predicted')
# plt.ylabel('True')
# plt.title('Confusion Matrix')