 # ADDA for ST

In [1]:
import os
import datetime
from itertools import chain
from copy import deepcopy
import warnings

from tqdm.autonotebook import tqdm

import h5py
import pickle
import numpy as np

import torch
from torch import nn

from src.da_models.adda import ADDAST
from src.da_models.datasets import SpotDataset
from src.da_models.utils import set_requires_grad

# datetime object containing current date and time
script_start_time = datetime.datetime.now().strftime("%Y-%m-%d_%Hh%Mm%S")


  from tqdm.autonotebook import tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cpu":
    warnings.warn(
        'Using CPU', 
        stacklevel=2)

In [3]:

NUM_MARKERS = 20
N_MIX = 8
N_SPOTS = 20000
TRAIN_USING_ALL_ST_SAMPLES = False

SAMPLE_ID_N = "151673"

BATCH_SIZE = 512
NUM_WORKERS = 4
INITIAL_TRAIN_EPOCHS = 10

EARLY_STOP_CRIT = 100
MIN_EPOCHS = INITIAL_TRAIN_EPOCHS


EARLY_STOP_CRIT_ADV = 10
MIN_EPOCHS_ADV = 10

SPATIALLIBD_DIR = "./data/spatialLIBD"
SC_DLPFC_PATH = "./data/sc_dlpfc/adata_sc_dlpfc.h5ad"

PROCESSED_DATA_DIR = "./data/preprocessed"

MODEL_NAME = "CellDART"


In [4]:
model_folder = os.path.join("model", MODEL_NAME, script_start_time)

model_folder = os.path.join("model", MODEL_NAME, "TESTING")

if not os.path.isdir(model_folder):
    os.makedirs(model_folder)
    print(model_folder)


 # Data load


In [5]:
# Load spatial data
mat_sp_test_s_d = {}
with h5py.File(os.path.join(PROCESSED_DATA_DIR, "mat_sp_test_s_d.hdf5"), "r") as f:
    for sample_id in f:
        mat_sp_test_s_d[sample_id] = f[sample_id][()]

if TRAIN_USING_ALL_ST_SAMPLES:
    with h5py.File(os.path.join(PROCESSED_DATA_DIR, "mat_sp_train_s.hdf5"), "r") as f:
        mat_sp_train_s = f["all"][()]
else:
    mat_sp_train_s_d = mat_sp_test_s_d

# Load sc data
with h5py.File(os.path.join(PROCESSED_DATA_DIR, "sc.hdf5"), "r") as f:
    sc_mix_train_s = f["X/train"][()]
    sc_mix_val_s = f["X/val"][()]
    sc_mix_test_s = f["X/test"][()]

    lab_mix_train = f["y/train"][()]
    lab_mix_val = f["y/val"][()]
    lab_mix_test = f["y/test"][()]

# Load helper dicts / lists
with open(os.path.join(PROCESSED_DATA_DIR, "sc_sub_dict.pkl"), "rb") as f:
    sc_sub_dict = pickle.load(f)

with open(os.path.join(PROCESSED_DATA_DIR, "sc_sub_dict2.pkl"), "rb") as f:
    sc_sub_dict2 = pickle.load(f)

with open(os.path.join(PROCESSED_DATA_DIR, "st_sample_id_l.pkl"), "rb") as f:
    st_sample_id_l = pickle.load(f)


 # Training: Adversarial domain adaptation for cell fraction estimation

 ## Prepare dataloaders

In [6]:
### source dataloaders
source_train_set = SpotDataset(deepcopy(sc_mix_train_s), deepcopy(lab_mix_train))
source_val_set = SpotDataset(deepcopy(sc_mix_val_s), deepcopy(lab_mix_val))
source_test_set = SpotDataset(deepcopy(sc_mix_test_s), deepcopy(lab_mix_test))

dataloader_source_train = torch.utils.data.DataLoader(
    source_train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
dataloader_source_val = torch.utils.data.DataLoader(
    source_val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)
dataloader_source_test = torch.utils.data.DataLoader(
    source_test_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)

### target dataloaders
target_test_set_d = {}
for sample_id in st_sample_id_l:
    target_test_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_test_s_d[sample_id]))

dataloader_target_test_d = {}
for sample_id in st_sample_id_l:
    dataloader_target_test_d[sample_id] = torch.utils.data.DataLoader(
        target_test_set_d[sample_id],
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=False,
    )

if TRAIN_USING_ALL_ST_SAMPLES:
    target_train_set = SpotDataset(deepcopy(mat_sp_train_s))
    dataloader_target_train = torch.utils.data.DataLoader(
        target_train_set,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
else:
    target_train_set_d = {}
    dataloader_target_train_d = {}
    for sample_id in st_sample_id_l:
        target_train_set_d[sample_id] = SpotDataset(
            deepcopy(mat_sp_test_s_d[sample_id])
        )
        dataloader_target_train_d[sample_id] = torch.utils.data.DataLoader(
            target_train_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=True,
        )


 ## Define Model

In [7]:
model = ADDAST(sc_mix_train_s.shape[1], emb_dim=64, ncls_source=lab_mix_train.shape[1])

## CellDART uses just one encoder!
model.target_encoder = model.source_encoder
model.to(device)


ADDAST(
  (source_encoder): MLPEncoder(
    (encoder): Sequential(
      (0): Linear(in_features=367, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): Linear(in_features=1024, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (5): ELU(alpha=1.0)
    )
  )
  (target_encoder): MLPEncoder(
    (encoder): Sequential(
      (0): Linear(in_features=367, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): Linear(in_features=1024, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (5): ELU(alpha=1.0)
    )
  )
  (clf): Predictor(
    (head): Sequential(
      (0): Linear(in_features=64, out_features=33, bias=True)
      (1): Log

 ## Pretrain

In [8]:
pretrain_folder = os.path.join(model_folder, "pretrain")

if not os.path.isdir(pretrain_folder):
    os.makedirs(pretrain_folder)


In [9]:
pre_optimizer = torch.optim.Adam(
    model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-07
)

criterion_clf = nn.KLDivLoss(reduction="batchmean")


In [10]:
def model_loss(x, y_true, model):
    x = x.to(torch.float32).to(device)
    y_true = y_true.to(torch.float32).to(device)

    y_pred = model(x)

    loss = criterion_clf(y_pred, y_true)

    return loss


def compute_acc(dataloader, model):
    loss_running = []
    mean_weights = []
    model.eval()
    with torch.no_grad():
        for _, batch in enumerate(dataloader):

            loss = model_loss(*batch, model)

            loss_running.append(loss.item())

            # we will weight average by batch size later
            mean_weights.append(len(batch))

    return np.average(loss_running, weights=mean_weights)


In [11]:
model.pretraining()


In [12]:
# Initialize lists to store loss and accuracy values
loss_history = []
loss_history_val = []

loss_history_running = []

# Early Stopping
best_loss_val = np.inf
early_stop_count = 0

# Train
print("Start pretrain...")
outer = tqdm(total=INITIAL_TRAIN_EPOCHS, desc="Epochs", position=0)
inner = tqdm(total=len(dataloader_source_train), desc=f"Batch", position=1)

checkpoint = {
    "epoch": -1,
    "model": model,
    "optimizer": pre_optimizer,
}
for epoch in range(INITIAL_TRAIN_EPOCHS):
    checkpoint["epoch"] = epoch

    # Train mode
    model.train()
    loss_running = []
    mean_weights = []

    inner.refresh()  # force print final state
    inner.reset()  # reuse bar
    for _, batch in enumerate(dataloader_source_train):

        pre_optimizer.zero_grad()
        loss = model_loss(*batch, model)
        loss_running.append(loss.item())
        mean_weights.append(len(batch))  # we will weight average by batch size later

        loss.backward()
        pre_optimizer.step()

        inner.update(1)

    loss_history.append(np.average(loss_running, weights=mean_weights))
    loss_history_running.append(loss_running)

    # Evaluate mode
    model.eval()
    with torch.no_grad():
        curr_loss_val = compute_acc(dataloader_source_val, model)
        loss_history_val.append(curr_loss_val)

    # Print the results
    outer.update(1)
    print(
        "epoch:",
        epoch,
        "train loss:",
        round(loss_history[-1], 6),
        "validation loss:",
        round(loss_history_val[-1], 6),
        end=" ",
    )

    # Save the best weights
    if curr_loss_val < best_loss_val:
        best_loss_val = curr_loss_val
        torch.save(checkpoint, os.path.join(pretrain_folder, f"best_model.pth"))
        early_stop_count = 0

        print("<-- new best val loss")
    else:
        print("")

    # Save checkpoint every 10
    if epoch % 10 == 0 or epoch >= INITIAL_TRAIN_EPOCHS - 1:
        torch.save(checkpoint, os.path.join(pretrain_folder, f"checkpt{epoch}.pth"))

    # check to see if validation loss has plateau'd
    if early_stop_count >= EARLY_STOP_CRIT and epoch >= MIN_EPOCHS - 1:
        print(f"Validation loss plateaued after {early_stop_count} at epoch {epoch}")
        torch.save(checkpoint, os.path.join(pretrain_folder, f"earlystop{epoch}.pth"))
        break

    early_stop_count += 1


# Save final model
torch.save(checkpoint, os.path.join(pretrain_folder, f"final_model.pth"))


Start pretrain...


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Batch:   0%|          | 0/40 [00:00<?, ?it/s]

epoch: 0 train loss: 1.246527 validation loss: 0.991463 <-- new best val loss
epoch: 1 train loss: 0.894556 validation loss: 0.79241 <-- new best val loss
epoch: 2 train loss: 0.709414 validation loss: 0.690445 <-- new best val loss
epoch: 3 train loss: 0.621515 validation loss: 0.661922 <-- new best val loss
epoch: 4 train loss: 0.584557 validation loss: 0.65897 <-- new best val loss
epoch: 5 train loss: 0.561804 validation loss: 0.644848 <-- new best val loss
epoch: 6 train loss: 0.542548 validation loss: 0.642356 <-- new best val loss
epoch: 7 train loss: 0.527927 validation loss: 0.640265 <-- new best val loss
epoch: 8 train loss: 0.512742 validation loss: 0.630194 <-- new best val loss
epoch: 9 train loss: 0.500608 validation loss: 0.665679 


 ## Adversarial Adaptation

In [13]:
N_ITER = 3000
EPOCHS = 1000
ALPHA_LR = 5
ALPHA = 1 / 0.6


In [14]:
advtrain_folder = os.path.join(model_folder, "advtrain")

if not os.path.isdir(advtrain_folder):
    os.makedirs(advtrain_folder)


In [15]:
def batch_generator(data, batch_size):
    """Generate batches of data.
    Given a list of numpy data, it iterates over the list and returns batches of
    the same size
    This
    """
    all_examples_indices = len(data[0])
    while True:
        mini_batch_indices = np.random.choice(
            all_examples_indices, size=batch_size, replace=False
        )
        tbr = [k[mini_batch_indices] for k in data]
        yield tbr


In [16]:
criterion_dis = nn.CrossEntropyLoss()


def discrim_loss_accu(x, y_dis, model):
    x = x.to(torch.float32).to(device)

    emb = model.source_encoder(x)
    y_pred = model.dis(emb)

    loss = criterion_dis(y_pred, y_dis)

    accu = torch.mean(
        (torch.flatten(torch.argmax(y_pred, dim=1)) == y_dis).to(torch.float32)
    ).cpu()

    return loss, accu


def compute_acc_dis(dataloader_source, dataloader_target, model):
    len_target = len(dataloader_target)
    len_source = len(dataloader_source)

    loss_running = []
    accu_running = []
    mean_weights = []
    model.eval()
    model.source_encoder.eval()
    model.dis.eval()
    with torch.no_grad():
        for y_val, dl in zip([1, 0], [dataloader_target, dataloader_source]):
            for _, (x, _) in enumerate(dl):

                y_dis = torch.full(
                    (x.shape[0],), y_val, device=device, dtype=torch.long
                )

                loss, accu = discrim_loss_accu(x, y_dis, model)

                accu_running.append(accu)
                loss_running.append(loss.item())

                # we will weight average by batch size later
                mean_weights.append(len(x))

    return (
        np.average(loss_running, weights=mean_weights),
        np.average(accu_running, weights=mean_weights),
    )


In [17]:
def train_adversarial(
    model,
    save_folder,
    sc_mix_train_s,
    lab_mix_train,
    mat_sp_train_s,
    dataloader_source_train_eval,
    dataloader_target_train_eval,
):

    model.to(device)
    model.advtraining()
    model.set_encoder("source")

    S_batches = batch_generator(
        [sc_mix_train_s.copy(), lab_mix_train.copy()], BATCH_SIZE
    )
    T_batches = batch_generator(
        [mat_sp_train_s.copy(), np.ones(shape=(len(mat_sp_train_s), 2))], BATCH_SIZE
    )

    enc_optimizer = torch.optim.Adam(
        chain(
            model.source_encoder.parameters(),
            model.dis.parameters(),
            model.clf.parameters(),
        ),
        lr=0.001,
        betas=(0.9, 0.999),
        eps=1e-07,
    )

    dis_optimizer = torch.optim.Adam(
        chain(model.source_encoder.parameters(), model.dis.parameters()),
        lr=ALPHA_LR * 0.001,
        betas=(0.9, 0.999),
        eps=1e-07,
    )

    # Initialize lists to store loss and accuracy values
    loss_history_running = []

    # Train
    print("Start adversarial training...")
    outer = tqdm(total=N_ITER, desc="Iterations", position=0)

    checkpoint = {
        "epoch": -1,
        "model": model,
        "dis_optimizer": dis_optimizer,
        "enc_optimizer": enc_optimizer,
    }
    for iters in range(N_ITER):
        checkpoint["epoch"] = iters

        model.train()
        model.dis.train()
        model.clf.train()
        model.source_encoder.train()

        x_source, y_true = next(S_batches)
        x_target, _ = next(T_batches)

        ## Train encoder
        set_requires_grad(model.source_encoder, True)
        set_requires_grad(model.clf, True)
        set_requires_grad(model.dis, True)

        # save discriminator weights
        dis_weights = deepcopy(model.dis.state_dict())
        new_dis_weights = {}
        for k in dis_weights:
            if "num_batches_tracked" not in k:
                new_dis_weights[k] = dis_weights[k]

        dis_weights = new_dis_weights

        x_source, x_target, y_true, = (
            torch.Tensor(x_source),
            torch.Tensor(x_target),
            torch.Tensor(y_true),
        )
        x_source, x_target, y_true, = (
            x_source.to(device),
            x_target.to(device),
            y_true.to(device),
        )

        x = torch.cat((x_source, x_target))

        # save for discriminator later
        x_d = x.detach()

        # y_dis is the REAL one
        y_dis = torch.cat(
            [
                torch.zeros(x_source.shape[0], device=device, dtype=torch.long),
                torch.ones(x_target.shape[0], device=device, dtype=torch.long),
            ]
        )
        y_dis_flipped = 1 - y_dis.detach()

        emb = model.source_encoder(x).view(x.shape[0], -1)

        y_dis_pred = model.dis(emb)
        y_clf_pred = model.clf(emb)

        # we use flipped because we want to confuse discriminator
        loss_dis = criterion_dis(y_dis_pred, y_dis_flipped)

        # Set true = predicted for target samples since we don't know what it is
        y_clf_true = torch.cat((y_true, y_clf_pred[-x_target.shape[0] :].detach()))
        # Loss fn does mean over all samples including target
        loss_clf = criterion_clf(y_clf_pred, y_clf_true)

        # Scale back up loss so mean doesn't include target
        loss = (x.shape[0] / x_source.shape[0]) * loss_clf + ALPHA * loss_dis

        # loss = loss_clf + ALPHA * loss_dis

        enc_optimizer.zero_grad()
        loss.backward()
        enc_optimizer.step()

        model.dis.load_state_dict(dis_weights, strict=False)

        ## Train discriminator
        set_requires_grad(model.source_encoder, True)
        set_requires_grad(model.clf, True)
        set_requires_grad(model.dis, True)

        # save encoder and clf weights
        source_encoder_weights = deepcopy(model.source_encoder.state_dict())
        clf_weights = deepcopy(model.clf.state_dict())

        new_clf_weights = {}
        for k in clf_weights:
            if "num_batches_tracked" not in k:
                new_clf_weights[k] = clf_weights[k]

        clf_weights = new_clf_weights

        new_source_encoder_weights = {}
        for k in source_encoder_weights:
            if "num_batches_tracked" not in k:
                new_source_encoder_weights[k] = source_encoder_weights[k]

        source_encoder_weights = new_source_encoder_weights

        emb = model.source_encoder(x_d).view(x_d.shape[0], -1)
        y_pred = model.dis(emb)

        # we use the real domain labels to train discriminator
        loss = criterion_dis(y_pred, y_dis)

        dis_optimizer.zero_grad()
        loss.backward()
        dis_optimizer.step()

        model.clf.load_state_dict(clf_weights, strict=False)
        model.source_encoder.load_state_dict(source_encoder_weights, strict=False)

        # Save checkpoint every 100
        if iters % 100 == 99 or iters >= N_ITER - 1:
            torch.save(checkpoint, os.path.join(save_folder, f"checkpt{iters}.pth"))

            model.eval()
            source_loss = compute_acc(dataloader_source_train_eval, model)
            _, dis_accu = compute_acc_dis(
                dataloader_source_train_eval, dataloader_target_train_eval, model
            )

            # Print the results
            print(
                "iter:",
                iters,
                "source loss:",
                round(source_loss, 6),
                "dis accu:",
                round(dis_accu, 6),
            )

        outer.update(1)

    torch.save(checkpoint, os.path.join(save_folder, f"final_model.pth"))


In [18]:
# st_sample_id_l = [SAMPLE_ID_N]


In [19]:
if TRAIN_USING_ALL_ST_SAMPLES:
    print(f"Adversarial training for all ST slides")
    save_folder = advtrain_folder

    best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
    model = best_checkpoint["model"]
    model.to(device)
    model.advtraining()

    train_adversarial(
        model,
        save_folder,
        sc_mix_train_s,
        lab_mix_train,
        mat_sp_train_s,
        dataloader_source_train,
        dataloader_target_train,
    )

else:
    for sample_id in st_sample_id_l:
        print(f"Adversarial training for ST slide {sample_id}: ")

        save_folder = os.path.join(advtrain_folder, sample_id)
        if not os.path.isdir(save_folder):
            os.makedirs(save_folder)

        best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
        model = best_checkpoint["model"]
        model.to(device)
        model.advtraining()

        train_adversarial(
            model,
            save_folder,
            sc_mix_train_s,
            lab_mix_train,
            mat_sp_train_s_d[sample_id],
            dataloader_source_train,
            dataloader_target_train_d[sample_id],
        )


Adversarial training for ST slide 151509: 
Start adversarial training...


Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.236978 dis accu: 0.720562
iter: 199 source loss: 1.260481 dis accu: 0.827787
iter: 299 source loss: 1.201087 dis accu: 0.897172
iter: 399 source loss: 1.271319 dis accu: 0.726008
iter: 499 source loss: 1.299129 dis accu: 0.887329
iter: 599 source loss: 1.278309 dis accu: 0.910928
iter: 699 source loss: 1.286322 dis accu: 0.86607
iter: 799 source loss: 1.276529 dis accu: 0.902537
iter: 899 source loss: 1.235346 dis accu: 0.931219
iter: 999 source loss: 1.185005 dis accu: 0.907217
iter: 1099 source loss: 1.30424 dis accu: 0.782565
iter: 1199 source loss: 1.220146 dis accu: 0.909073
iter: 1299 source loss: 1.128674 dis accu: 0.90754
iter: 1399 source loss: 1.180274 dis accu: 0.684094
iter: 1499 source loss: 1.083691 dis accu: 0.515148
iter: 1599 source loss: 1.043124 dis accu: 0.671346
iter: 1699 source loss: 1.021036 dis accu: 0.589939
iter: 1799 source loss: 1.050688 dis accu: 0.398967
iter: 1899 source loss: 0.993762 dis accu: 0.871959
iter: 1999 source loss: 0.

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.204959 dis accu: 0.925266
iter: 199 source loss: 1.219318 dis accu: 0.909881
iter: 299 source loss: 1.196306 dis accu: 0.992328
iter: 399 source loss: 1.265769 dis accu: 0.928635
iter: 499 source loss: 1.22779 dis accu: 0.94191
iter: 599 source loss: 1.267878 dis accu: 0.986563
iter: 699 source loss: 1.210657 dis accu: 0.720346
iter: 799 source loss: 1.199746 dis accu: 0.826459
iter: 899 source loss: 1.194351 dis accu: 0.589348
iter: 999 source loss: 1.066228 dis accu: 0.823171
iter: 1099 source loss: 1.028084 dis accu: 0.833279


In [None]:
# if TRAIN_USING_ALL_ST_SAMPLES:
#     best_checkpoint = torch.load(os.path.join(advtrain_folder, f"final_model.pth"))
# else:
#     best_checkpoint = torch.load(
#         os.path.join(advtrain_folder, SAMPLE_ID_N, f"final_model.pth")
#     )

# model = best_checkpoint["model"]
# model.to(device)

# model.eval()
# model.set_encoder("source")

# pred_mix = (
#     torch.exp(model(torch.Tensor(sc_mix_test_s).to(device)))
#     .detach()
#     .cpu()
#     .numpy()
# )

# cell_type_nums = sc_sub_dict.keys()
# nrows = ceil(len(cell_type_nums) / 5)

# line_kws = {"color": "tab:orange"}
# scatter_kws = {"s": 5}

# props = dict(facecolor="w", alpha=0.5)

# fig, ax = plt.subplots(
#     nrows,
#     5,
#     figsize=(25, 5 * nrows),
#     constrained_layout=True,
#     sharex=False,
#     sharey=True,
# )
# for i, visnum in enumerate(cell_type_nums):
#     sns.regplot(
#         x=pred_mix[:, visnum],
#         y=lab_mix_test[:, visnum],
#         line_kws=line_kws,
#         scatter_kws=scatter_kws,
#         ax=ax.flat[i],
#     ).set_title(sc_sub_dict[visnum])

#     ax.flat[i].set_aspect("equal")
#     ax.flat[i].set_xlabel("Predicted Proportion")

#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("True Proportion")
#     else:
#         ax.flat[i].set_ylabel("")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     textstr = (
#         f"MSE: {mean_squared_error(pred_mix[:,visnum], lab_mix_test[:,visnum]):.5f}"
#     )

#     # place a text box in upper left in axes coords
#     ax.flat[i].text(
#         0.95,
#         0.05,
#         textstr,
#         transform=ax.flat[i].transAxes,
#         verticalalignment="bottom",
#         horizontalalignment="right",
#         bbox=props,
#     )

# for i in range(len(cell_type_nums), nrows * 5):
#     ax.flat[i].axis("off")

# plt.show()
