 # ADDA for ST

In [1]:
import os
import datetime
from itertools import chain
from copy import deepcopy
import warnings

from tqdm.autonotebook import tqdm

import h5py
import pickle
import numpy as np

import torch
from torch import nn

from src.da_models.adda import ADDAST
from src.da_models.datasets import SpotDataset
from src.da_models.utils import set_requires_grad
from src.utils.data_loading import load_spatial, load_sc

# datetime object containing current date and time
script_start_time = datetime.datetime.now().strftime("%Y-%m-%d_%Hh%Mm%S")


  from tqdm.autonotebook import tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cpu":
    warnings.warn("Using CPU", stacklevel=2)


In [3]:
NUM_MARKERS = 20
N_MIX = 8
N_SPOTS = 20000
TRAIN_USING_ALL_ST_SAMPLES = False

ST_SPLIT = False

SAMPLE_ID_N = "151673"

BATCH_SIZE = 512
NUM_WORKERS = 4
INITIAL_TRAIN_EPOCHS = 10

EARLY_STOP_CRIT = 100
MIN_EPOCHS = INITIAL_TRAIN_EPOCHS


EARLY_STOP_CRIT_ADV = 10
MIN_EPOCHS_ADV = 10

SPATIALLIBD_DIR = "./data/spatialLIBD"
SC_DLPFC_PATH = "./data/sc_dlpfc/adata_sc_dlpfc.h5ad"

PROCESSED_DATA_DIR = "./data/preprocessed_markers_celldart"

MODEL_NAME = "CellDART"


celldart_kwargs = {
    "emb_dim": 64,
    "bn_momentum": 0.01,
}


In [4]:
N_ITER = 3000
ALPHA_LR = 5
ALPHA = 0.6


In [5]:
model_folder = os.path.join("model", MODEL_NAME, script_start_time)

model_folder = os.path.join("model", MODEL_NAME, "bn_fix")

if not os.path.isdir(model_folder):
    os.makedirs(model_folder)
    print(model_folder)


 # Data load


In [6]:
# Load spatial data
mat_sp_d, mat_sp_train_s, st_sample_id_l = load_spatial(
    TRAIN_USING_ALL_ST_SAMPLES, PROCESSED_DATA_DIR, ST_SPLIT
)

# Load sc data
sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = load_sc(PROCESSED_DATA_DIR)


 # Training: Adversarial domain adaptation for cell fraction estimation

 ## Prepare dataloaders

In [7]:
### source dataloaders
source_train_set = SpotDataset(
    deepcopy(sc_mix_d["train"]), deepcopy(lab_mix_d["train"])
)
source_val_set = SpotDataset(deepcopy(sc_mix_d["val"]), deepcopy(lab_mix_d["val"]))
source_test_set = SpotDataset(deepcopy(sc_mix_d["test"]), deepcopy(lab_mix_d["test"]))

dataloader_source_train = torch.utils.data.DataLoader(
    source_train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
dataloader_source_val = torch.utils.data.DataLoader(
    source_val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)
dataloader_source_test = torch.utils.data.DataLoader(
    source_test_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)

### target dataloaders
target_test_set_d = {}
for sample_id in st_sample_id_l:
    target_test_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d["test"][sample_id]))

dataloader_target_test_d = {}
for sample_id in st_sample_id_l:
    dataloader_target_test_d[sample_id] = torch.utils.data.DataLoader(
        target_test_set_d[sample_id],
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=False,
    )

if TRAIN_USING_ALL_ST_SAMPLES:
    target_train_set = SpotDataset(deepcopy(mat_sp_train_s))
    dataloader_target_train = torch.utils.data.DataLoader(
        target_train_set,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
else:
    target_train_set_d = {}
    dataloader_target_train_d = {}
    for sample_id in st_sample_id_l:
        target_train_set_d[sample_id] = SpotDataset(
            deepcopy(mat_sp_d["train"][sample_id])
        )
        dataloader_target_train_d[sample_id] = torch.utils.data.DataLoader(
            target_train_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=True,
        )


 ## Define Model

In [8]:
model = ADDAST(
    inp_dim=sc_mix_d["train"].shape[1],
    ncls_source=lab_mix_d["train"].shape[1],
    **celldart_kwargs
)

## CellDART uses just one encoder!
model.target_encoder = model.source_encoder
model.to(device)


ADDAST(
  (source_encoder): MLPEncoder(
    (encoder): Sequential(
      (0): Linear(in_features=367, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): Linear(in_features=1024, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (5): ELU(alpha=1.0)
    )
  )
  (target_encoder): MLPEncoder(
    (encoder): Sequential(
      (0): Linear(in_features=367, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): Linear(in_features=1024, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (5): ELU(alpha=1.0)
    )
  )
  (dis): Discriminator(
    (head): Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1):

 ## Pretrain

In [9]:
pretrain_folder = os.path.join(model_folder, "pretrain")

if not os.path.isdir(pretrain_folder):
    os.makedirs(pretrain_folder)


In [10]:
pre_optimizer = torch.optim.Adam(
    model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-07
)

criterion_clf = nn.KLDivLoss(reduction="batchmean")


In [11]:
def model_loss(x, y_true, model):
    x = x.to(torch.float32).to(device)
    y_true = y_true.to(torch.float32).to(device)

    y_pred = model(x)

    loss = criterion_clf(y_pred, y_true)

    return loss


def compute_acc(dataloader, model):
    loss_running = []
    mean_weights = []
    model.eval()
    with torch.no_grad():
        for _, batch in enumerate(dataloader):

            loss = model_loss(*batch, model)

            loss_running.append(loss.item())

            # we will weight average by batch size later
            mean_weights.append(len(batch))

    return np.average(loss_running, weights=mean_weights)


In [12]:
model.pretraining()


In [13]:
# Initialize lists to store loss and accuracy values
loss_history = []
loss_history_val = []

loss_history_running = []

# Early Stopping
best_loss_val = np.inf
early_stop_count = 0

# Train
print("Start pretrain...")
outer = tqdm(total=INITIAL_TRAIN_EPOCHS, desc="Epochs", position=0)
inner = tqdm(total=len(dataloader_source_train), desc=f"Batch", position=1)

checkpoint = {
    "epoch": -1,
    "model": model,
    "optimizer": pre_optimizer,
}
for epoch in range(INITIAL_TRAIN_EPOCHS):
    checkpoint["epoch"] = epoch

    # Train mode
    model.train()
    loss_running = []
    mean_weights = []

    inner.refresh()  # force print final state
    inner.reset()  # reuse bar
    for _, batch in enumerate(dataloader_source_train):

        pre_optimizer.zero_grad()
        loss = model_loss(*batch, model)
        loss_running.append(loss.item())
        mean_weights.append(len(batch))  # we will weight average by batch size later

        loss.backward()
        pre_optimizer.step()

        inner.update(1)

    loss_history.append(np.average(loss_running, weights=mean_weights))
    loss_history_running.append(loss_running)

    # Evaluate mode
    model.eval()
    with torch.no_grad():
        curr_loss_val = compute_acc(dataloader_source_val, model)
        loss_history_val.append(curr_loss_val)

    # Print the results
    outer.update(1)
    print(
        "epoch:",
        epoch,
        "train loss:",
        round(loss_history[-1], 6),
        "validation loss:",
        round(loss_history_val[-1], 6),
        end=" ",
    )

    # Save the best weights
    if curr_loss_val < best_loss_val:
        best_loss_val = curr_loss_val
        torch.save(checkpoint, os.path.join(pretrain_folder, f"best_model.pth"))
        early_stop_count = 0

        print("<-- new best val loss")
    else:
        print("")

    # Save checkpoint every 10
    if epoch % 10 == 0 or epoch >= INITIAL_TRAIN_EPOCHS - 1:
        torch.save(checkpoint, os.path.join(pretrain_folder, f"checkpt{epoch}.pth"))

    # check to see if validation loss has plateau'd
    if early_stop_count >= EARLY_STOP_CRIT and epoch >= MIN_EPOCHS - 1:
        print(f"Validation loss plateaued after {early_stop_count} at epoch {epoch}")
        torch.save(checkpoint, os.path.join(pretrain_folder, f"earlystop{epoch}.pth"))
        break

    early_stop_count += 1


# Save final model
torch.save(checkpoint, os.path.join(pretrain_folder, f"final_model.pth"))


Start pretrain...


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Batch:   0%|          | 0/40 [00:00<?, ?it/s]

epoch: 0 train loss: 1.200553 validation loss: 1.432417 <-- new best val loss
epoch: 1 train loss: 0.857845 validation loss: 1.169185 <-- new best val loss
epoch: 2 train loss: 0.691272 validation loss: 1.005482 <-- new best val loss
epoch: 3 train loss: 0.613629 validation loss: 0.984148 <-- new best val loss
epoch: 4 train loss: 0.580604 validation loss: 0.95129 <-- new best val loss
epoch: 5 train loss: 0.55745 validation loss: 0.898289 <-- new best val loss
epoch: 6 train loss: 0.539335 validation loss: 0.888219 <-- new best val loss
epoch: 7 train loss: 0.524384 validation loss: 0.834256 <-- new best val loss
epoch: 8 train loss: 0.510968 validation loss: 0.770836 <-- new best val loss
epoch: 9 train loss: 0.49733 validation loss: 0.772089 


 ## Adversarial Adaptation

In [14]:
advtrain_folder = os.path.join(model_folder, "advtrain")

if not os.path.isdir(advtrain_folder):
    os.makedirs(advtrain_folder)


In [15]:
def batch_generator(data, batch_size):
    """Generate batches of data.
    
    Given a list of numpy data, it iterates over the list and returns batches of
    the same size.
    """
    all_examples_indices = len(data[0])
    while True:
        mini_batch_indices = np.random.choice(
            all_examples_indices, size=batch_size, replace=False
        )
        tbr = [k[mini_batch_indices] for k in data]
        yield tbr


In [16]:
criterion_dis = nn.CrossEntropyLoss()


def discrim_loss_accu(x, y_dis, model):
    x = x.to(torch.float32).to(device)

    emb = model.source_encoder(x)
    y_pred = model.dis(emb)

    loss = criterion_dis(y_pred, y_dis)

    accu = torch.mean(
        (torch.flatten(torch.argmax(y_pred, dim=1)) == y_dis).to(torch.float32)
    ).cpu()

    return loss, accu


def compute_acc_dis(dataloader_source, dataloader_target, model):
    len_target = len(dataloader_target)
    len_source = len(dataloader_source)

    loss_running = []
    accu_running = []
    mean_weights = []
    model.eval()
    model.source_encoder.eval()
    model.dis.eval()
    with torch.no_grad():
        for y_val, dl in zip([1, 0], [dataloader_target, dataloader_source]):
            for _, (x, _) in enumerate(dl):

                y_dis = torch.full(
                    (x.shape[0],), y_val, device=device, dtype=torch.long
                )

                loss, accu = discrim_loss_accu(x, y_dis, model)

                accu_running.append(accu)
                loss_running.append(loss.item())

                # we will weight average by batch size later
                mean_weights.append(len(x))

    return (
        np.average(loss_running, weights=mean_weights),
        np.average(accu_running, weights=mean_weights),
    )


In [17]:
def train_adversarial(
    model,
    save_folder,
    sc_mix_train_s,
    lab_mix_train_s,
    mat_sp_train_s,
    dataloader_source_train_eval,
    dataloader_target_train_eval,
):

    model.to(device)
    model.advtraining()
    model.set_encoder("source")

    S_batches = batch_generator(
        [sc_mix_train_s.copy(), lab_mix_train_s.copy()], BATCH_SIZE
    )
    T_batches = batch_generator(
        [mat_sp_train_s.copy(), np.ones(shape=(len(mat_sp_train_s), 2))], BATCH_SIZE
    )

    enc_optimizer = torch.optim.Adam(
        chain(
            model.source_encoder.parameters(),
            model.dis.parameters(),
            model.clf.parameters(),
        ),
        lr=0.001,
        betas=(0.9, 0.999),
        eps=1e-07,
    )

    dis_optimizer = torch.optim.Adam(
        chain(model.source_encoder.parameters(), model.dis.parameters()),
        lr=ALPHA_LR * 0.001,
        betas=(0.9, 0.999),
        eps=1e-07,
    )

    # Initialize lists to store loss and accuracy values
    loss_history_running = []

    # Train
    print("Start adversarial training...")
    outer = tqdm(total=N_ITER, desc="Iterations", position=0)

    checkpoint = {
        "epoch": -1,
        "model": model,
        "dis_optimizer": dis_optimizer,
        "enc_optimizer": enc_optimizer,
    }
    for iters in range(N_ITER):
        checkpoint["epoch"] = iters

        model.train()
        model.dis.train()
        model.clf.train()
        model.source_encoder.train()

        x_source, y_true = next(S_batches)
        x_target, _ = next(T_batches)

        ## Train encoder
        set_requires_grad(model.source_encoder, True)
        set_requires_grad(model.clf, True)
        set_requires_grad(model.dis, True)

        # save discriminator weights
        dis_weights = deepcopy(model.dis.state_dict())
        new_dis_weights = {}
        for k in dis_weights:
            if "num_batches_tracked" not in k:
                new_dis_weights[k] = dis_weights[k]

        dis_weights = new_dis_weights

        x_source, x_target, y_true, = (
            torch.Tensor(x_source),
            torch.Tensor(x_target),
            torch.Tensor(y_true),
        )
        x_source, x_target, y_true, = (
            x_source.to(device),
            x_target.to(device),
            y_true.to(device),
        )

        x = torch.cat((x_source, x_target))

        # save for discriminator later
        x_d = x.detach()

        # y_dis is the REAL one
        y_dis = torch.cat(
            [
                torch.zeros(x_source.shape[0], device=device, dtype=torch.long),
                torch.ones(x_target.shape[0], device=device, dtype=torch.long),
            ]
        )
        y_dis_flipped = 1 - y_dis.detach()

        emb = model.source_encoder(x).view(x.shape[0], -1)

        y_dis_pred = model.dis(emb)
        y_clf_pred = model.clf(emb)

        # we use flipped because we want to confuse discriminator
        loss_dis = criterion_dis(y_dis_pred, y_dis_flipped)

        # Set true = predicted for target samples since we don't know what it is
        y_clf_true = torch.cat((y_true, y_clf_pred[-x_target.shape[0] :].detach()))
        # Loss fn does mean over all samples including target
        loss_clf = criterion_clf(y_clf_pred, y_clf_true)

        # Scale back up loss so mean doesn't include target
        loss = (x.shape[0] / x_source.shape[0]) * loss_clf + ALPHA * loss_dis

        # loss = loss_clf + ALPHA * loss_dis

        enc_optimizer.zero_grad()
        loss.backward()
        enc_optimizer.step()

        model.dis.load_state_dict(dis_weights, strict=False)

        ## Train discriminator
        set_requires_grad(model.source_encoder, True)
        set_requires_grad(model.clf, True)
        set_requires_grad(model.dis, True)

        # save encoder and clf weights
        source_encoder_weights = deepcopy(model.source_encoder.state_dict())
        clf_weights = deepcopy(model.clf.state_dict())

        new_clf_weights = {}
        for k in clf_weights:
            if "num_batches_tracked" not in k:
                new_clf_weights[k] = clf_weights[k]

        clf_weights = new_clf_weights

        new_source_encoder_weights = {}
        for k in source_encoder_weights:
            if "num_batches_tracked" not in k:
                new_source_encoder_weights[k] = source_encoder_weights[k]

        source_encoder_weights = new_source_encoder_weights

        emb = model.source_encoder(x_d).view(x_d.shape[0], -1)
        y_pred = model.dis(emb)

        # we use the real domain labels to train discriminator
        loss = criterion_dis(y_pred, y_dis)

        dis_optimizer.zero_grad()
        loss.backward()
        dis_optimizer.step()

        model.clf.load_state_dict(clf_weights, strict=False)
        model.source_encoder.load_state_dict(source_encoder_weights, strict=False)

        # Save checkpoint every 100
        if iters % 100 == 99 or iters >= N_ITER - 1:
            torch.save(checkpoint, os.path.join(save_folder, f"checkpt{iters}.pth"))

            model.eval()
            source_loss = compute_acc(dataloader_source_train_eval, model)
            _, dis_accu = compute_acc_dis(
                dataloader_source_train_eval, dataloader_target_train_eval, model
            )

            # Print the results
            print(
                "iter:",
                iters,
                "source loss:",
                round(source_loss, 6),
                "dis accu:",
                round(dis_accu, 6),
            )

        outer.update(1)

    torch.save(checkpoint, os.path.join(save_folder, f"final_model.pth"))


In [18]:
# st_sample_id_l = [SAMPLE_ID_N]


In [19]:
if TRAIN_USING_ALL_ST_SAMPLES:
    print(f"Adversarial training for all ST slides")
    save_folder = advtrain_folder

    best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
    model = best_checkpoint["model"]
    model.to(device)
    model.advtraining()

    train_adversarial(
        model,
        save_folder,
        sc_mix_d["train"],
        lab_mix_d["train"],
        mat_sp_train_s,
        dataloader_source_train,
        dataloader_target_train,
    )

else:
    for sample_id in st_sample_id_l:
        print(f"Adversarial training for ST slide {sample_id}: ")

        save_folder = os.path.join(advtrain_folder, sample_id)
        if not os.path.isdir(save_folder):
            os.makedirs(save_folder)

        best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
        model = best_checkpoint["model"]
        model.to(device)
        model.advtraining()

        train_adversarial(
            model,
            save_folder,
            sc_mix_d["train"],
            lab_mix_d["train"],
            mat_sp_d["train"][sample_id],
            dataloader_source_train,
            dataloader_target_train_d[sample_id],
        )


Adversarial training for ST slide 151507: 
Start adversarial training...


Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.402684 dis accu: 0.175101
iter: 199 source loss: 1.200676 dis accu: 0.174565
iter: 299 source loss: 1.044143 dis accu: 0.044333
iter: 399 source loss: 1.020976 dis accu: 0.172501
iter: 499 source loss: 0.949005 dis accu: 0.015521
iter: 599 source loss: 1.181578 dis accu: 0.002064
iter: 699 source loss: 0.960063 dis accu: 0.174441
iter: 799 source loss: 0.974099 dis accu: 0.030835
iter: 899 source loss: 0.925667 dis accu: 0.991868
iter: 999 source loss: 0.889115 dis accu: 0.976719
iter: 1099 source loss: 0.878345 dis accu: 0.174771
iter: 1199 source loss: 0.815522 dis accu: 0.165896
iter: 1299 source loss: 1.041307 dis accu: 0.178527
iter: 1399 source loss: 0.927543 dis accu: 0.929951
iter: 1499 source loss: 0.839059 dis accu: 0.999835
iter: 1599 source loss: 0.742302 dis accu: 0.935854
iter: 1699 source loss: 0.75665 dis accu: 0.989887
iter: 1799 source loss: 1.097513 dis accu: 0.126888
iter: 1899 source loss: 0.642576 dis accu: 0.974903
iter: 1999 source loss: 

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.827075 dis accu: 0.17979
iter: 199 source loss: 1.392839 dis accu: 8.2e-05
iter: 299 source loss: 1.28145 dis accu: 0.162115
iter: 399 source loss: 1.070872 dis accu: 0.179585
iter: 499 source loss: 0.861586 dis accu: 0.007423
iter: 599 source loss: 0.788384 dis accu: 0.019234
iter: 699 source loss: 0.737167 dis accu: 0.274688
iter: 799 source loss: 0.69745 dis accu: 8.2e-05
iter: 899 source loss: 0.681851 dis accu: 0.012344
iter: 999 source loss: 0.584925 dis accu: 0.003404
iter: 1099 source loss: 0.559091 dis accu: 0.311311
iter: 1199 source loss: 0.521794 dis accu: 0.479372
iter: 1299 source loss: 0.539486 dis accu: 0.781455
iter: 1399 source loss: 0.491309 dis accu: 0.900221
iter: 1499 source loss: 0.508203 dis accu: 0.668799
iter: 1599 source loss: 0.48087 dis accu: 0.145874
iter: 1699 source loss: 0.499325 dis accu: 0.601214
iter: 1799 source loss: 0.439486 dis accu: 0.513985
iter: 1899 source loss: 0.430008 dis accu: 0.429216
iter: 1999 source loss: 0.416

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.845189 dis accu: 0.193231
iter: 199 source loss: 0.977501 dis accu: 0.837872
iter: 299 source loss: 1.401892 dis accu: 0.001896
iter: 399 source loss: 0.95845 dis accu: 0.74525
iter: 499 source loss: 0.910862 dis accu: 0.247327
iter: 599 source loss: 0.996869 dis accu: 0.001372
iter: 699 source loss: 0.728548 dis accu: 0.998548
iter: 799 source loss: 0.761187 dis accu: 0.999395
iter: 899 source loss: 0.74859 dis accu: 0.106822
iter: 999 source loss: 1.097891 dis accu: 0.001533
iter: 1099 source loss: 0.738414 dis accu: 0.327645
iter: 1199 source loss: 0.743521 dis accu: 0.52136
iter: 1299 source loss: 0.649497 dis accu: 0.210819
iter: 1399 source loss: 0.660462 dis accu: 0.003994
iter: 1499 source loss: 0.693484 dis accu: 0.002219
iter: 1599 source loss: 0.603726 dis accu: 0.077413
iter: 1699 source loss: 0.573389 dis accu: 0.055105
iter: 1799 source loss: 0.536898 dis accu: 0.083787
iter: 1899 source loss: 0.529118 dis accu: 0.201017
iter: 1999 source loss: 0.5

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 3.289033 dis accu: 0.188114
iter: 199 source loss: 2.304462 dis accu: 4.1e-05
iter: 299 source loss: 1.263104 dis accu: 0.060851
iter: 399 source loss: 1.076439 dis accu: 0.038077
iter: 499 source loss: 0.91568 dis accu: 0.17557
iter: 599 source loss: 0.893856 dis accu: 0.188114
iter: 699 source loss: 0.894783 dis accu: 0.089998
iter: 799 source loss: 0.747736 dis accu: 0.188114
iter: 899 source loss: 0.80205 dis accu: 0.267273
iter: 999 source loss: 0.791282 dis accu: 0.063327
iter: 1099 source loss: 0.665198 dis accu: 0.999878
iter: 1199 source loss: 0.802796 dis accu: 0.001949
iter: 1299 source loss: 0.640365 dis accu: 0.997605
iter: 1399 source loss: 0.605864 dis accu: 0.999878
iter: 1499 source loss: 0.598108 dis accu: 0.744702
iter: 1599 source loss: 0.69497 dis accu: 4.1e-05
iter: 1699 source loss: 0.569449 dis accu: 0.022408
iter: 1799 source loss: 0.537333 dis accu: 0.082812
iter: 1899 source loss: 0.527593 dis accu: 0.416295
iter: 1999 source loss: 0.508

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.919003 dis accu: 0.154727
iter: 199 source loss: 2.164257 dis accu: 0.333798
iter: 299 source loss: 2.318265 dis accu: 4.2e-05
iter: 399 source loss: 1.026289 dis accu: 0.999282
iter: 499 source loss: 0.81468 dis accu: 1.0
iter: 599 source loss: 0.862584 dis accu: 0.002705
iter: 699 source loss: 0.911986 dis accu: 0.138413
iter: 799 source loss: 0.692254 dis accu: 0.998098
iter: 899 source loss: 0.726174 dis accu: 0.97181
iter: 999 source loss: 0.630601 dis accu: 0.999282
iter: 1099 source loss: 0.626835 dis accu: 0.020709
iter: 1199 source loss: 0.761906 dis accu: 0.00038
iter: 1299 source loss: 0.817235 dis accu: 0.009087
iter: 1399 source loss: 0.924953 dis accu: 0.0
iter: 1499 source loss: 0.568574 dis accu: 0.01961
iter: 1599 source loss: 0.513458 dis accu: 0.391573
iter: 1699 source loss: 0.51 dis accu: 0.729132
iter: 1799 source loss: 0.507902 dis accu: 0.392122
iter: 1899 source loss: 0.519839 dis accu: 0.329825
iter: 1999 source loss: 0.505518 dis accu:

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.697487 dis accu: 0.151375
iter: 199 source loss: 1.561198 dis accu: 0.112137
iter: 299 source loss: 1.163092 dis accu: 0.005532
iter: 399 source loss: 1.117894 dis accu: 0.001532
iter: 499 source loss: 0.845698 dis accu: 0.277726
iter: 599 source loss: 0.760172 dis accu: 0.99966
iter: 699 source loss: 0.853383 dis accu: 0.00915
iter: 799 source loss: 0.872892 dis accu: 0.002979
iter: 899 source loss: 0.668921 dis accu: 0.990638
iter: 999 source loss: 0.943622 dis accu: 8.5e-05
iter: 1099 source loss: 0.901301 dis accu: 0.157205
iter: 1199 source loss: 0.700944 dis accu: 1.0
iter: 1299 source loss: 0.988138 dis accu: 0.0
iter: 1399 source loss: 0.666302 dis accu: 0.999957
iter: 1499 source loss: 1.571737 dis accu: 0.05277
iter: 1599 source loss: 0.674535 dis accu: 0.999872
iter: 1699 source loss: 1.330128 dis accu: 0.326794
iter: 1799 source loss: 0.7106 dis accu: 0.999362
iter: 1899 source loss: 0.773144 dis accu: 0.998298
iter: 1999 source loss: 0.985146 dis ac

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.033683 dis accu: 0.170469
iter: 199 source loss: 1.25766 dis accu: 0.169722
iter: 299 source loss: 1.462087 dis accu: 0.170469
iter: 399 source loss: 1.56348 dis accu: 0.010535
iter: 499 source loss: 1.040711 dis accu: 0.80701
iter: 599 source loss: 0.886168 dis accu: 0.171672
iter: 699 source loss: 0.967984 dis accu: 0.013687
iter: 799 source loss: 0.805141 dis accu: 0.170469
iter: 899 source loss: 0.849231 dis accu: 0.161095
iter: 999 source loss: 0.64244 dis accu: 0.003152
iter: 1099 source loss: 0.631501 dis accu: 0.406346
iter: 1199 source loss: 0.574947 dis accu: 0.094442
iter: 1299 source loss: 0.553641 dis accu: 0.829531
iter: 1399 source loss: 0.555276 dis accu: 0.658565
iter: 1499 source loss: 0.525443 dis accu: 0.685981
iter: 1599 source loss: 0.510967 dis accu: 0.689714
iter: 1699 source loss: 0.489733 dis accu: 0.82949
iter: 1799 source loss: 0.477851 dis accu: 0.835214
iter: 1899 source loss: 0.456266 dis accu: 0.848445
iter: 1999 source loss: 0.42

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.011891 dis accu: 0.167187
iter: 199 source loss: 1.036628 dis accu: 0.474495
iter: 299 source loss: 1.087833 dis accu: 0.999709
iter: 399 source loss: 0.951484 dis accu: 0.153904
iter: 499 source loss: 0.973676 dis accu: 0.167187
iter: 599 source loss: 1.15366 dis accu: 0.071122
iter: 699 source loss: 0.793042 dis accu: 0.943785
iter: 799 source loss: 1.319874 dis accu: 4.2e-05
iter: 899 source loss: 0.802943 dis accu: 0.769103
iter: 999 source loss: 1.012458 dis accu: 0.110889
iter: 1099 source loss: 0.831635 dis accu: 0.063793
iter: 1199 source loss: 0.627593 dis accu: 0.999958
iter: 1299 source loss: 0.804679 dis accu: 0.044431
iter: 1399 source loss: 0.795242 dis accu: 0.167187
iter: 1499 source loss: 0.75787 dis accu: 0.905767
iter: 1599 source loss: 1.328108 dis accu: 0.999334
iter: 1699 source loss: 0.797201 dis accu: 0.910098
iter: 1799 source loss: 0.756857 dis accu: 0.99925
iter: 1899 source loss: 0.688865 dis accu: 0.054966
iter: 1999 source loss: 0.9

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.733454 dis accu: 0.153941
iter: 199 source loss: 1.471033 dis accu: 0.151318
iter: 299 source loss: 1.54124 dis accu: 0.140869
iter: 399 source loss: 1.141766 dis accu: 0.253183
iter: 499 source loss: 0.775254 dis accu: 0.999535
iter: 599 source loss: 0.722436 dis accu: 0.999069
iter: 699 source loss: 0.733135 dis accu: 0.991074
iter: 799 source loss: 1.255779 dis accu: 0.091586
iter: 899 source loss: 0.733028 dis accu: 0.996658
iter: 999 source loss: 1.258695 dis accu: 0.0
iter: 1099 source loss: 0.845919 dis accu: 0.01726
iter: 1199 source loss: 0.626721 dis accu: 1.0
iter: 1299 source loss: 0.638242 dis accu: 0.872499
iter: 1399 source loss: 0.671165 dis accu: 0.031347
iter: 1499 source loss: 0.678587 dis accu: 0.110622
iter: 1599 source loss: 0.578487 dis accu: 0.476966
iter: 1699 source loss: 0.55691 dis accu: 0.79022
iter: 1799 source loss: 0.543093 dis accu: 0.879944
iter: 1899 source loss: 0.525204 dis accu: 0.467448
iter: 1999 source loss: 0.539463 dis 

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.750132 dis accu: 0.155156
iter: 199 source loss: 1.94549 dis accu: 0.154691
iter: 299 source loss: 1.201523 dis accu: 0.155156
iter: 399 source loss: 1.223608 dis accu: 0.155156
iter: 499 source loss: 1.106527 dis accu: 0.430744
iter: 599 source loss: 1.010105 dis accu: 0.397542
iter: 699 source loss: 0.908044 dis accu: 0.134415
iter: 799 source loss: 0.849668 dis accu: 0.734719
iter: 899 source loss: 1.071171 dis accu: 0.000338
iter: 999 source loss: 0.845489 dis accu: 0.998606
iter: 1099 source loss: 0.893878 dis accu: 0.061124
iter: 1199 source loss: 1.09322 dis accu: 0.155156
iter: 1299 source loss: 0.850256 dis accu: 0.844844
iter: 1399 source loss: 0.786354 dis accu: 0.887805
iter: 1499 source loss: 0.856975 dis accu: 0.138639
iter: 1599 source loss: 0.834501 dis accu: 0.110506
iter: 1699 source loss: 0.719859 dis accu: 0.963545
iter: 1799 source loss: 0.831796 dis accu: 0.860178
iter: 1899 source loss: 0.781274 dis accu: 0.079796
iter: 1999 source loss: 0

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.323576 dis accu: 0.152255
iter: 199 source loss: 1.229293 dis accu: 0.042811
iter: 299 source loss: 0.971015 dis accu: 0.152255
iter: 399 source loss: 1.085588 dis accu: 0.00106
iter: 499 source loss: 1.009515 dis accu: 0.044549
iter: 599 source loss: 0.746061 dis accu: 0.999576
iter: 699 source loss: 0.726505 dis accu: 0.491141
iter: 799 source loss: 0.737102 dis accu: 0.078586
iter: 899 source loss: 0.617481 dis accu: 0.000254
iter: 999 source loss: 0.607698 dis accu: 0.07155
iter: 1099 source loss: 0.538361 dis accu: 0.377967
iter: 1199 source loss: 0.516608 dis accu: 0.034673
iter: 1299 source loss: 0.512809 dis accu: 0.507121
iter: 1399 source loss: 0.528709 dis accu: 0.754154
iter: 1499 source loss: 0.474007 dis accu: 0.670778
iter: 1599 source loss: 0.488987 dis accu: 0.907257
iter: 1699 source loss: 0.435866 dis accu: 0.857113
iter: 1799 source loss: 0.450681 dis accu: 0.650856
iter: 1899 source loss: 0.406227 dis accu: 0.428408
iter: 1999 source loss: 0

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.302041 dis accu: 0.147485
iter: 199 source loss: 0.903799 dis accu: 0.987681
iter: 299 source loss: 1.480647 dis accu: 0.503922
iter: 399 source loss: 1.137979 dis accu: 0.863896
iter: 499 source loss: 0.986015 dis accu: 0.024893
iter: 599 source loss: 1.456981 dis accu: 0.903282
iter: 699 source loss: 0.900354 dis accu: 0.035166
iter: 799 source loss: 0.756257 dis accu: 0.89416
iter: 899 source loss: 0.994393 dis accu: 0.074766
iter: 999 source loss: 0.878186 dis accu: 0.496505
iter: 1099 source loss: 0.966908 dis accu: 0.147485
iter: 1199 source loss: 0.805791 dis accu: 0.455286
iter: 1299 source loss: 0.795168 dis accu: 0.991858
iter: 1399 source loss: 0.949996 dis accu: 0.141816
iter: 1499 source loss: 0.687583 dis accu: 0.867221
iter: 1599 source loss: 0.894832 dis accu: 0.857886
iter: 1699 source loss: 1.087398 dis accu: 0.146974
iter: 1799 source loss: 0.876766 dis accu: 0.147613
iter: 1899 source loss: 0.799366 dis accu: 0.130946
iter: 1999 source loss: 

In [20]:
# if TRAIN_USING_ALL_ST_SAMPLES:
#     best_checkpoint = torch.load(os.path.join(advtrain_folder, f"final_model.pth"))
# else:
#     best_checkpoint = torch.load(
#         os.path.join(advtrain_folder, SAMPLE_ID_N, f"final_model.pth")
#     )

# model = best_checkpoint["model"]
# model.to(device)

# model.eval()
# model.set_encoder("source")

# pred_mix = (
#     torch.exp(model(torch.Tensor(sc_mix_test_s).to(device)))
#     .detach()
#     .cpu()
#     .numpy()
# )

# cell_type_nums = sc_sub_dict.keys()
# nrows = ceil(len(cell_type_nums) / 5)

# line_kws = {"color": "tab:orange"}
# scatter_kws = {"s": 5}

# props = dict(facecolor="w", alpha=0.5)

# fig, ax = plt.subplots(
#     nrows,
#     5,
#     figsize=(25, 5 * nrows),
#     constrained_layout=True,
#     sharex=False,
#     sharey=True,
# )
# for i, visnum in enumerate(cell_type_nums):
#     sns.regplot(
#         x=pred_mix[:, visnum],
#         y=lab_mix_d["test"][:, visnum],
#         line_kws=line_kws,
#         scatter_kws=scatter_kws,
#         ax=ax.flat[i],
#     ).set_title(sc_sub_dict[visnum])

#     ax.flat[i].set_aspect("equal")
#     ax.flat[i].set_xlabel("Predicted Proportion")

#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("True Proportion")
#     else:
#         ax.flat[i].set_ylabel("")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     textstr = (
#         f"MSE: {mean_squared_error(pred_mix[:,visnum], lab_mix_d["test"][:,visnum]):.5f}"
#     )

#     # place a text box in upper left in axes coords
#     ax.flat[i].text(
#         0.95,
#         0.05,
#         textstr,
#         transform=ax.flat[i].transAxes,
#         verticalalignment="bottom",
#         horizontalalignment="right",
#         bbox=props,
#     )

# for i in range(len(cell_type_nums), nrows * 5):
#     ax.flat[i].axis("off")

# plt.show()
