 # ADDA for ST

 Creating something like CellDART but it actually follows Adda in PyTorch as a first step

In [1]:
import math
import os
import datetime
from copy import deepcopy
from itertools import count

from tqdm.autonotebook import tqdm

import h5py
import pickle
import numpy as np


import torch
from torch.nn import functional as F
from torch import nn

from src.da_models.adda import ADDAST
from src.da_models.datasets import SpotDataset
from src.da_models.utils import set_requires_grad
from src.da_models.utils import initialize_weights


# datetime object containing current date and time
script_start_time = datetime.datetime.now().strftime("%Y-%m-%d_%Hh%Mm%S")


  from tqdm.autonotebook import tqdm


In [2]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
print(device)

cuda:3


In [3]:
TRAIN_USING_ALL_ST_SAMPLES = False

SAMPLE_ID_N = "151673"

BATCH_SIZE = 1024
NUM_WORKERS = 4
INITIAL_TRAIN_EPOCHS = 200


MIN_EPOCHS = 0.4 * INITIAL_TRAIN_EPOCHS
EARLY_STOP_CRIT = INITIAL_TRAIN_EPOCHS

PROCESSED_DATA_DIR = "data/preprocessed"

MODEL_NAME = "ADDA"


In [4]:
## Adversarial Hyperparameters

EPOCHS = 100
MIN_EPOCHS_ADV = 0.4 * EPOCHS
EARLY_STOP_CRIT_ADV = EPOCHS
ENC_LR = 0.0002
ADAM_BETA_1 = 0.5
ALPHA = 2
DIS_LOOP_FACTOR = 5

In [5]:
model_folder = os.path.join("model", MODEL_NAME, script_start_time)

model_folder = os.path.join("model", MODEL_NAME, "V2")

if not os.path.isdir(model_folder):
    os.makedirs(model_folder)
    print(model_folder)

# if not os.path.isdir(results_folder):
#     os.makedirs(results_folder)


In [6]:
# # sc.logging.print_versions()
# sc.set_figure_params(facecolor="white", figsize=(8, 8))
# sc.settings.verbosity = 3


 # Data load

In [7]:
# Load spatial data
mat_sp_test_s_d = {}
with h5py.File(os.path.join(PROCESSED_DATA_DIR, "mat_sp_test_s_d.hdf5"), "r") as f:
    for sample_id in f:
        mat_sp_test_s_d[sample_id] = f[sample_id][()]

if TRAIN_USING_ALL_ST_SAMPLES:
    with h5py.File(os.path.join(PROCESSED_DATA_DIR, "mat_sp_train_s.hdf5"), "r") as f:
        mat_sp_train_s = f["all"][()]
else:
    mat_sp_train_s_d = mat_sp_test_s_d

# Load sc data
with h5py.File(os.path.join(PROCESSED_DATA_DIR, "sc.hdf5"), "r") as f:
    sc_mix_train_s = f["X/train"][()]
    sc_mix_val_s = f["X/val"][()]
    sc_mix_test_s = f["X/test"][()]

    lab_mix_train = f["y/train"][()]
    lab_mix_val = f["y/val"][()]
    lab_mix_test = f["y/test"][()]

# Load helper dicts / lists
with open(os.path.join(PROCESSED_DATA_DIR, "sc_sub_dict.pkl"), "rb") as f:
    sc_sub_dict = pickle.load(f)

with open(os.path.join(PROCESSED_DATA_DIR, "sc_sub_dict2.pkl"), "rb") as f:
    sc_sub_dict2 = pickle.load(f)

with open(os.path.join(PROCESSED_DATA_DIR, "st_sample_id_l.pkl"), "rb") as f:
    st_sample_id_l = pickle.load(f)


 # Training: Adversarial domain adaptation for cell fraction estimation

 ## Prepare dataloaders

In [8]:
### source dataloaders
source_train_set = SpotDataset(sc_mix_train_s, lab_mix_train)
source_val_set = SpotDataset(sc_mix_val_s, lab_mix_val)
source_test_set = SpotDataset(sc_mix_test_s, lab_mix_test)

dataloader_source_train = torch.utils.data.DataLoader(
    source_train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
dataloader_source_val = torch.utils.data.DataLoader(
    source_val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
dataloader_source_test = torch.utils.data.DataLoader(
    source_test_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

### target dataloaders
target_test_set_d = {}
for sample_id in st_sample_id_l:
    target_test_set_d[sample_id] = SpotDataset(mat_sp_test_s_d[sample_id])

dataloader_target_test_d = {}
for sample_id in st_sample_id_l:
    dataloader_target_test_d[sample_id] = torch.utils.data.DataLoader(
        target_test_set_d[sample_id],
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )

if TRAIN_USING_ALL_ST_SAMPLES:
    target_train_set = SpotDataset(mat_sp_train_s)
    dataloader_target_train = torch.utils.data.DataLoader(
        target_train_set,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    target_train_set_dis = SpotDataset(deepcopy(mat_sp_train_s))
    dataloader_target_train_dis = torch.utils.data.DataLoader(
        target_train_set_dis,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
else:
    target_train_set_d = {}
    dataloader_target_train_d = {}

    target_train_set_dis_d = {}
    dataloader_target_train_dis_d = {}
    for sample_id in st_sample_id_l:
        target_train_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_test_s_d[sample_id]))
        dataloader_target_train_d[sample_id] = torch.utils.data.DataLoader(
            target_train_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=True,
        )

        target_train_set_dis_d[sample_id] = SpotDataset(deepcopy(mat_sp_test_s_d[sample_id]))
        dataloader_target_train_dis_d[sample_id] = torch.utils.data.DataLoader(
            target_train_set_dis_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=True,
        )


 ## Define Model

In [9]:
model = ADDAST(sc_mix_train_s.shape[1], emb_dim=64, ncls_source=lab_mix_train.shape[1], is_adda=True)
model.apply(initialize_weights)
model.to(device)


ADDAST(
  (source_encoder): ADDAMLPEncoder(
    (encoder): Sequential(
      (0): BatchNorm1d(367, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (1): Dropout(p=0.5, inplace=False)
      (2): Linear(in_features=367, out_features=1024, bias=True)
      (3): BatchNorm1d(1024, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (4): LeakyReLU(negative_slope=0.01)
      (5): Dropout(p=0.5, inplace=False)
      (6): Linear(in_features=1024, out_features=64, bias=True)
      (7): Tanh()
    )
  )
  (target_encoder): ADDAMLPEncoder(
    (encoder): Sequential(
      (0): BatchNorm1d(367, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (1): Dropout(p=0.5, inplace=False)
      (2): Linear(in_features=367, out_features=1024, bias=True)
      (3): BatchNorm1d(1024, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (4): LeakyReLU(negative_slope=0.01)
      (5): Dropout(p=0.5, inplace=False)
      (6): Linear(in_f

 ## Pretrain

In [10]:
pretrain_folder = os.path.join(model_folder, "pretrain")

if not os.path.isdir(pretrain_folder):
    os.makedirs(pretrain_folder)


In [11]:
pre_optimizer = torch.optim.Adam(
    model.parameters(), lr=0.002, betas=(0.9, 0.999), eps=1e-07
)

pre_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    pre_optimizer,
    max_lr=0.002,
    steps_per_epoch=len(dataloader_source_train),
    epochs=INITIAL_TRAIN_EPOCHS,
)

criterion_clf = nn.KLDivLoss(reduction="batchmean")


In [12]:
def model_loss(x, y_true, model):
    x = x.to(torch.float32).to(device)
    y_true = y_true.to(torch.float32).to(device)

    y_pred = model(x)

    loss = criterion_clf(y_pred, y_true)

    return loss


def compute_acc(dataloader, model):
    loss_running = []
    mean_weights = []
    model.eval()
    with torch.no_grad():
        for _, batch in enumerate(dataloader):

            loss = model_loss(*batch, model)

            loss_running.append(loss.item())

            # we will weight average by batch size later
            mean_weights.append(len(batch))

    return np.average(loss_running, weights=mean_weights)


In [13]:
model.pretraining()


In [14]:
"""
# Initialize lists to store loss and accuracy values
loss_history = []
loss_history_val = []

loss_history_running = []

# Early Stopping
best_loss_val = np.inf
early_stop_count = 0


# Train
print("Start pretrain...")
outer = tqdm(total=INITIAL_TRAIN_EPOCHS, desc="Epochs", position=0)
inner = tqdm(total=len(dataloader_source_train), desc=f"Batch", position=1)

checkpoint = {
    "epoch": -1,
    "model": model,
    "optimizer": pre_optimizer,
    "scheduler": pre_scheduler,
    # 'scaler': scaler
}
for epoch in range(INITIAL_TRAIN_EPOCHS):
    checkpoint["epoch"] = epoch

    # Train mode
    model.train()
    loss_running = []
    mean_weights = []

    inner.refresh()  # force print final state
    inner.reset()  # reuse bar
    for _, batch in enumerate(dataloader_source_train):
        # lr_history_running.append(scheduler.get_last_lr())

        pre_optimizer.zero_grad()
        loss = model_loss(*batch, model)
        loss_running.append(loss.item())
        mean_weights.append(len(batch))  # we will weight average by batch size later

        # scaler.scale(loss).backward()
        # scaler.step(optimizer)
        # scaler.update()

        loss.backward()
        pre_optimizer.step()
        pre_scheduler.step()

        inner.update(1)

    loss_history.append(np.average(loss_running, weights=mean_weights))
    loss_history_running.append(loss_running)

    # Evaluate mode
    model.eval()
    with torch.no_grad():
        curr_loss_val = compute_acc(dataloader_source_val, model)
        loss_history_val.append(curr_loss_val)

    # Print the results
    outer.update(1)
    print(
        "epoch:",
        epoch,
        "train loss:",
        round(loss_history[-1], 6),
        "validation loss:",
        round(loss_history_val[-1], 6),
        # "next_lr:", scheduler.get_last_lr(),
        end=" ",
    )
    # Save the best weights
    if curr_loss_val < best_loss_val:
        best_loss_val = curr_loss_val
        torch.save(checkpoint, os.path.join(pretrain_folder, f"best_model.pth"))
        early_stop_count = 0

        print("<-- new best val loss")
    else:
        print("")

    # Save checkpoint every 10
    if epoch % 10 == 0 or epoch >= INITIAL_TRAIN_EPOCHS - 1:
        torch.save(checkpoint, os.path.join(pretrain_folder, f"checkpt{epoch}.pth"))

    # check to see if validation loss has plateau'd
    if early_stop_count >= EARLY_STOP_CRIT and epoch >= MIN_EPOCHS - 1:
        print(f"Validation loss plateaued after {early_stop_count} at epoch {epoch}")
        torch.save(checkpoint, os.path.join(pretrain_folder, f"earlystop{epoch}.pth"))
        break

    early_stop_count += 1


# Save final model
best_checkpoint = torch.load(os.path.join(pretrain_folder, f"best_model.pth"))
torch.save(best_checkpoint, os.path.join(pretrain_folder, f"final_model.pth"))
"""

'\n# Initialize lists to store loss and accuracy values\nloss_history = []\nloss_history_val = []\n\nloss_history_running = []\n\n# Early Stopping\nbest_loss_val = np.inf\nearly_stop_count = 0\n\n\n# Train\nprint("Start pretrain...")\nouter = tqdm(total=INITIAL_TRAIN_EPOCHS, desc="Epochs", position=0)\ninner = tqdm(total=len(dataloader_source_train), desc=f"Batch", position=1)\n\ncheckpoint = {\n    "epoch": -1,\n    "model": model,\n    "optimizer": pre_optimizer,\n    "scheduler": pre_scheduler,\n    # \'scaler\': scaler\n}\nfor epoch in range(INITIAL_TRAIN_EPOCHS):\n    checkpoint["epoch"] = epoch\n\n    # Train mode\n    model.train()\n    loss_running = []\n    mean_weights = []\n\n    inner.refresh()  # force print final state\n    inner.reset()  # reuse bar\n    for _, batch in enumerate(dataloader_source_train):\n        # lr_history_running.append(scheduler.get_last_lr())\n\n        pre_optimizer.zero_grad()\n        loss = model_loss(*batch, model)\n        loss_running.appen

 ## Adversarial Adaptation

In [15]:
advtrain_folder = os.path.join(model_folder, "advtrain")

if not os.path.isdir(advtrain_folder):
    os.makedirs(advtrain_folder)


In [16]:
def cycle_iter(iter):
    while True:
        yield from iter


def iter_skip(iter, n=1):
    for i in range(len(iter)*n):
        if (i % n) == n - 1:
            yield next(iter)
        else:
            yield None, None


In [17]:
criterion_dis = nn.BCEWithLogitsLoss()


In [18]:
def discrim_loss_accu(x, domain, model):
    x = x.to(device)

    if domain == 'source':
        y_dis = torch.zeros(x.shape[0], device=device, dtype=x.dtype).view(-1, 1)
        emb = model.source_encoder(x) #.view(x.shape[0], -1)
    elif domain == 'target':
        y_dis = torch.ones(x.shape[0], device=device, dtype=x.dtype).view(-1, 1)
        emb = model.target_encoder(x) #.view(x.shape[0], -1)
    else:
        raise(ValueError, f"invalid domain {domain} given, must be 'source' or 'target'")
    
    y_pred = model.dis(emb)
    
    loss = criterion_dis(y_pred, y_dis)
    accu = torch.mean(
        (torch.round(y_pred).to(torch.long) == y_dis).to(torch.float32)
    ).cpu()

    return loss, accu


def compute_acc_dis(dataloader_source, dataloader_target, model):
    loss_history = []
    accu_history = []
    # iters = max(len(dataloader_source), len(dataloader_target))
    model.eval()
    model.dis.eval()
    model.target_encoder.eval()
    model.source_encoder.eval()
    with torch.no_grad():
        loss_running = []
        accu_running = []
        mean_weights = []
        # batch_cycler = zip(cycle_iter(dataloader_source), cycle_iter(dataloader_target))
        for _, (X, _) in enumerate(dataloader_source):
            X = X.to(device)

            y_dis = torch.zeros(X.shape[0], device=device, dtype=X.dtype).view(-1, 1)

            emb = model.source_encoder(X) #.view(X.shape[0], -1)

            y_pred = model.dis(emb)

            loss_running.append(criterion_dis(y_pred, y_dis).item())
            accu_running.append(
                torch.mean(
                    (torch.flatten(torch.argmax(y_pred, dim=1)) == y_dis).to(
                        torch.float32
                    )
                ).cpu()
            )
        loss_history.append(np.average(loss_running, weights=mean_weights))
        accu_history.append(np.average(accu_running, weights=mean_weights))

        loss_running = []
        accu_running = []
        mean_weights = []
        for _, (X, _) in enumerate(dataloader_target):
            X = X.to(device)

            y_dis = torch.ones(X.shape[0], device=device, dtype=X.dtype).view(-1, 1)

            emb = model.source_encoder(X) #.view(X.shape[0], -1)

            y_pred = model.dis(emb)

            loss_running.append(criterion_dis(y_pred, y_dis).item())
            accu_running.append(
                torch.mean(
                    (torch.flatten(torch.argmax(y_pred, dim=1)) == y_dis).to(
                        torch.float32
                    )
                ).cpu()
            )
        loss_history.append(np.average(loss_running, weights=mean_weights))
        accu_history.append(np.average(accu_running, weights=mean_weights))

    return np.average(loss_history), np.average(accu_history)


def encoder_loss(x_target, model):
    x_target = x_target.to(device)

    # flip label
    y_dis = torch.zeros(x_target.shape[0], device=device, dtype=x_target.dtype).view(-1, 1)

    emb_target = model.target_encoder(x_target) #.view(x_target.shape[0], -1)
    y_pred = model.dis(emb_target)
    loss = criterion_dis(y_pred, y_dis)

    return loss


In [19]:
# def train_adversarial(
#     model,
#     save_folder,
#     dataloader_source_train,
#     dataloader_source_val,
#     dataloader_target_train,
# ):
#     model.to(device)
#     model.advtraining()

#     target_optimizer = torch.optim.Adam(
#         model.target_encoder.parameters(), lr=0.0005, betas=(0.9, 0.999), eps=1e-07
#     )
#     dis_optimizer = torch.optim.Adam(
#         model.dis.parameters(), lr=0.00025, betas=(0.9, 0.999), eps=1e-07
#     )

#     iters = max(len(dataloader_source_train), len(dataloader_target_train))

#     dis_scheduler = torch.optim.lr_scheduler.OneCycleLR(
#         dis_optimizer, max_lr=0.0005, steps_per_epoch=iters, epochs=EPOCHS
#     )
#     target_scheduler = torch.optim.lr_scheduler.OneCycleLR(
#         target_optimizer, max_lr=0.0005, steps_per_epoch=iters, epochs=EPOCHS
#     )

#     n_samples_source = len(dataloader_source_train.dataset)
#     n_samples_target = len(dataloader_target_train.dataset)
#     p = n_samples_source / (n_samples_source + n_samples_target)
#     rand_loss = -(p * np.log(0.5)) - (1 - p) * np.log(0.5)

#     # Initialize lists to store loss and accuracy values
#     loss_history = []
#     accu_history = []
#     loss_history_running = []

#     # Early Stopping
#     best_loss_val = np.inf
#     early_stop_count = 0

#     # Train
#     print("Start adversarial training...")
#     print("Discriminator target loss:", rand_loss)
#     outer = tqdm(total=EPOCHS, desc="Epochs", position=0)
#     inner1 = tqdm(total=iters, desc=f"Batch (Discriminator)", position=1)
#     inner2 = tqdm(total=iters, desc=f"Batch (Encoder)", position=2)
#     checkpoint = {
#         "epoch": -1,
#         "model": model,
#         "dis_optimizer": dis_optimizer,
#         "target_optimizer": target_optimizer,
#         "dis_scheduler": dis_scheduler,
#         "target_scheduler": target_scheduler,
#     }
#     for epoch in range(EPOCHS):
#         checkpoint["epoch"] = epoch

#         # Train mode
#         model.train()

#         loss_running = []
#         accu_running = []
#         mean_weights = []

#         inner1.refresh()  # force print final state
#         inner1.reset()  # reuse bar
#         inner2.refresh()  # force print final state
#         inner2.reset()  # reuse bar

#         model.train_discriminator()
#         model.target_encoder.eval()
#         model.source_encoder.eval()
#         model.dis.train()
#         batch_cycler = zip(
#             cycle_iter(dataloader_source_train), cycle_iter(dataloader_target_train)
#         )
#         for _ in range(iters):
#             # lr_history_running.append(scheduler.get_last_lr())
#             dis_optimizer.zero_grad()

#             (x_source, _), (x_target, _) = next(batch_cycler)
#             loss, accu = discrim_loss_accu(x_source, x_target, model)
#             loss_running.append(loss.item())
#             accu_running.append(accu)
#             mean_weights.append(len(x_source) + len(x_target))

#             # scaler.scale(loss).backward()
#             # scaler.step(optimizer)
#             # scaler.update()

#             loss.backward()
#             dis_optimizer.step()
#             dis_scheduler.step()

#             inner1.update(1)

#         loss_history.append(np.average(loss_running, weights=mean_weights))
#         accu_history.append(np.average(accu_running, weights=mean_weights))
#         loss_history_running.append(loss_running)

#         model.train_target_encoder()
#         model.target_encoder.train()
#         model.source_encoder.eval()
#         model.dis.eval()
#         batch_cycler = zip(
#             cycle_iter(dataloader_source_train), cycle_iter(dataloader_target_train)
#         )
#         for _ in range(iters):
#             target_optimizer.zero_grad()

#             _, (x_target, _) = next(batch_cycler)
#             loss = encoder_loss(x_target, model)

#             loss.backward()
#             target_optimizer.step()
#             target_scheduler.step()

#             inner2.update(1)

#         diff_from_rand = math.fabs(loss_history[-1] - rand_loss)

#         # Print the results
#         outer.update(1)
#         print(
#             "epoch:",
#             epoch,
#             "dis loss:",
#             round(loss_history[-1], 6),
#             "dis accu:",
#             round(accu_history[-1], 6),
#             "difference from random loss:",
#             round(diff_from_rand, 6),
#             # "next_lr:", scheduler.get_last_lr(),
#             end=" ",
#         )

#         # Save the best weights
#         if diff_from_rand < best_loss_val:
#             best_loss_val = diff_from_rand
#             torch.save(checkpoint, os.path.join(save_folder, f"best_model.pth"))
#             early_stop_count = 0

#             print("<-- new best difference from random loss")
#         else:
#             print("")

#         # Save checkpoint every 10
#         if epoch % 10 == 0 or epoch >= EPOCHS - 1:
#             torch.save(checkpoint, os.path.join(save_folder, f"checkpt{epoch}.pth"))

#         # check to see if validation loss has plateau'd
#         if early_stop_count >= EARLY_STOP_CRIT_ADV and epoch > MIN_EPOCHS_ADV - 1:
#             print(
#                 f"Discriminator loss plateaued after {early_stop_count} at epoch {epoch}"
#             )
#             torch.save(checkpoint, os.path.join(save_folder, f"earlystop_{epoch}.pth"))
#             break

#         early_stop_count += 1

#     # Save final model
#     torch.save(checkpoint, os.path.join(save_folder, f"final_model.pth"))


In [20]:
def train_adversarial_iters(
    model,
    save_folder,
    dataloader_source_train,
    dataloader_source_val,
    dataloader_target_train,
    dataloader_target_train_dis,
):
    model.to(device)
    model.advtraining()

    target_optimizer = torch.optim.Adam(
        model.target_encoder.parameters(), lr=ENC_LR, betas=(ADAM_BETA_1, 0.999), eps=1e-07
    )
    dis_optimizer = torch.optim.Adam(
        model.dis.parameters(), lr=ALPHA * ENC_LR, betas=(ADAM_BETA_1, 0.999), eps=1e-07
    )

    # iters = -(max_len_dataloader // -(1 + DIS_LOOP_FACTOR))  # ceiling divide

    dataloader_lengths = [
        len(dataloader_source_train),
        len(dataloader_target_train),
        len(dataloader_target_train_dis) * DIS_LOOP_FACTOR,
    ]
    max_len_dataloader = np.amax(dataloader_lengths)
    longest = np.argmax(dataloader_lengths)

    iters_val = max(len(dataloader_source_val), len(dataloader_target_train))

    # dis_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    #     dis_optimizer, max_lr=0.0005, steps_per_epoch=iters, epochs=EPOCHS
    # )
    # target_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    #     target_optimizer, max_lr=0.0005, steps_per_epoch=iters, epochs=EPOCHS
    # )

    # Initialize lists to store loss and accuracy values
    loss_history = []
    accu_history = []
    loss_history_val = []
    accu_history_val = []
    loss_history_running = []

    loss_history_gen = []
    loss_history_gen_running = []
    mean_weights_gen = []

    # Early Stopping
    best_loss_val = np.inf
    early_stop_count = 0
    # Train
    print("Start adversarial training...")
    outer = tqdm(total=EPOCHS, desc="Epochs", position=0)
    inner1 = tqdm(total=max_len_dataloader, desc=f"Batch", position=1)
    checkpoint = {
        "epoch": -1,
        "model": model,
        "dis_optimizer": dis_optimizer,
        "target_optimizer": target_optimizer,
        # "dis_scheduler": dis_scheduler,
        # "target_scheduler": target_scheduler,
    }
    for epoch in range(EPOCHS):
        checkpoint["epoch"] = epoch

        # Train mode
        model.train()
        model.target_encoder.train()
        model.source_encoder.train()
        model.dis.train()

        loss_running = []
        accu_running = []
        mean_weights = []

        loss_running_gen = []
        mean_weights_gen = []

        inner1.refresh()  # force print final state
        inner1.reset()  # reuse bar


        # if longest == 0:
        #     batch_cycler = zip(
        #         dataloader_source_train,
        #         cycle_iter(dataloader_target_train),
        #         cycle_iter(iter_skip(iter(dataloader_target_train_dis), DIS_LOOP_FACTOR))
        #     )
        # elif longest == 1:
        #     batch_cycler = zip(
        #         cycle_iter(dataloader_source_train),
        #         dataloader_target_train,
        #         cycle_iter(iter_skip(iter(dataloader_target_train_dis), DIS_LOOP_FACTOR))
        #     )
        # else:
        #     batch_cycler = zip(
        #         cycle_iter(dataloader_source_train),
        #         cycle_iter(dataloader_target_train),
        #         iter_skip(iter(dataloader_target_train_dis), DIS_LOOP_FACTOR)
        #     )

        s_train_iter = iter(dataloader_source_train)
        t_train_iter = iter(dataloader_target_train)
        t_train_dis_iter = iter(dataloader_target_train_dis)
        for i in range(max_len_dataloader):
            try:
                x_source, _ = next(s_train_iter)
            except StopIteration:
                s_train_iter = iter(dataloader_source_train)
                x_source, _ = next(s_train_iter)
            try:
                x_target, _ = next(t_train_iter)
            except StopIteration:
                t_train_iter = iter(dataloader_target_train)
                x_target, _ = next(t_train_iter)

            train_encoder_step = (i % DIS_LOOP_FACTOR) == DIS_LOOP_FACTOR - 1
            
            # print(x_target_enc, (i % DIS_LOOP_FACTOR))
            model.train_discriminator()
            # model.target_encoder.train()
            # model.source_encoder.train()
            # model.dis.train()

            set_requires_grad(model.target_encoder, False)
            set_requires_grad(model.source_encoder, False)
            set_requires_grad(model.dis, True)

            # lr_history_running.append(scheduler.get_last_lr())
            dis_optimizer.zero_grad()

            loss, accu = discrim_loss_accu(x_source, 'source', model)
            loss_running.append(loss.item())
            accu_running.append(accu)
            mean_weights.append(len(x_source))

            
            # scaler.scale(loss).backward()
            # scaler.step(optimizer)
            # scaler.update()

            loss.backward()
            dis_optimizer.step()

            dis_optimizer.zero_grad()

            loss, accu = discrim_loss_accu(x_target,'target', model)
            loss_running.append(loss.item())
            accu_running.append(accu)
            mean_weights.append(len(x_target))

            
            # scaler.scale(loss).backward()
            # scaler.step(optimizer)
            # scaler.update()

            loss.backward()
            dis_optimizer.step()
            # dis_scheduler.step()
                
            # print(i % DIS_LOOP_FACTOR)
            if train_encoder_step:
                try:
                    x_target_enc, _ = next(t_train_dis_iter)
                except StopIteration:
                    t_train_dis_iter = iter(dataloader_target_train_dis)
                    x_target_enc, _ = next(t_train_dis_iter)
                model.train_target_encoder()
                # model.target_encoder.train()
                # model.source_encoder.train()
                # model.dis.train()

                set_requires_grad(model.target_encoder, True)
                set_requires_grad(model.source_encoder, False)
                set_requires_grad(model.dis, False)

                target_optimizer.zero_grad()

                loss = encoder_loss(x_target_enc, model)

                loss_running_gen.append(loss.item())
                mean_weights_gen.append(len(x_target_enc))

                loss.backward()
                target_optimizer.step()
            # target_scheduler.step()

            inner1.update(1)
        loss_history.append(np.average(loss_running, weights=mean_weights))
        accu_history.append(np.average(accu_running, weights=mean_weights))
        loss_history_running.append(loss_running)
        loss_history_gen.append(np.average(loss_running_gen, weights=mean_weights_gen))
        loss_history_gen_running.append(loss_running_gen)
        # ensure batch_cycler gets garbage collected, freeing the dataloaders
        
        model.eval()
        model.dis.eval()
        model.target_encoder.eval()
        model.source_encoder.eval()

        set_requires_grad(model, True)
        set_requires_grad(model.target_encoder, True)
        set_requires_grad(model.source_encoder, True)
        set_requires_grad(model.dis, True)

        # del batch_cycler
        # with torch.no_grad():
        #     curr_loss_val, curr_acc_val = compute_acc_dis(
        #         dataloader_source_val, dataloader_target_train, model
        #     )
        #     loss_history_val.append(curr_loss_val)
        #     accu_history_val.append(curr_loss_val)

        # Print the results
        outer.update(1)
        print(
            "epoch:",
            epoch,
            "gen train loss:",
            round(loss_history_gen[-1], 6),
            "dis train loss:",
            round(loss_history[-1], 6),
            "dis train accu:",
            round(accu_history[-1], 6),
            # "dis val loss:",
            # round(loss_history_val[-1], 6),
            # "dis val accu:",
            # round(accu_history_val[-1], 6),
            # "next_lr:", scheduler.get_last_lr(),
            end=" ",
        )

        # # Save the best weights
        # if diff_from_rand < best_loss_val:
        #     best_loss_val = diff_from_rand
        #     torch.save(checkpoint, os.path.join(save_folder, f"best_model.pth"))
        #     early_stop_count = 0

        #     print("<-- new best difference from random loss")
        # else:
        #     print("")

        print("")

        # Save checkpoint every 10
        if epoch % 10 == 0 or epoch >= EPOCHS - 1:
            torch.save(checkpoint, os.path.join(save_folder, f"checkpt{epoch}.pth"))

        # check to see if validation loss has plateau'd
        if early_stop_count >= EARLY_STOP_CRIT_ADV and epoch > MIN_EPOCHS_ADV - 1:
            print(
                f"Discriminator loss plateaued after {early_stop_count} at epoch {epoch}"
            )
            torch.save(checkpoint, os.path.join(save_folder, f"earlystop_{epoch}.pth"))
            break

        early_stop_count += 1

    # Save final model
    torch.save(checkpoint, os.path.join(save_folder, f"final_model.pth"))


In [21]:
# st_sample_id_l = [SAMPLE_ID_N]


In [22]:
if TRAIN_USING_ALL_ST_SAMPLES:
    print(f"Adversarial training for all ST slides")
    save_folder = advtrain_folder

    best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
    model = best_checkpoint["model"]
    model.to(device)
    model.advtraining()

    train_adversarial_iters(
        model,
        save_folder,
        dataloader_source_train,
        dataloader_source_val,
        dataloader_target_train,
        dataloader_target_train_dis
    )

else:
    for sample_id in st_sample_id_l:
        print(f"Adversarial training for ST slide {sample_id}: ")

        save_folder = os.path.join(advtrain_folder, sample_id)
        if not os.path.isdir(save_folder):
            os.makedirs(save_folder)

        best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
        model = best_checkpoint["model"]
        model.to(device)
        model.advtraining()

        train_adversarial_iters(
            model,
            save_folder,
            dataloader_source_train,
            dataloader_source_val,
            dataloader_target_train_d[sample_id],
            dataloader_target_train_dis_d[sample_id]
        )


Adversarial training for ST slide 151509: 
Start adversarial training...


Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/25 [00:00<?, ?it/s]

epoch: 0 gen train loss: 0.644878 dis train loss: 0.726105 dis train accu: 0.38608 
epoch: 1 gen train loss: 0.654195 dis train loss: 0.720103 dis train accu: 0.407582 
epoch: 2 gen train loss: 0.657657 dis train loss: 0.71618 dis train accu: 0.423133 
epoch: 3 gen train loss: 0.664204 dis train loss: 0.715938 dis train accu: 0.428554 
epoch: 4 gen train loss: 0.667946 dis train loss: 0.714565 dis train accu: 0.436238 
epoch: 5 gen train loss: 0.67313 dis train loss: 0.713817 dis train accu: 0.441618 
epoch: 6 gen train loss: 0.673681 dis train loss: 0.71155 dis train accu: 0.443738 
epoch: 7 gen train loss: 0.679589 dis train loss: 0.711386 dis train accu: 0.450056 
epoch: 8 gen train loss: 0.676709 dis train loss: 0.711322 dis train accu: 0.454907 
epoch: 9 gen train loss: 0.678353 dis train loss: 0.709704 dis train accu: 0.455518 
epoch: 10 gen train loss: 0.687062 dis train loss: 0.711099 dis train accu: 0.460471 
epoch: 11 gen train loss: 0.688761 dis train loss: 0.708933 dis trai

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/25 [00:00<?, ?it/s]

epoch: 0 gen train loss: 0.647467 dis train loss: 0.724435 dis train accu: 0.38859 
epoch: 1 gen train loss: 0.652521 dis train loss: 0.717691 dis train accu: 0.413191 
epoch: 2 gen train loss: 0.656106 dis train loss: 0.716252 dis train accu: 0.425513 
epoch: 3 gen train loss: 0.664434 dis train loss: 0.712843 dis train accu: 0.435204 
epoch: 4 gen train loss: 0.667705 dis train loss: 0.712632 dis train accu: 0.438973 
epoch: 5 gen train loss: 0.667322 dis train loss: 0.713092 dis train accu: 0.445993 
epoch: 6 gen train loss: 0.673688 dis train loss: 0.712343 dis train accu: 0.452185 
epoch: 7 gen train loss: 0.682112 dis train loss: 0.710444 dis train accu: 0.458936 
epoch: 8 gen train loss: 0.686831 dis train loss: 0.709582 dis train accu: 0.459847 
epoch: 9 gen train loss: 0.684036 dis train loss: 0.709403 dis train accu: 0.46345 
epoch: 10 gen train loss: 0.686308 dis train loss: 0.709224 dis train accu: 0.466101 
epoch: 11 gen train loss: 0.683104 dis train loss: 0.708287 dis tr

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/25 [00:00<?, ?it/s]

epoch: 0 gen train loss: 0.649842 dis train loss: 0.721493 dis train accu: 0.405102 
epoch: 1 gen train loss: 0.654955 dis train loss: 0.718384 dis train accu: 0.424502 
epoch: 2 gen train loss: 0.659443 dis train loss: 0.715116 dis train accu: 0.440661 
epoch: 3 gen train loss: 0.670298 dis train loss: 0.713756 dis train accu: 0.448215 
epoch: 4 gen train loss: 0.671628 dis train loss: 0.711753 dis train accu: 0.459974 
epoch: 5 gen train loss: 0.667552 dis train loss: 0.711325 dis train accu: 0.463499 
epoch: 6 gen train loss: 0.676726 dis train loss: 0.709235 dis train accu: 0.4709 
epoch: 7 gen train loss: 0.685112 dis train loss: 0.710186 dis train accu: 0.475126 
epoch: 8 gen train loss: 0.680999 dis train loss: 0.70929 dis train accu: 0.482198 
epoch: 9 gen train loss: 0.680756 dis train loss: 0.709343 dis train accu: 0.481804 
epoch: 10 gen train loss: 0.684632 dis train loss: 0.709927 dis train accu: 0.488789 
epoch: 11 gen train loss: 0.684882 dis train loss: 0.706007 dis tra

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/25 [00:00<?, ?it/s]

epoch: 0 gen train loss: 0.649123 dis train loss: 0.722318 dis train accu: 0.398172 
epoch: 1 gen train loss: 0.652893 dis train loss: 0.718067 dis train accu: 0.424256 
epoch: 2 gen train loss: 0.661271 dis train loss: 0.713233 dis train accu: 0.432696 
epoch: 3 gen train loss: 0.662851 dis train loss: 0.713406 dis train accu: 0.441624 
epoch: 4 gen train loss: 0.669672 dis train loss: 0.711986 dis train accu: 0.451807 
epoch: 5 gen train loss: 0.669902 dis train loss: 0.711126 dis train accu: 0.455995 
epoch: 6 gen train loss: 0.67315 dis train loss: 0.710448 dis train accu: 0.460247 
epoch: 7 gen train loss: 0.676827 dis train loss: 0.709294 dis train accu: 0.464881 
epoch: 8 gen train loss: 0.681585 dis train loss: 0.708915 dis train accu: 0.470259 
epoch: 9 gen train loss: 0.682028 dis train loss: 0.709502 dis train accu: 0.474298 
epoch: 10 gen train loss: 0.684287 dis train loss: 0.707879 dis train accu: 0.476807 
epoch: 11 gen train loss: 0.688081 dis train loss: 0.70711 dis tr

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/20 [00:00<?, ?it/s]

epoch: 0 gen train loss: 0.646722 dis train loss: 0.722689 dis train accu: 0.396532 
epoch: 1 gen train loss: 0.658346 dis train loss: 0.719287 dis train accu: 0.41131 
epoch: 2 gen train loss: 0.66193 dis train loss: 0.713429 dis train accu: 0.42622 
epoch: 3 gen train loss: 0.66165 dis train loss: 0.713012 dis train accu: 0.438037 
epoch: 4 gen train loss: 0.66686 dis train loss: 0.712996 dis train accu: 0.444519 
epoch: 5 gen train loss: 0.668344 dis train loss: 0.712764 dis train accu: 0.451214 
epoch: 6 gen train loss: 0.668638 dis train loss: 0.710394 dis train accu: 0.457402 
epoch: 7 gen train loss: 0.669699 dis train loss: 0.711926 dis train accu: 0.457215 
epoch: 8 gen train loss: 0.677337 dis train loss: 0.709809 dis train accu: 0.466284 
epoch: 9 gen train loss: 0.681205 dis train loss: 0.710903 dis train accu: 0.465431 
epoch: 10 gen train loss: 0.679178 dis train loss: 0.709789 dis train accu: 0.466178 
epoch: 11 gen train loss: 0.680067 dis train loss: 0.710627 dis train

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/25 [00:00<?, ?it/s]

epoch: 0 gen train loss: 0.644113 dis train loss: 0.721091 dis train accu: 0.405795 
epoch: 1 gen train loss: 0.658612 dis train loss: 0.714705 dis train accu: 0.425189 
epoch: 2 gen train loss: 0.660868 dis train loss: 0.713151 dis train accu: 0.43987 
epoch: 3 gen train loss: 0.663533 dis train loss: 0.711854 dis train accu: 0.447481 
epoch: 4 gen train loss: 0.66581 dis train loss: 0.711649 dis train accu: 0.456195 
epoch: 5 gen train loss: 0.669876 dis train loss: 0.711011 dis train accu: 0.461016 
epoch: 6 gen train loss: 0.670035 dis train loss: 0.710135 dis train accu: 0.467978 
epoch: 7 gen train loss: 0.684427 dis train loss: 0.709596 dis train accu: 0.473881 
epoch: 8 gen train loss: 0.6818 dis train loss: 0.710137 dis train accu: 0.476562 
epoch: 9 gen train loss: 0.682576 dis train loss: 0.707761 dis train accu: 0.477038 
epoch: 10 gen train loss: 0.685383 dis train loss: 0.7088 dis train accu: 0.48173 
epoch: 11 gen train loss: 0.690997 dis train loss: 0.707874 dis train a

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/20 [00:00<?, ?it/s]

epoch: 0 gen train loss: 0.642616 dis train loss: 0.726462 dis train accu: 0.389626 
epoch: 1 gen train loss: 0.654271 dis train loss: 0.719632 dis train accu: 0.405682 
epoch: 2 gen train loss: 0.657243 dis train loss: 0.71634 dis train accu: 0.419236 
epoch: 3 gen train loss: 0.665424 dis train loss: 0.716569 dis train accu: 0.430027 
epoch: 4 gen train loss: 0.664431 dis train loss: 0.712989 dis train accu: 0.437378 
epoch: 5 gen train loss: 0.672859 dis train loss: 0.711945 dis train accu: 0.442852 
epoch: 6 gen train loss: 0.667123 dis train loss: 0.71139 dis train accu: 0.448482 
epoch: 7 gen train loss: 0.676614 dis train loss: 0.711462 dis train accu: 0.449107 
epoch: 8 gen train loss: 0.678005 dis train loss: 0.711178 dis train accu: 0.456041 
epoch: 9 gen train loss: 0.682399 dis train loss: 0.710492 dis train accu: 0.459898 
epoch: 10 gen train loss: 0.685382 dis train loss: 0.711147 dis train accu: 0.457552 
epoch: 11 gen train loss: 0.681232 dis train loss: 0.70953 dis tra

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/20 [00:00<?, ?it/s]

epoch: 0 gen train loss: 0.648872 dis train loss: 0.722975 dis train accu: 0.400509 
epoch: 1 gen train loss: 0.654178 dis train loss: 0.717687 dis train accu: 0.416863 
epoch: 2 gen train loss: 0.658703 dis train loss: 0.714609 dis train accu: 0.431206 
epoch: 3 gen train loss: 0.661306 dis train loss: 0.713351 dis train accu: 0.43992 
epoch: 4 gen train loss: 0.667514 dis train loss: 0.712349 dis train accu: 0.446971 
epoch: 5 gen train loss: 0.670668 dis train loss: 0.71098 dis train accu: 0.450322 
epoch: 6 gen train loss: 0.669719 dis train loss: 0.711331 dis train accu: 0.459625 
epoch: 7 gen train loss: 0.677503 dis train loss: 0.711177 dis train accu: 0.461126 
epoch: 8 gen train loss: 0.672973 dis train loss: 0.709852 dis train accu: 0.466488 
epoch: 9 gen train loss: 0.680663 dis train loss: 0.710372 dis train accu: 0.469169 
epoch: 10 gen train loss: 0.679697 dis train loss: 0.709834 dis train accu: 0.470617 
epoch: 11 gen train loss: 0.683095 dis train loss: 0.71027 dis tra

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/20 [00:00<?, ?it/s]

epoch: 0 gen train loss: 0.651117 dis train loss: 0.722422 dis train accu: 0.396654 
epoch: 1 gen train loss: 0.655717 dis train loss: 0.718226 dis train accu: 0.413277 
epoch: 2 gen train loss: 0.660126 dis train loss: 0.717802 dis train accu: 0.426159 
epoch: 3 gen train loss: 0.658532 dis train loss: 0.713304 dis train accu: 0.434089 
epoch: 4 gen train loss: 0.662617 dis train loss: 0.715056 dis train accu: 0.439726 
epoch: 5 gen train loss: 0.671339 dis train loss: 0.712226 dis train accu: 0.445285 
epoch: 6 gen train loss: 0.677042 dis train loss: 0.711429 dis train accu: 0.450448 
epoch: 7 gen train loss: 0.678874 dis train loss: 0.711639 dis train accu: 0.453583 
epoch: 8 gen train loss: 0.675649 dis train loss: 0.710363 dis train accu: 0.459642 
epoch: 9 gen train loss: 0.683378 dis train loss: 0.708387 dis train accu: 0.463857 
epoch: 10 gen train loss: 0.681479 dis train loss: 0.710654 dis train accu: 0.464278 
epoch: 11 gen train loss: 0.680922 dis train loss: 0.708954 dis 

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/20 [00:00<?, ?it/s]

epoch: 0 gen train loss: 0.650477 dis train loss: 0.725372 dis train accu: 0.393062 
epoch: 1 gen train loss: 0.653152 dis train loss: 0.71934 dis train accu: 0.411049 
epoch: 2 gen train loss: 0.658555 dis train loss: 0.715889 dis train accu: 0.423223 
epoch: 3 gen train loss: 0.662367 dis train loss: 0.712399 dis train accu: 0.431104 
epoch: 4 gen train loss: 0.667519 dis train loss: 0.714626 dis train accu: 0.435109 
epoch: 5 gen train loss: 0.668361 dis train loss: 0.713007 dis train accu: 0.4421 
epoch: 6 gen train loss: 0.67696 dis train loss: 0.712196 dis train accu: 0.446289 
epoch: 7 gen train loss: 0.674628 dis train loss: 0.711362 dis train accu: 0.450818 
epoch: 8 gen train loss: 0.681658 dis train loss: 0.710976 dis train accu: 0.45286 
epoch: 9 gen train loss: 0.684908 dis train loss: 0.710787 dis train accu: 0.459955 
epoch: 10 gen train loss: 0.679855 dis train loss: 0.709828 dis train accu: 0.463726 
epoch: 11 gen train loss: 0.68782 dis train loss: 0.710002 dis train 

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/20 [00:00<?, ?it/s]

epoch: 0 gen train loss: 0.645571 dis train loss: 0.728751 dis train accu: 0.379139 
epoch: 1 gen train loss: 0.650909 dis train loss: 0.722513 dis train accu: 0.393138 
epoch: 2 gen train loss: 0.656088 dis train loss: 0.718807 dis train accu: 0.40791 
epoch: 3 gen train loss: 0.658337 dis train loss: 0.716847 dis train accu: 0.414173 
epoch: 4 gen train loss: 0.665344 dis train loss: 0.717411 dis train accu: 0.420412 
epoch: 5 gen train loss: 0.676496 dis train loss: 0.715784 dis train accu: 0.424828 
epoch: 6 gen train loss: 0.669737 dis train loss: 0.712857 dis train accu: 0.430393 
epoch: 7 gen train loss: 0.680742 dis train loss: 0.71313 dis train accu: 0.436282 
epoch: 8 gen train loss: 0.678541 dis train loss: 0.713102 dis train accu: 0.436382 
epoch: 9 gen train loss: 0.678044 dis train loss: 0.71278 dis train accu: 0.442171 
epoch: 10 gen train loss: 0.681862 dis train loss: 0.71147 dis train accu: 0.443868 
epoch: 11 gen train loss: 0.684306 dis train loss: 0.711686 dis trai

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/20 [00:00<?, ?it/s]

epoch: 0 gen train loss: 0.651826 dis train loss: 0.724318 dis train accu: 0.390236 
epoch: 1 gen train loss: 0.646084 dis train loss: 0.717851 dis train accu: 0.40825 
epoch: 2 gen train loss: 0.658287 dis train loss: 0.71573 dis train accu: 0.419867 
epoch: 3 gen train loss: 0.663761 dis train loss: 0.714349 dis train accu: 0.426968 
epoch: 4 gen train loss: 0.667496 dis train loss: 0.711922 dis train accu: 0.438141 
epoch: 5 gen train loss: 0.674037 dis train loss: 0.713907 dis train accu: 0.438115 
epoch: 6 gen train loss: 0.673471 dis train loss: 0.712213 dis train accu: 0.44827 
epoch: 7 gen train loss: 0.670821 dis train loss: 0.711293 dis train accu: 0.451534 
epoch: 8 gen train loss: 0.676747 dis train loss: 0.7119 dis train accu: 0.453596 
epoch: 9 gen train loss: 0.67341 dis train loss: 0.711574 dis train accu: 0.454197 
epoch: 10 gen train loss: 0.681215 dis train loss: 0.710306 dis train accu: 0.46088 
epoch: 11 gen train loss: 0.683661 dis train loss: 0.709246 dis train a

## Evaluation of latent space

In [23]:
# from sklearn.decomposition import PCA
# from sklearn import model_selection
# from sklearn.ensemble import RandomForestClassifier


# for sample_id in st_sample_id_l:
#     best_checkpoint = torch.load(
#         os.path.join(advtrain_folder, sample_id, f"final_model.pth")
#     )
#     model = best_checkpoint["model"]
#     model.to(device)

#     model.eval()
#     model.target_inference()

#     with torch.no_grad():
#         source_emb = model.source_encoder(torch.Tensor(sc_mix_train_s).to(device))
#         target_emb = model.target_encoder(
#             torch.Tensor(mat_sp_test_s_d[sample_id]).to(device)
#         )

#         y_dis = torch.cat(
#             [
#                 torch.zeros(source_emb.shape[0], device=device, dtype=torch.long),
#                 torch.ones(target_emb.shape[0], device=device, dtype=torch.long),
#             ]
#         )

#         emb = torch.cat([source_emb, target_emb])

#         emb = emb.detach().cpu().numpy()
#         y_dis = y_dis.detach().cpu().numpy()

#     (emb_train, emb_test, y_dis_train, y_dis_test,) = model_selection.train_test_split(
#         emb,
#         y_dis,
#         test_size=0.2,
#         random_state=225,
#         stratify=y_dis,
#     )

#     pca = PCA(n_components=50)
#     pca.fit(emb_train)

#     emb_train_50 = pca.transform(emb_train)
#     emb_test_50 = pca.transform(emb_test)

#     clf = RandomForestClassifier(random_state=145, n_jobs=-1)
#     clf.fit(emb_train_50, y_dis_train)
#     accu_train = clf.score(emb_train_50, y_dis_train)
#     accu_test = clf.score(emb_test_50, y_dis_test)
#     class_proportions = np.mean(y_dis)

#     print(
#         "Training accuracy: {}, Test accuracy: {}, Class proportions: {}".format(
#             accu_train, accu_test, class_proportions
#         )
#     )


 # 4. Predict cell fraction of spots and visualization

In [24]:
# pred_sp_d, pred_sp_noda_d = {}, {}
# if TRAIN_USING_ALL_ST_SAMPLES:
#     best_checkpoint = torch.load(os.path.join(advtrain_folder, f"final_model.pth"))
#     model = best_checkpoint["model"]
#     model.to(device)

#     model.eval()
#     model.target_inference()
#     with torch.no_grad():
#         for sample_id in st_sample_id_l:
#             pred_sp_d[sample_id] = (
#                 torch.exp(
#                     model(torch.Tensor(mat_sp_test_s_d[sample_id]).to(device))
#                 )
#                 .detach()
#                 .cpu()
#                 .numpy()
#             )

# else:
#     for sample_id in st_sample_id_l:
#         best_checkpoint = torch.load(
#             os.path.join(advtrain_folder, sample_id, f"final_model.pth")
#         )
#         model = best_checkpoint["model"]
#         model.to(device)

#         model.eval()
#         model.target_inference()

#         with torch.no_grad():
#             pred_sp_d[sample_id] = (
#                 torch.exp(
#                     model(torch.Tensor(mat_sp_test_s_d[sample_id]).to(device))
#                 )
#                 .detach()
#                 .cpu()
#                 .numpy()
#             )


# best_checkpoint = torch.load(os.path.join(pretrain_folder, f"best_model.pth"))
# model = best_checkpoint["model"]
# model.to(device)

# model.eval()
# model.set_encoder("source")

# with torch.no_grad():
#     for sample_id in st_sample_id_l:
#         pred_sp_noda_d[sample_id] = (
#             torch.exp(model(torch.Tensor(mat_sp_test_s_d[sample_id]).to(device)))
#             .detach()
#             .cpu()
#             .numpy()
#         )


In [25]:
# adata_spatialLIBD = sc.read_h5ad(
#     os.path.join(PROCESSED_DATA_DIR, "adata_spatialLIBD.h5ad")
# )

# adata_spatialLIBD_d = {}
# for sample_id in st_sample_id_l:
#     adata_spatialLIBD_d[sample_id] = adata_spatialLIBD[
#         adata_spatialLIBD.obs.sample_id == sample_id
#     ]
#     adata_spatialLIBD_d[sample_id].obsm["spatial"] = (
#         adata_spatialLIBD_d[sample_id].obs[["X", "Y"]].values
#     )


In [26]:
# num_name_exN_l = []
# for k, v in sc_sub_dict.items():
#     if "Ex" in v:
#         num_name_exN_l.append((k, v, int(v.split("_")[1])))
# num_name_exN_l.sort(key=lambda a: a[2])
# num_name_exN_l


In [27]:
# Ex_to_L_d = {
#     1: {5, 6},
#     2: {5},
#     3: {4, 5},
#     4: {6},
#     5: {5},
#     6: {4, 5, 6},
#     7: {4, 5, 6},
#     8: {5, 6},
#     9: {5, 6},
#     10: {2, 3, 4},
# }


In [28]:
# numlist = [t[0] for t in num_name_exN_l]
# Ex_l = [t[2] for t in num_name_exN_l]
# num_to_ex_d = dict(zip(numlist, Ex_l))


In [29]:
# def plot_cellfraction(visnum, adata, pred_sp, ax=None):
#     """Plot predicted cell fraction for a given visnum"""
#     adata.obs["Pred_label"] = pred_sp[:, visnum]
#     # vmin = 0
#     # vmax = np.amax(pred_sp)

#     sc.pl.spatial(
#         adata,
#         img_key="hires",
#         color="Pred_label",
#         palette="Set1",
#         size=1.5,
#         legend_loc=None,
#         title=f"{sc_sub_dict[visnum]}",
#         spot_size=100,
#         show=False,
#         # vmin=vmin,
#         # vmax=vmax,
#         ax=ax,
#     )


In [30]:
# def plot_roc(visnum, adata, pred_sp, name, ax=None):
#     """Plot ROC for a given visnum"""

#     def layer_to_layer_number(x):
#         for char in x:
#             if char.isdigit():
#                 if int(char) in Ex_to_L_d[num_to_ex_d[visnum]]:
#                     return 1
#         return 0

#     y_pred = pred_sp[:, visnum]
#     y_true = adata.obs["spatialLIBD"].map(layer_to_layer_number).fillna(0)
#     # print(y_true)
#     # print(y_true.isna().sum())
#     RocCurveDisplay.from_predictions(y_true=y_true, y_pred=y_pred, name=name, ax=ax)


In [31]:
# fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 5), constrained_layout=True)

# sc.pl.spatial(
#     adata_spatialLIBD_d[SAMPLE_ID_N],
#     img_key=None,
#     color="spatialLIBD",
#     palette="Accent_r",
#     size=1.5,
#     title=SAMPLE_ID_N,
#     # legend_loc = 4,
#     spot_size=100,
#     show=False,
#     ax=ax,
# )

# ax.axis("equal")
# ax.set_xlabel("")
# ax.set_ylabel("")

# fig.show()


In [32]:
# fig, ax = plt.subplots(2, 5, figsize=(20, 8), constrained_layout=True)

# for i, num in enumerate(numlist):
#     plot_cellfraction(
#         num, adata_spatialLIBD_d[SAMPLE_ID_N], pred_sp_d[SAMPLE_ID_N], ax.flat[i]
#     )
#     ax.flat[i].axis("equal")
#     ax.flat[i].set_xlabel("")
#     ax.flat[i].set_ylabel("")

# fig.show()

# fig, ax = plt.subplots(
#     2, 5, figsize=(20, 8), constrained_layout=True, sharex=True, sharey=True
# )

# for i, num in enumerate(numlist):
#     plot_roc(
#         num,
#         adata_spatialLIBD_d[SAMPLE_ID_N],
#         pred_sp_d[SAMPLE_ID_N],
#         "ADDA",
#         ax.flat[i],
#     )
#     plot_roc(
#         num,
#         adata_spatialLIBD_d[SAMPLE_ID_N],
#         pred_sp_noda_d[SAMPLE_ID_N],
#         "NN_wo_da",
#         ax.flat[i],
#     )
#     ax.flat[i].plot([0, 1], [0, 1], transform=ax.flat[i].transAxes, ls="--", color="k")
#     ax.flat[i].set_aspect("equal")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     ax.flat[i].set_title(f"{sc_sub_dict[num]}")

#     if i >= len(numlist) - 5:
#         ax.flat[i].set_xlabel("FPR")
#     else:
#         ax.flat[i].set_xlabel("")
#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("TPR")
#     else:
#         ax.flat[i].set_ylabel("")

# fig.show()


In [33]:
# if TRAIN_USING_ALL_ST_SAMPLES:
#     best_checkpoint = torch.load(os.path.join(advtrain_folder, f"final_model.pth"))
# else:
#     best_checkpoint = torch.load(
#         os.path.join(advtrain_folder, SAMPLE_ID_N, f"final_model.pth")
#     )

# model = best_checkpoint["model"]
# model.to(device)

# model.eval()
# model.set_encoder("source")

# with torch.no_grad():
#     pred_mix = (
#         torch.exp(model(torch.Tensor(sc_mix_test_s).to(device)))
#         .detach()
#         .cpu()
#         .numpy()
#     )

# cell_type_nums = sc_sub_dict.keys()
# nrows = ceil(len(cell_type_nums) / 5)

# line_kws = {"color": "tab:orange"}
# scatter_kws = {"s": 5}

# props = dict(facecolor="w", alpha=0.5)

# fig, ax = plt.subplots(
#     nrows,
#     5,
#     figsize=(25, 5 * nrows),
#     constrained_layout=True,
#     sharex=False,
#     sharey=True,
# )
# for i, visnum in enumerate(cell_type_nums):
#     sns.regplot(
#         x=pred_mix[:, visnum],
#         y=lab_mix_test[:, visnum],
#         line_kws=line_kws,
#         scatter_kws=scatter_kws,
#         ax=ax.flat[i],
#     ).set_title(sc_sub_dict[visnum])

#     ax.flat[i].set_aspect("equal")
#     ax.flat[i].set_xlabel("Predicted Proportion")

#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("True Proportion")
#     else:
#         ax.flat[i].set_ylabel("")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     textstr = (
#         f"MSE: {mean_squared_error(pred_mix[:,visnum], lab_mix_test[:,visnum]):.5f}"
#     )

#     # place a text box in upper left in axes coords
#     ax.flat[i].text(
#         0.95,
#         0.05,
#         textstr,
#         transform=ax.flat[i].transAxes,
#         verticalalignment="bottom",
#         horizontalalignment="right",
#         bbox=props,
#     )

# for i in range(len(cell_type_nums), nrows * 5):
#     ax.flat[i].axis("off")

# plt.show()
