version 5. Generate predictions of hold-out set in this code + 2019 dataset

# 1 Introduction

In [None]:
!cp -r ../input/cassava-public-ckpt/filter.jpeg ./ 

![filtering](filter.jpeg)
> [Image Source](https://www.star-spain.com/en/blog/transittermstar-nxt-tooltips/filtering-data-records-termstar-nxt)

This is an analysis of the noisy label of the Cassava dataset using CleanLab.
It should be able to give you **the explanation and inspiration of noisy label.**

The goal of this competition is to help increase crop yields by quickly identifying cassava plant diseases.

This is a **multi-class classification task**.
All images contain Cassava and there is a label for the disease status of the plant.
The evaluation metric is Accuracy.

*train_images/* and *train.csv* are in familiar form containing images and correct answers. 
Since this competition is a code competition, only one sample image exists for provided *test_images/*.
In addition, the 2019 competition was held, so you can use the 2019 competition data.

The information that is actively shared on Discussion and Notebook is that the labels on this dataset are noisy.
This notebok would like to provide an in-depth analysis of noisy label using [CleanLab](https://github.com/cgnorthcutt/cleanlab).

Thanks [@tmhrkt](https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/210557) for sharing CleanLab and I've learnt many things from this community.
I refered to [@jacoporepossi's discussion](https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/198143) for recognizing the disease's characteristics.
If this notebook helps, please leave an upvote.

# 2 Preparation

Here loads the required libraries and data.

In [None]:
!mkdir -p /tmp/pip/cache/
!cp ../input/omegaconf/PyYAML-5.4b2-cp38-cp38-manylinux1_x86_64.whl /tmp/pip/cache/
!cp ../input/omegaconf/omegaconf-2.0.5-py3-none-any.whl /tmp/pip/cache/
!cp ../input/omegaconf/typing_extensions-3.7.4.3-py3-none-any.whl /tmp/pip/cache/
!pip install --no-index --find-links /tmp/pip/cache/ omegaconf > /dev/null

In [None]:
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
sys.path.append("../input/cleanlab/")

import os
import sys
from glob import glob
import numpy as np
import pandas as pd
from omegaconf import DictConfig, OmegaConf
import cleanlab
import cv2
from ast import literal_eval
from collections import OrderedDict
from sklearn.model_selection import StratifiedKFold
import warnings
warnings.filterwarnings('ignore')

import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.core.decorators import auto_move_data
from pytorch_lightning.utilities.cloud_io import load as pl_load
import timm

import seaborn as sns
import matplotlib.pyplot as plt

## 2-1 Generate Predictions

Generate predictions for hold-out set with EfficientNet-b0

In [None]:
from albumentations import *
from albumentations.pytorch import ToTensorV2


mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)

def get_transforms(img_size=(512, 512)):
    transformations = Compose([
        PadIfNeeded(min_height=img_size[0], min_width=img_size[1]),
        CenterCrop(img_size[0], img_size[1]),
        Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
        ToTensorV2(p=1.0),
    ], p=1.0)
    return transformations

In [None]:
class TestDataset(Dataset):
    def __init__(self, img_dir, df, img_size=384):
        self.img_dir = img_dir
        self.df = df
        self.file_names = df['image_id'].values
        self.transform = get_transforms(img_size=img_size)
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = os.path.join(self.img_dir, file_name)
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image

In [None]:
import torch
import torch.nn as nn
import timm


def create_model(model_name: str,
                 pretrained: bool,
                 num_classes: int,
                 in_chans: int):
    model = timm.create_model(model_name=model_name,
                              pretrained=pretrained,
                              num_classes=num_classes,
                              in_chans=in_chans)
    return model

In [None]:
import os
from glob import glob
from collections import OrderedDict
from pytorch_lightning.utilities.cloud_io import load as pl_load


def get_state_dict_from_checkpoint(log_dir, fold_num):
    ckpt_path = glob(os.path.join(log_dir, f'checkpoints/*fold{fold_num}*.ckpt'))[0]
    state_dict = pl_load(ckpt_path, map_location='cpu')
    if 'state_dict' in state_dict:
        state_dict = state_dict['state_dict']
    did_distillation = False
    state_dict = OrderedDict((k.replace('model.', '')
                              if 'model.' in k else k, v) for k, v in
                             state_dict.items())
    return state_dict

In [None]:
class LitTester(pl.LightningModule):
    def __init__(self, network_cfg, state_dict):
        super(LitTester, self).__init__()
        self.model = create_model(**network_cfg)
        self.model.load_state_dict(state_dict)
        self.model.eval()
        
    def forward(self, x):
        x = self.model(x)
        return x
    
    def test_step(self, batch, batch_idx):
        score = torch.nn.functional.softmax(self(batch), dim=1)
        score2 = torch.nn.functional.softmax(self(torch.flip(batch, [-1])), dim=1)
        score3 = torch.nn.functional.softmax(self(torch.flip(batch, [-2])), dim=1)

        out = (score + score2 + score3) / 3.0
        return {"pred": out}

    def test_epoch_end(self, output_results):
        all_outputs = torch.cat([out["pred"] for out in output_results], dim=0)
        all_outputs = all_outputs.cpu().numpy()
        return {'prob': all_outputs}

In [None]:
eff_b0_cfg_s = """
batch_size: 160
img_size: [512, 512]
network:
  model_name: tf_efficientnet_b0_ns
  pretrained: False
  num_classes: 5
  in_chans: 3
"""
eff_b0_cfg = OmegaConf.create(eff_b0_cfg_s)

In [None]:
### Configurations

# Checkpoint
name = '14-10-36'
cfg = eff_b0_cfg

# Generate Predictions (GPU needed)
do_predict = True
# Generate Submission (GPU needed)
do_submit = False

img_dir = '../input/cassava-leaf-disease-merged/train/'
label_path = '../input/cassava-leaf-disease-merged/merged.csv'
log_dir = os.path.join('../input/cassava-public-ckpt', name)
n_folds = len(glob(os.path.join(log_dir, 'checkpoints/*.ckpt')))

num2class = ["Cassava Bacterial Blight (CBB)", "Cassava Brown Streak Disease (CBSD)", "Cassava Green Mottle (CGM)", "Cassava Mosaic Disease (CMD)", "Healthy"]

In [None]:
# Set seed for reproducing
seed_everything(42)

# If there is no fold in label dataframe, generate with Stratified K-fold
label_df = pd.read_csv(label_path)
if 'fold' not in label_df.columns:
    skf = StratifiedKFold(n_splits=5, shuffle=True)
    label_df.loc[:, 'fold'] = 0
    for fold_num, (train_index, val_index) in enumerate(skf.split(X=label_df.index, y=label_df.label.values)):
        label_df.loc[label_df.iloc[val_index].index, 'fold'] = fold_num

# Generate predictions for hold-out sets (True) or load a saved prediction (False).
if do_predict:
    infer = pl.Trainer(gpus=1)
    oof_dict = {'image_id': [], 'label': [], 'fold': []}
    for fold_num in range(n_folds):
        val_df = label_df[label_df.fold == fold_num]
        test_dataset = TestDataset(img_dir, val_df, img_size=cfg.img_size)
        test_dataloader = DataLoader(test_dataset,
                                     batch_size=cfg.batch_size,
                                     num_workers=4,
                                     shuffle=False)

        state_dict = get_state_dict_from_checkpoint(log_dir, fold_num)
        model = LitTester(cfg.network, state_dict)
        pred = infer.test(model, test_dataloaders=test_dataloader, verbose=False)[0]
        oof_dict['image_id'].extend(val_df.image_id.values)
        oof_dict['label'].extend(pred['prob'].tolist())
        oof_dict['fold'].extend([fold_num] * len(pred['prob']))
    pred_df = pd.DataFrame(oof_dict)
    pred_df.to_csv('oof.csv', index=False)
else:
    pred_df = pd.read_csv(os.path.join(log_dir, 'oof.csv'))

## 2-2 Generate Submission (Optional)

If you want to submit the result from EfficientNet-b0(baseline), you can set *do_submit = True*

In [None]:
if do_submit:
    sub = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
    infer = pl.Trainer(gpus=1)
    test_dataset = TestDataset('../input/cassava-leaf-disease-classification/test_images',
                               sub,
                               img_size=cfg.img_size)
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=cfg.batch_size,
                                 num_workers=4,
                                 shuffle=False)
    preds = []
    for fold_num in range(n_folds):
        state_dict = get_state_dict_from_checkpoint(log_dir, fold_num)
        model = LitTester(cfg.network, state_dict)
        pred = infer.test(model, test_dataloaders=test_dataloader, verbose=False)[0]
        preds.append(pred['prob'])
    sub['label'] = np.argmax(np.mean(preds, axis=0), axis=1)
    sub.to_csv(os.path.join(os.getcwd(), 'submission.csv'), index=False)

In [None]:
label_df = label_df.sort_values(by='image_id', ascending=1)
pred_df = pred_df.sort_values(by='image_id', ascending=1)

ids, labels = label_df.image_id.values, label_df.label.values
preds = np.array([literal_eval(pred) if isinstance(pred, str) else pred for pred in pred_df.label.values])

print(f'total {len(ids)} images')
print(f'prediction shape: {preds.shape}, label shape: {labels.shape}')

# 3 CleanLab: Denoise Dataset Labels

Here, we use **CleanLab** to guess the data with Noisy Label.

The data utilized here is the label given in the competition and the prediction probability of the trained model.
The model used is *EfficientNet-b0* and the probabilty was only gotten from the holdout set(5-folds).

**(Note: This below code is taken from [cleanlab](https://github.com/cgnorthcutt/cleanlab) tutorial)**

In [None]:
# STEP 1 - Compute confident joint

# Verify inputs
s = labels
psx = preds

# Find the number of unique classes if K is not given
K = len(np.unique(s))

# Estimate the probability thresholds for confident counting
# You can specify these thresholds yourself if you want
# as you may want to optimize them using a validation set.
# By default (and provably so) they are set to the average class prob.
thresholds = [np.mean(psx[:,k][s == k]) for k in range(K)] # P(s^=k|s=k)
thresholds = np.asarray(thresholds)

# Compute confident joint
confident_joint = np.zeros((K, K), dtype = int)
for i, row in enumerate(psx):
    s_label = s[i]
    # Find out how many classes each example is confidently labeled as
    confident_bins = row >= thresholds - 1e-6
    num_confident_bins = sum(confident_bins)
    # If more than one conf class, inc the count of the max prob class
    if num_confident_bins == 1:
        confident_joint[s_label][np.argmax(confident_bins)] += 1
    elif num_confident_bins > 1:
        confident_joint[s_label][np.argmax(row)] += 1

# Normalize confident joint (use cleanlab, trust me on this)
confident_joint = cleanlab.latent_estimation.calibrate_confident_joint(
    confident_joint, s)

cleanlab.util.print_joint_matrix(confident_joint)

# STEP 2 - Find label errors

# We arbitrarily choose at least 5 examples left in every class.
# Regardless of whether some of them might be label errors.
MIN_NUM_PER_CLASS = 5
# Leave at least MIN_NUM_PER_CLASS examples per class.
# NOTE prune_count_matrix is transposed (relative to confident_joint)
prune_count_matrix = cleanlab.pruning.keep_at_least_n_per_class(
    prune_count_matrix=confident_joint.T,
    n=MIN_NUM_PER_CLASS,
)

s_counts = np.bincount(s)
noise_masks_per_class = []
# For each row in the transposed confident joint
for k in range(K):
    noise_mask = np.zeros(len(psx), dtype=bool)
    psx_k = psx[:, k]
    if s_counts[k] > MIN_NUM_PER_CLASS:  # Don't prune if not MIN_NUM_PER_CLASS
        for j in range(K):  # noisy label index (k is the true label index)
            if k != j:  # Only prune for noise rates, not diagonal entries
                num2prune = prune_count_matrix[k][j]
                if num2prune > 0:
                    # num2prune'th largest p(classk) - p(class j)
                    # for x with noisy label j
                    margin = psx_k - psx[:, j]
                    s_filter = s == j
                    threshold = -np.partition(
                        -margin[s_filter], num2prune - 1
                    )[num2prune - 1]
                    noise_mask = noise_mask | (s_filter & (margin >= threshold))
        noise_masks_per_class.append(noise_mask)
    else:
        noise_masks_per_class.append(np.zeros(len(s), dtype=bool))

# Boolean label error mask
label_errors_bool = np.stack(noise_masks_per_class).any(axis=0)

 # Remove label errors if given label == model prediction
for i, pred_label in enumerate(psx.argmax(axis=1)):
    # np.all let's this work for multi_label and single label
    if label_errors_bool[i] and np.all(pred_label == s[i]):
        label_errors_bool[i] = False

# Convert boolean mask to an ordered list of indices for label errors
label_errors_idx = np.arange(len(s))[label_errors_bool]
# self confidence is the holdout probability that an example
# belongs to its given class label
self_confidence = np.array(
    [np.mean(psx[i][s[i]]) for i in label_errors_idx]
)
margin = self_confidence - psx[label_errors_bool].max(axis=1)
label_errors_idx = label_errors_idx[np.argsort(margin)]

In [None]:
total_idx = np.arange(len(ids))
clean_idx = np.array([idx for idx in total_idx if idx not in label_errors_idx])

guesses = np.stack(noise_masks_per_class).argmax(axis=0)
guesses[clean_idx] = labels[clean_idx]

clean_ids = ids[clean_idx]
clean_labels = labels[clean_idx]
clean_guesses = guesses[clean_idx]

noisy_ids = ids[label_errors_idx]
noisy_labels = labels[label_errors_idx]
noisy_guesses = guesses[label_errors_idx]

print(f'[clean ratio] \t {len(clean_idx) / len(total_idx) * 100:.2f}%')
print(f'[noise ratio] \t {len(noisy_ids) / len(total_idx) * 100:.2f}%')

# 4 Visualize Clean or Noisy Data

Before analyzing the distribution, let's take a visual look at the results of CleanLab.

<span style="color:blue; font-size:14pt">*Blue: a 'guess' class*</span>

<span style="color:red; font-size:14pt">*Red: a 'given' class*</span>

In [None]:
def visualize_images(ids, labels, guesses, target_class,
                     n_rows=4, n_cols=4):
    c_ids, c_guess, c_labels = ids[guesses == target_class], \
                               guesses[guesses == target_class], \
                               labels[guesses == target_class]

    fig, axes = plt.subplots(n_rows, n_cols, sharex=True, sharey=True, figsize=(15, 12))
    for i in range(n_rows*n_cols):
        img = cv2.imread(os.path.join(img_dir, c_ids[i]))[...,::-1]
        axes[i//n_rows][i%n_cols].imshow(img)
        axes[i//n_rows][i%n_cols].set_title(f'{num2class[c_labels[i]]}\n{c_ids[i]}', color='r')
    plt.suptitle(f'Guess: {num2class[c_guess[i]]}', y=1.03, color='b', fontsize=20, fontweight='bold')
    plt.tight_layout()

## 4-1 Clean Label

In [None]:
visualize_images(clean_ids, clean_labels, clean_guesses, target_class=0)

- At first, angular, water-soaked spots occur on the leaves which are restricted by the veins; the spots are more clearly seen on the lower leaf surface. The spots expand rapidly, join together, especially along the margins of the leaves, and turn brown with yellow borders. Droplets of a creamy-white ooze occur at the centre of the spots; later, they turn yellow.
- Stem infections block the flow of water and food and the leaves above wilt, die and fall, and branches die back.

- Main characteristics to leverage: **angular spots, brown spots with yellow borders, yellow leaves, leaves wilting**

In [None]:
visualize_images(clean_ids, clean_labels, clean_guesses, target_class=1)

- Symptoms of cassava brown streak disease appear as patches of yellow areas mixed with normal green colour. The characteristic yellow or necrotic vein banding may enlarge and coalesce to form a large yellow patches.
- The infected leaves do not become distorted in shape as occurs with leaves infected by Cassava mosaic disease.

- Main characteristics to leverage: **yellow spots**

In [None]:
visualize_images(clean_ids, clean_labels, clean_guesses, target_class=2)

- Young leaves are puckered with faint to distinct yellow spots, green patterns (mosaics), and twisted margins. Occasionally, plants become severely stunted.

- Main characteristics to leverage: **yellow patterns, irregular patches of yellow and green, leaf margins distortion, stunted**

In [None]:
visualize_images(clean_ids, clean_labels, clean_guesses, target_class=3)

- Cassava Mosaic Disease is characterized by sever mosaic symptoms on leaves, with affected leaved showing mottling and light-green, yellow or white spots. Discoloration, malformation and puckering of the leaf blade occur.

- Main characteristics to leverage: **severe shape distortion, mosaic patterns**

In [None]:
visualize_images(clean_ids, clean_labels, clean_guesses, target_class=4)

## 4-2 Noisy Label

In [None]:
visualize_images(noisy_ids, noisy_labels, noisy_guesses, target_class=0)

- At first, angular, water-soaked spots occur on the leaves which are restricted by the veins; the spots are more clearly seen on the lower leaf surface. The spots expand rapidly, join together, especially along the margins of the leaves, and turn brown with yellow borders. Droplets of a creamy-white ooze occur at the centre of the spots; later, they turn yellow.
- Stem infections block the flow of water and food and the leaves above wilt, die and fall, and branches die back.

- Main characteristics to leverage: **angular spots, brown spots with yellow borders, yellow leaves, leaves wilting**

In [None]:
visualize_images(noisy_ids, noisy_labels, noisy_guesses, target_class=1)

- Symptoms of cassava brown streak disease appear as patches of yellow areas mixed with normal green colour. The characteristic yellow or necrotic vein banding may enlarge and coalesce to form a large yellow patches.
- The infected leaves do not become distorted in shape as occurs with leaves infected by Cassava mosaic disease.

- Main characteristics to leverage: **yellow spots**

In [None]:
visualize_images(noisy_ids, noisy_labels, noisy_guesses, target_class=2)

- Young leaves are puckered with faint to distinct yellow spots, green patterns (mosaics), and twisted margins. Occasionally, plants become severely stunted.

- Main characteristics to leverage: **yellow patterns, irregular patches of yellow and green, leaf margins distortion, stunted**

In [None]:
visualize_images(noisy_ids, noisy_labels, noisy_guesses, target_class=3)

- Cassava Mosaic Disease is characterized by sever mosaic symptoms on leaves, with affected leaved showing mottling and light-green, yellow or white spots. Discoloration, malformation and puckering of the leaf blade occur.

- Main characteristics to leverage: **severe shape distortion, mosaic patterns**

In [None]:
visualize_images(noisy_ids, noisy_labels, noisy_guesses, target_class=4)

# 5 Noisy Label Analysis

In [None]:
all_data = pd.DataFrame({'image_id': ids,
                         'given_label': labels,
                         'guess_label': guesses})
all_data['is_noisy'] = (all_data.given_label != all_data.guess_label)
all_data['max_prob'] = preds.max(axis=1)

## 5-1 Noise Distribution

In [None]:
plt.figure(figsize=(6, 4.5)) 
ax = sns.countplot(x = 'is_noisy', data = all_data, palette=["#55967e", "#263959"])

plt.xticks( np.arange(2), ['Clean', 'Noise'] )
plt.title('Noise Distribution (hold-out set)',fontsize= 14)
plt.xlabel('')
plt.ylabel('Number of images')

counts = all_data['is_noisy'].value_counts()
counts_pct = [f'{elem * 100:.2f}%' for elem in counts / counts.sum()]
for i, v in enumerate(counts_pct):
    ax.text(i, 0, v, horizontalalignment = 'center', size = 14, color = 'w', fontweight = 'bold')
    
plt.show()

- Noisy Label isn't as much as expected (8.02%).

In [None]:
noise_by_class = all_data.groupby('given_label')['is_noisy'].value_counts(normalize = True)
noise_by_class = noise_by_class.unstack()
noise_by_class.plot(kind='bar', stacked='True', color = ["#55967e", "#263959"]) # '#6d819c', '#e4e7ec'
plt.legend(loc=(1.04,0))
_ = plt.xticks(
    rotation = 45, 
    horizontalalignment = 'right',
    fontweight = 'light'  
)

plt.title('Noise Distribution by class',fontsize= 14)
plt.xlabel('Class of the images')
plt.tight_layout()

- There are many noisy labels from 0-Cassava Bacterial Blight (CBB) and 4-Healthy classes.
- 3-Cassava Mosaic Disease (CMD) includes very small amount of noisy labels (1.7%).

In [None]:
plt.figure(figsize=(20, 5))
palette = "Set3"

plt.subplot(1, 3, 1)
sns.boxplot(x = 'guess_label', y = 'max_prob', data = all_data,
     palette = palette, fliersize = 0)
sns.stripplot(x = 'guess_label', y = 'max_prob', data = all_data,
     linewidth = 0.6, palette = palette)
plt.title('Prediction Probability Distribution by Class',fontsize= 14)
plt.ylim(0.2, 1.05)
plt.ylabel('Prediction Probability')

plt.show()

- 0-Cassava Bacterial Blight (CBB) and 4-Healthy results a little confidence(probability).
- 3-Cassava Mosaic Disease (CMD) results strong confidence.

## 5-2 How does a label change?

In [None]:
class_colors = np.array(['#fe4365', '#fc9d9a', '#f9cdad', '#c8c8a9', '#aacfd0'])
num2class = [f'{idx}-{elem}' for idx, elem in enumerate(num2class)]

In [None]:
class_num = 0
label_shift = all_data[all_data.given_label==class_num]['guess_label'].value_counts(normalize=True)

sns.set_style("whitegrid")
bar,ax = plt.subplots(figsize=(10,6))
ax = sns.barplot(x=label_shift.index, y=np.array(num2class)[label_shift.index],
                 data=label_shift, ci=None, palette=class_colors[label_shift.index], orient='h' )
ax.set_title(f"{num2class[class_num]} Changed to", fontsize=15)
ax.set_xlabel ("Percentage")
ax.set_ylabel ("Disease Type")
for rect in ax.patches:
    ax.text (rect.get_width(), rect.get_y() + rect.get_height() / 2,"%.2f%%"% rect.get_width(), weight='bold' )

In [None]:
class_num = 1
label_shift = all_data[all_data.given_label==class_num]['guess_label'].value_counts(normalize=True)

sns.set_style("whitegrid")
bar,ax = plt.subplots(figsize=(10,6))
ax = sns.barplot(x=label_shift.index, y=np.array(num2class)[label_shift.index],
                 data=label_shift, ci=None, palette=class_colors[label_shift.index], orient='h' )
ax.set_title(f"{num2class[class_num]} Changed to", fontsize=15)
ax.set_xlabel ("Percentage")
ax.set_ylabel ("Disease Type")
for rect in ax.patches:
    ax.text (rect.get_width(), rect.get_y() + rect.get_height() / 2,"%.2f%%"% rect.get_width(), weight='bold' )

In [None]:
class_num = 2
label_shift = all_data[all_data.given_label==class_num]['guess_label'].value_counts(normalize=True)

sns.set_style("whitegrid")
bar,ax = plt.subplots(figsize=(10,6))
ax = sns.barplot(x=label_shift.index, y=np.array(num2class)[label_shift.index],
                 data=label_shift, ci=None, palette=class_colors[label_shift.index], orient='h' )
ax.set_title(f"{num2class[class_num]} Changed to", fontsize=15)
ax.set_xlabel ("Percentage")
ax.set_ylabel ("Disease Type")
for rect in ax.patches:
    ax.text (rect.get_width(), rect.get_y() + rect.get_height() / 2,"%.2f%%"% rect.get_width(), weight='bold' )

In [None]:
class_num = 3
label_shift = all_data[all_data.given_label==class_num]['guess_label'].value_counts(normalize=True)

sns.set_style("whitegrid")
bar,ax = plt.subplots(figsize=(10,6))
ax = sns.barplot(x=label_shift.index, y=np.array(num2class)[label_shift.index],
                 data=label_shift, ci=None, palette=class_colors[label_shift.index], orient='h' )
ax.set_title(f"{num2class[class_num]} Changed to", fontsize=15)
ax.set_xlabel ("Percentage")
ax.set_ylabel ("Disease Type")
for rect in ax.patches:
    ax.text (rect.get_width(), rect.get_y() + rect.get_height() / 2,"%.2f%%"% rect.get_width(), weight='bold' )

In [None]:
class_num = 4
label_shift = all_data[all_data.given_label==class_num]['guess_label'].value_counts(normalize=True)

sns.set_style("whitegrid")
bar,ax = plt.subplots(figsize=(10,6))
ax = sns.barplot(x=label_shift.index, y=np.array(num2class)[label_shift.index],
                 data=label_shift, ci=None, palette=class_colors[label_shift.index], orient='h' )
ax.set_title(f"{num2class[class_num]} Changed to", fontsize=15)
ax.set_xlabel ("Percentage")
ax.set_ylabel ("Disease Type")
for rect in ax.patches:
    ax.text (rect.get_width(), rect.get_y() + rect.get_height() / 2,"%.2f%%"% rect.get_width(), weight='bold' )