<a href="https://colab.research.google.com/github/rashwinr/MONAI_tutorials/blob/main/MONAI_Tutorial_DS261_AIMIA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Coding Tutorials on MONAI 4 AI in Medical Image Analysis**

This tutorial was conducted as a part of DS261 3:1 Artificial Intelligence for Medical Image Analysis course offered in August of 2024 at the [Computational Data Sciences Department](https://cds.iisc.ac.in/), Indian Institute of Science Bengaluru.

# **MONAI**
<center><img src="https://github.com/Project-MONAI/monai-bootcamp/blob/main/MONAICore/monai.png?raw=1"/></center>

"MONAI" stands for **Medical Open Network for Artificial Intelligence**

-- It is a "low-code" infrastrcuture to build Medical Image Analysis Pipeleines

MONAI consists of three frameworks:
*   MONAI Label: seemlessly integrates into label generation workflow
*   MONAI Core: enables clinicians and researchers to build AI models to work on **Medical Imaging** data
*   MONAI Deploy: facilitates easy transition of Python programs into a deployable application


This colab notebook will introduce you to the *MONAI Core*'s design and architecture. We will get hands-on examples with MONAI's quick introduction and preform a simple deeplearning task</p>



# **MONAI End to End workflow**

MONAI supports deep learning in medical image analysis at multiple levels. This figure shows a typical example of an end-to-end workflow in a medical deep learning context:

<center><img src="https://github.com/Project-MONAI/monai-bootcamp/blob/main/MONAICore/end_to_end.png?raw=1" style="width: 1400px;"/></center>

## *Install & import*  **MONAI**

MONAI is:   

*   an open-source
*   freely available
*   collaborative framework
*   **Low-code framework**

built on **pyTorch** and **Python** for **accelarting research** & **clinical collaboration** in **Medical Image Analysis**



In [None]:
!pip install monai[all]

In [None]:
import monai
from monai.config import print_config
print_config()

# 1. **MONAI: Datatype**

* Dataset: Combines data and its associated transform into a single entity
  * Syntax: ``Dataset(data,transform=None)``

    Where transform is an image or object manipulation that will be activated and acts on the data

In [None]:
from monai.data import Dataset
from monai.transforms import ToTensor
items = [{"data": 4},
         {"data": 9},
         {"data": 3},
         {"data": 7},
         {"data": 1},
         {"data": 2},
         {"data": 5}]

print(type(items))

transform_items = ToTensor()
dataset = Dataset(items, transform=transform_items)
print(type(dataset))
print(dataset)
print(f"Length of dataset is {len(dataset)}")
for item in dataset:
    print(item)

## **Medical Image Analysis without MONAI**

Let us explore why do we have a separate data type for MONAI?

Traditionally, how we access them:

1. import python imaging library (PIL)
2. Import Numerical Python (numpy)
3. Import pytorch
4. Import matplotlib
5. Convert the image into numpy array
6. Show them as plot

## Create a dummy image

**Syntax**:
```python
monai.data.synthetic.create_test_image_2d(height, width, num_objs=12, num_seg_classes=1, channel_dim=3, random_state=None)
```

**Parameters**
________
**height** – height of the image.

**width** – width of the image.

**num_objs** – number of circles to generate. Defaults to 12.

**rad_max** – maximum circle radius. Defaults to 30.

**rad_min** – minimum circle radius. Defaults to 5.

**noise_max** – if greater than 0 then noise will be added to the image taken from the uniform distribution on range [0,noise_max). Defaults to 0.

**num_seg_classes** – number of classes for segmentations. Defaults to 5.

**channel_dim** – if None, create an image without channel dimension, otherwise create an image with channel dimension as first dim or last dim. Defaults to None.

**random_state** – the random generator to use. Defaults to np.random.



In [None]:
from monai.data import create_test_image_2d

image, seg = create_test_image_2d(height=128, width=128,num_objs=5,rad_max=10,rad_min=2,num_seg_classes=2)

print(f"Image shape: {image.shape}")
print(f"Segmentation shape: {seg.shape}")

print(f"Image min: {image.min()}, max: {image.max()}")
print(f"Segmentation min: {seg.min()}, max: {seg.max()}")

### Visualization

- Matplotlib has a number of built-in colormaps
- An intuitive color scheme for the parameter you are plotting
- More details on ``matplotlib.colormaps`` is available here: https://matplotlib.org/stable/users/explain/colors/colormaps.html


<img src="https://matplotlib.org/stable/_images/sphx_glr_colormaps_014.png">

In [None]:
import matplotlib.pyplot as plt

plt.figure("visualize", (12, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(image, cmap="gray")
plt.subplot(1, 2, 2)
plt.title("segmentation")
plt.imshow(seg,cmap="gnuplot")
plt.show()

## MONAI Dataset

In [None]:
from monai.data import Dataset

data = [
    {"image": image, "seg": seg}
]

# Define a dataset using the data list
dataset = Dataset(data=data)


print(f"Dataset length: {len(dataset)}")

# Access a data item by index
item = dataset[0]
print(f"Keys in item: {item.keys()}")

print(f"Image shape: {item['image'].shape}")
print(f"Segmentation shape: {item['seg'].shape}")


# 2. **MONAI: Transforms**

A lot of geometric and image inherent transforms are available to augment the data available and they are easily clubbed with ``Compose`` and image loading

Today we will integrate the following image transforms:
* ``ScaleIntensityd``
* ``Resized``
* ``RandRotated``
* ``RandFlipd``
* ``RandAdjustContrastd``
* ``RandAxisFlipd``
* ``RandZoomd``
* ``RandRotate90d``
* ``ToNumpyd``

Integrating these functions into the transform methodically will help us get accustomed to the MONAI's transform syntax





In [None]:
import monai
from monai.transforms import Compose, EnsureChannelFirstd, ToTensord
from monai.transforms import ScaleIntensityd, Resized, RandRotated, RandFlipd, RandAdjustContrastd, RandAxisFlipd, RandZoomd, RandRotate90d, ToNumpyd

# Define a transform to convert image and segmentation into tensors,
# ensure channel first and scale intensity
transform = Compose([
    EnsureChannelFirstd(keys=["image", "seg"],channel_dim="no_channel"),
    ScaleIntensityd(keys=["image", "seg"]),
    Resized(keys=["image"], spatial_size=(100,100)),
    RandRotated(keys=["image", "seg"], range_x=0.3, prob=0.5, mode=['bilinear', 'nearest']),
    RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=0),
    RandAdjustContrastd(keys=["image"], prob=0.5, gamma=(0.5, 2.0)),
    RandAxisFlipd(keys=["image", "seg"], prob=0.5),
    RandZoomd(keys=["image", "seg"], prob=0.5, min_zoom=0.8, max_zoom=1.2, mode=['area', 'nearest']),
    RandRotate90d(keys=["image", "seg"], prob=0.5, spatial_axes=(0, 1)),
    ToTensord(keys=["image", "seg"])
])
# Create a monai dataset with the transform
dataset_transformed = monai.data.Dataset(data=data, transform=transform)


# Access a data item by index
item = dataset_transformed[0] # Assuming 'dataset_transformed' is the correct dataset object.  The original code used 'dataset'
print(f"Keys in item: {item.keys()}")

print(f"Image shape: {item['image'].shape}")
print(f"Segmentation shape: {item['seg'].shape}")

print(f"Image min: {item['image'].min()}, max: {item['image'].max()}")
print(f"Segmentation min: {item['seg'].min()}, max: {item['seg'].max()}")

# Access a data item by index
item = dataset[0]
print(f"Keys in item: {item.keys()}")

print(f"Image shape: {item['image'].shape}")
print(f"Segmentation shape: {item['seg'].shape}")

print(f"Image min: {item['image'].min()}, max: {item['image'].max()}")
print(f"Segmentation min: {item['seg'].min()}, max: {item['seg'].max()}")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
# Assuming 'image' and 'seg' are from dataset_transformed
# Access the transformed data
item_transformed = dataset_transformed[0]
image_transformed = item_transformed['image'].numpy()  # Access transformed image
seg_transformed = item_transformed['seg'].numpy()    # Access transformed segmentation

item = dataset[0]
image = item['image']
seg = item['seg']


# Remove channel dimension if present for transformed images
if len(image_transformed.shape) == 3 and image_transformed.shape[0] == 1:
    image_transformed = image_transformed.squeeze(0)
if len(seg_transformed.shape) == 3 and seg_transformed.shape[0] == 1:
    seg_transformed = seg_transformed.squeeze(0)

# Remove channel dimension if present for original images
if len(image.shape) == 3 and image.shape[0] == 1:
    image = image.squeeze(0)
if len(seg.shape) == 3 and seg.shape[0] == 1:
    seg = seg.squeeze(0)

# Create a 2x2 subplot to display all four images
plt.figure("visualize", (12, 12))

plt.subplot(2, 2, 1)
plt.title("Original Image")
plt.imshow(image, cmap="gray")

plt.subplot(2, 2, 2)
plt.title("Original Segmentation")
plt.imshow(seg, cmap="gnuplot")

plt.subplot(2, 2, 3)
plt.title("Transformed Image")
plt.imshow(image_transformed, cmap="gray")

plt.subplot(2, 2, 4)
plt.title("Transformed Segmentation")
plt.imshow(seg_transformed, cmap="gnuplot")

plt.show()

# 3. **MONAI datasets**

From MONAI Applications: ``class monai.apps``
- MEDNIST
- Medical Decathlon
- TCIA
- Others: MEDMNIST
- Others: PhysioNet

In [None]:
import os
dir_path = os.getcwd()
print(dir_path)

## MEDNIST Dataset

The MedNIST dataset was gathered from several sets from [TCIA](https://wiki.cancerimagingarchive.net/display/Public/Data+Usage+Policies+and+Restrictions),
[the RSNA Bone Age Challenge](http://rsnachallenges.cloudapp.net/competitions/4),
and [the NIH Chest X-ray dataset](https://cloud.google.com/healthcare/docs/resources/public-datasets/nih-chest).

The dataset is kindly made available by [Dr. Bradley J. Erickson M.D., Ph.D.](https://www.mayo.edu/research/labs/radiology-informatics/overview) (Department of Radiology, Mayo Clinic)
under the Creative Commons [CC BY-SA 4.0 license](https://creativecommons.org/licenses/by-sa/4.0/). If you use the MedNIST dataset, please acknowledge the source.

Syntax: ``MedNISTDataset(root_dir, section, transform=(), download=False, seed=0, val_frac=0.1, test_frac=0.1, cache_num=9223372036854775807, cache_rate=1.0, num_workers=1, progress=True, copy_cache=True, as_contiguous=True, runtime_cache=False)``

**Parameters**:
- **root_dir** – target directory to download and load MedNIST dataset.
- **section** – expected data section, can be: training, validation or test.
- **download** – whether to download and extract the MedNIST from resource link, default is False. if expected file already exists, skip downloading even set it to True. user can manually copy MedNIST.tar.gz file or MedNIST folder to root directory.
- **seed** – random seed to randomly split training, validation and test datasets, default is 0.
- **val_frac** – percentage of validation fraction in the whole dataset, default is 0.1.
- **test_frac** – percentage of test fraction in the whole dataset, default is 0.1.

In [None]:
from monai.apps import MedNISTDataset
train_data = MedNISTDataset(root_dir=dir_path, section="training",download=False, seed=24, val_frac=0.1, test_frac=0.7)
val_data = MedNISTDataset(root_dir=dir_path, section="validation",download=False, seed=24, val_frac=0.1, test_frac=0.7)
test_data = MedNISTDataset(root_dir=dir_path, section="test",download=False, seed=24, val_frac=0.1, test_frac=0.7)


In [None]:
print(f"Length of training dataset: {len(train_data)}")
print(f"Length of validation dataset: {len(val_data)}")
print(f"Length of test dataset: {len(test_data)}")
print(f"Type of train data: {type(train_data)}")
# print(dir(train_data))
for data in val_data:
    image = data['image']
    label = data['label']
    # Print or inspect the 'image' and 'label'
    print(f"Image shape: {image.shape}, Label: {label}")

#### Data exploration

First of all, check the dataset files and show some statistics.  
There are 6 folders in the dataset: Hand, AbdomenCT, CXR, ChestCT, BreastMRI, HeadCT, which should be used as the labels to train our classification model.

In [None]:
import os

mednist_folder = os.path.join(dir_path, 'MedNIST')

if os.path.exists(mednist_folder):
  for subfolder in os.listdir(mednist_folder):
    subfolder_path = os.path.join(mednist_folder, subfolder)
    if os.path.isdir(subfolder_path):
      print(f"Subfolder: {subfolder}")
else:
  print("MedNIST folder not found.")


In [None]:
import PIL
mednist_folder = os.path.join(dir_path, 'MedNIST')

class_names = sorted(x for x in os.listdir(mednist_folder) if os.path.isdir(os.path.join(mednist_folder, x)))
num_class = len(class_names)
image_files = [
    [os.path.join(mednist_folder, class_names[i], x) for x in os.listdir(os.path.join(mednist_folder, class_names[i]))]
    for i in range(num_class)
]
num_each = [len(image_files[i]) for i in range(num_class)]
image_files_list = []
image_class = []
for i in range(num_class):
    image_files_list.extend(image_files[i])
    image_class.extend([i] * num_each[i])
num_total = len(image_class)
image_width, image_height = PIL.Image.open(image_files_list[0]).size

print(f"Total image count: {num_total}")
print(f"Image dimensions: {image_width} x {image_height}")
print(f"Label names: {class_names}")
print(f"Label counts: {num_each}")

In [None]:
plt.subplots(3, 3, figsize=(8, 8))
for i, k in enumerate(np.random.randint(num_total, size=9)):
    im = PIL.Image.open(image_files_list[k])
    arr = np.array(im)
    plt.subplot(3, 3, i + 1)
    plt.xlabel(class_names[image_class[k]])
    plt.imshow(arr, cmap="gray", vmin=0, vmax=255)
plt.tight_layout()
plt.show()

### Decathlon datasets

- The Dataset command to automatically download the data of Medical Segmentation Decathlon challenge (http://medicaldecathlon.com/) and generate items for training, validation or test.

- It will also load these properties from the JSON config file of dataset.

- Syntax:
```python
DecathlonDataset(root_dir, task, section, download=False, seed=0, val_frac=0.2, progress=True)
```
Parameters:
- **root_dir** – local directory for caching and loading the MSD datasets.

- **task** – Task to download and execute: one item of the list
    - “Task01_BrainTumour”
    - “Task02_Heart”
    - “Task03_Liver”
    - “Task04_Hippocampus”
    - “Task05_Prostate”
    - “Task06_Lung”
    - “Task07_Pancreas”
    - “Task08_HepaticVessel”
    - “Task09_Spleen”
    - “Task10_Colon”

- **section** – expected data section: training or validation.

- **download** – whether to download and extract the Decathlon from resource link, default is False. if expected file already exists, skip downloading even set it to True. user can manually copy tar file or dataset folder to the root directory.

- **val_frac** – percentage of validation fraction in the whole dataset, default is 0.2.

- **seed** – random seed to randomly shuffle the datalist before splitting into training and validation, default is 0.
  - **Note**: Set same seed for training and validation sections.

- **progress** – whether to display a progress bar when downloading dataset and computing the transform cache content.



In [None]:
from monai.apps import DecathlonDataset

# Specify the task number you want to access (e.g., Task04_Hippocampus)
task_num = "Task04_Hippocampus"

# Create a DecathlonDataset instance for the specified task
train_decathlondataset = DecathlonDataset(root_dir=dir_path, task=task_num, section="training", download=True,val_frac=0.1)
validation_decathlondataset = DecathlonDataset(root_dir=dir_path, task=task_num, section="validation", download=False,val_frac=0.1)

### TCIA Dataset

- The Dataset to automatically download the data from a public The Cancer Imaging Archive (TCIA) dataset and generate items for training, validation or test. [https://www.cancerimagingarchive.net/](https://www.cancerimagingarchive.net/)
- Syntax:
```Python
class monai.apps.TciaDataset(root_dir, collection, section, transform=(), download=False)
```

- **Massive Public Database**: TCIA provides a huge collection of de-identified medical images (like CT scans, MRIs, and histopathology slides) across a wide range of cancer types. This allows researchers to access diverse data for analysis, development of image-based diagnostic tools, and discovery of new disease insights.

- **Open and Free**: All the data in TCIA is freely available to the public. This open access promotes collaboration, accelerates research, and encourages the development of innovative cancer imaging applications.

- **Standardized Format**: TCIA uses the DICOM (Digital Imaging and Communications in Medicine) standard for storing and distributing images. This ensures compatibility and makes it easier for researchers to use the data with various image processing and analysis tools.

In [None]:
!pip install pydicom
import pydicom
from monai.apps import TciaDataset

# Specify the collection you want to access (e.g., "Lung Phantom")
collection = "Lung Phantom"

# Create a TciaDataset instance for the specified collection
tcia_dataset = TciaDataset(root_dir=dir_path, collection=collection, section="training", download=True)



# 4. **MONAI: Simple Deeplearning Task**

## 5. **MONAI: DataLoader**



In [None]:
from monai.data import DataLoader

train_transforms = Compose(
    [
        EnsureChannelFirstd(keys=["image", "label"],channel_dim="no_channel"), # Apply EnsureChannelFirstd to the "image" key
        Resized(keys=["image"], spatial_size=(32,32)), # Apply Resized to the "image" key
        ToTensord(keys=["image", "label"]), # Convert the "image" and "label" keys to tensors
    ]
)
train_dataset = Dataset(data=train_data, transform=train_transforms)
val_dataset = Dataset(data=val_data, transform=train_transforms)
test_dataset = Dataset(data=test_data, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

print(f"Length of train_loader: {len(train_loader)}")
print(f"Length of val_loader: {len(val_loader)}")
print(f"Length of test_loader: {len(test_loader)}")
print(f"The shape of train_loader: {train_loader.dataset[0]['image'].shape}")
print(f"The sahe of train_loader: {train_loader.dataset[0]['label'].shape}")

## 6. **MONAI: Model**

- MONAI provides specialized neural network architectures and pre-trained models for medical imaging tasks, optimized for handling 3D volumetric data and different modalities (like MRI and CT).

- These models are designed to improve the accuracy and efficiency of tasks such as image segmentation, classification, and registration in medical applications.

- More details: https://docs.monai.io/en/stable/networks.html


In [None]:
from monai.networks.nets import DenseNet121
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=num_class)

## 7. **MONAI: Loss function**

- MONAI offers specialized loss functions tailored for medical imaging, addressing challenges like class imbalance and volumetric data.
- These functions, including Dice Loss and Focal Loss variants, optimize model training for segmentation, classification, and other medical image analysis tasks.

- More details: https://docs.monai.io/en/stable/losses.html

In [None]:
from monai.losses import DiceCELoss

loss_function = DiceCELoss(to_onehot_y=True, lambda_dice=0,lambda_ce=1.0,softmax=False)

## 8. **MONAI: Metrics**

- They allow the use of tensors in the parameter calculations
- More details: https://docs.monai.io/en/stable/metrics.html


In [None]:
from monai.metrics import ROCAUCMetric
import sklearn
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report

auc_metric = ROCAUCMetric()


## Optimizer

In [None]:
import torch
optimizer = torch.optim.Adam(model.parameters(), 1e-5)

## Model Training

In [None]:
from monai.transforms import AsDiscrete, Activations
y_pred_trans = Activations(softmax=True)                          #added y_pred_trans for softmax
y_trans = AsDiscrete(to_onehot=num_class)                                 #added y_trans for one_hot
from monai.data import decollate_batch

max_epochs = 1
val_interval = 1

from tqdm import tqdm
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []


for epoch in range(max_epochs):                                                                   #Iteration of for loop through multiple epochs
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in tqdm(train_loader):                                                                              #Iteration of all the data in train loader
        step += 1
        inputs, labels = batch_data["image"], batch_data["label"]
        optimizer.zero_grad()
        outputs = model(inputs)                                                                                 #Predicting the outputs from the model
        loss = loss_function(outputs, labels)                                                                   #Computing the loss for each batch
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        # print(f"{step}/{len(train_dataset) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")
        epoch_len = len(train_dataset) // train_loader.batch_size
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")                                                  #Printing the computed loss after training

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            y_pred = torch.tensor([], dtype=torch.float32)
            y = torch.tensor([], dtype=torch.long)
            for val_data in val_loader:                                                                         #Iteration of all the data in val loader
                val_images, val_labels = val_data["image"], val_data["label"]
                y_pred = torch.cat([y_pred, model(val_images)], dim=0)
                y = torch.cat([y, val_labels], dim=0)
            y_onehot = [y_trans(i) for i in decollate_batch(y, detach=False)]
            # print(y_onehot)
            y_pred_act = [y_pred_trans(i) for i in decollate_batch(y_pred)]
            # print(y_pred_act.shape)
            # y_onehot = torch.stack(y_onehot)
            # y_pred_act = torch.stack(y_pred_act)
            auc_metric(y_pred_act, y_onehot)                                                                    #Computing the AUROC metric
            result = auc_metric.aggregate()
            auc_metric.reset()
            del y_pred_act, y_onehot
            metric_values.append(result)
            y_pred_class = torch.argmax(y_pred, dim=1)                          # Convert logits to class labels
            acc_metric = accuracy_score(y.cpu().numpy(),y_pred_class.cpu().numpy())                             # Computing the accuracy
            if result > best_metric:
                best_metric = result
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "best_metric_model.pth")
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current AUC: {result:.4f}"
                f" current accuracy: {acc_metric:.4f}"
                f" best AUC: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )

print(f"train completed, best_metric: {best_metric:.4f} " f"at epoch: {best_metric_epoch}")



## Evaluation

In [None]:
from sklearn.metrics import confusion_matrix

# Load the best model
model.load_state_dict(torch.load("/content/best_metric_model.pth"))

#Setting the model to evaluation state
model.eval()


with torch.no_grad():
    y_pred = torch.tensor([], dtype=torch.float32)
    y = torch.tensor([], dtype=torch.long)
    for val_data in val_loader:
        val_images, val_labels = val_data["image"], val_data["label"]
        y_pred = torch.cat([y_pred, model(val_images)], dim=0)
        y = torch.cat([y, val_labels], dim=0)

print(f"The shape of y_pred is: {y_pred.shape}")
y_pred_class = torch.argmax(y_pred,dim=1)
print(f"The shape of y_pred_class is: {y_pred_class.shape}")
print(f"The shape of y_pred_class is: {y_pred_class.unsqueeze(-1).shape}")
accuracy = accuracy_score(y_pred_class.cpu().numpy(),y.squeeze(-1).cpu().numpy())
print(f"Accuracy on validation set: {accuracy}")
y_onehot = [y_trans(i) for i in decollate_batch(y, detach=False)]
y_pred_onehot = [y_trans(i) for i in decollate_batch(y_pred_class.unsqueeze(-1), detach=False)]
y_onehot = torch.stack(y_onehot)
y_pred_onehot = torch.stack(y_pred_onehot)
print(f"The shape of y_pred_onehot is: {y_pred_onehot.shape}")
print(f"The shape of y_onehot is: {y_onehot.shape}")
# print(f"Confusion Matrix:\n{conf_matrix}")
conf_matix_sklearn = confusion_matrix(y,y_pred_class)
# print(f"Confusion Matrix:\n{conf_matix_sklearn}")

from sklearn.metrics import classification_report

print(classification_report(y.cpu().numpy(), y_pred_class.cpu().numpy(), target_names=class_names))


In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Assuming conf_matix_sklearn is your confusion matrix from sklearn
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matix_sklearn, annot=True, fmt="d", cmap="Blues",
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()