 # ADDA for ST

In [1]:
import os
import datetime
from itertools import chain
from copy import deepcopy
import warnings

from tqdm.autonotebook import tqdm

import h5py
import pickle
import numpy as np

import torch
from torch import nn

from src.da_models.adda import ADDAST
from src.da_models.datasets import SpotDataset
from src.da_models.utils import set_requires_grad
from src.utils.data_loading import load_spatial, load_sc

# datetime object containing current date and time
script_start_time = datetime.datetime.now().strftime("%Y-%m-%d_%Hh%Mm%S")


  from tqdm.autonotebook import tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cpu":
    warnings.warn("Using CPU", stacklevel=2)


In [3]:
NUM_MARKERS = 20
N_MIX = 8
N_SPOTS = 20000
TRAIN_USING_ALL_ST_SAMPLES = False

ST_SPLIT = False

SAMPLE_ID_N = "151673"

BATCH_SIZE = 512
NUM_WORKERS = 4
INITIAL_TRAIN_EPOCHS = 10

EARLY_STOP_CRIT = 100
MIN_EPOCHS = INITIAL_TRAIN_EPOCHS


EARLY_STOP_CRIT_ADV = 10
MIN_EPOCHS_ADV = 10

SPATIALLIBD_DIR = "./data/spatialLIBD"
SC_DLPFC_PATH = "./data/sc_dlpfc/adata_sc_dlpfc.h5ad"

PROCESSED_DATA_DIR = "./data/preprocessed_markers_celldart"

MODEL_NAME = "CellDART"


celldart_kwargs = {
    "emb_dim": 64,
    "bn_momentum": 0.01,
}


In [4]:
N_ITER = 3000
ALPHA_LR = 5
ALPHA = 0.6


In [5]:
model_folder = os.path.join("model", MODEL_NAME, script_start_time)

model_folder = os.path.join("model", MODEL_NAME, "bn_fix")

if not os.path.isdir(model_folder):
    os.makedirs(model_folder)
    print(model_folder)


 # Data load


In [6]:
# Load spatial data
mat_sp_d, mat_sp_train_s, st_sample_id_l = load_spatial(
    TRAIN_USING_ALL_ST_SAMPLES, PROCESSED_DATA_DIR, ST_SPLIT
)

# Load sc data
sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = load_sc(PROCESSED_DATA_DIR)


 # Training: Adversarial domain adaptation for cell fraction estimation

 ## Prepare dataloaders

In [7]:
### source dataloaders
source_train_set = SpotDataset(
    deepcopy(sc_mix_d["train"]), deepcopy(lab_mix_d["train"])
)
source_val_set = SpotDataset(deepcopy(sc_mix_d["val"]), deepcopy(lab_mix_d["val"]))
source_test_set = SpotDataset(deepcopy(sc_mix_d["test"]), deepcopy(lab_mix_d["test"]))

dataloader_source_train = torch.utils.data.DataLoader(
    source_train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
dataloader_source_val = torch.utils.data.DataLoader(
    source_val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)
dataloader_source_test = torch.utils.data.DataLoader(
    source_test_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)

### target dataloaders
target_test_set_d = {}
for sample_id in st_sample_id_l:
    target_test_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d["test"][sample_id]))

dataloader_target_test_d = {}
for sample_id in st_sample_id_l:
    dataloader_target_test_d[sample_id] = torch.utils.data.DataLoader(
        target_test_set_d[sample_id],
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=False,
    )

if TRAIN_USING_ALL_ST_SAMPLES:
    target_train_set = SpotDataset(deepcopy(mat_sp_train_s))
    dataloader_target_train = torch.utils.data.DataLoader(
        target_train_set,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
else:
    target_train_set_d = {}
    dataloader_target_train_d = {}
    for sample_id in st_sample_id_l:
        target_train_set_d[sample_id] = SpotDataset(
            deepcopy(mat_sp_d["train"][sample_id])
        )
        dataloader_target_train_d[sample_id] = torch.utils.data.DataLoader(
            target_train_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=True,
        )


 ## Define Model

In [8]:
model = ADDAST(
    inp_dim=sc_mix_d["train"].shape[1],
    ncls_source=lab_mix_d["train"].shape[1],
    **celldart_kwargs
)

## CellDART uses just one encoder!
model.target_encoder = model.source_encoder
model.to(device)


ADDAST(
  (source_encoder): MLPEncoder(
    (encoder): Sequential(
      (0): Linear(in_features=367, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): Linear(in_features=1024, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (5): ELU(alpha=1.0)
    )
  )
  (target_encoder): MLPEncoder(
    (encoder): Sequential(
      (0): Linear(in_features=367, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): Linear(in_features=1024, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (5): ELU(alpha=1.0)
    )
  )
  (dis): Discriminator(
    (head): Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1):

 ## Pretrain

In [9]:
pretrain_folder = os.path.join(model_folder, "pretrain")

if not os.path.isdir(pretrain_folder):
    os.makedirs(pretrain_folder)


In [10]:
pre_optimizer = torch.optim.Adam(
    model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-07
)

criterion_clf = nn.KLDivLoss(reduction="batchmean")


In [11]:
def model_loss(x, y_true, model):
    x = x.to(torch.float32).to(device)
    y_true = y_true.to(torch.float32).to(device)

    y_pred = model(x)

    loss = criterion_clf(y_pred, y_true)

    return loss


def compute_acc(dataloader, model):
    loss_running = []
    mean_weights = []
    model.eval()
    with torch.no_grad():
        for _, batch in enumerate(dataloader):

            loss = model_loss(*batch, model)

            loss_running.append(loss.item())

            # we will weight average by batch size later
            mean_weights.append(len(batch))

    return np.average(loss_running, weights=mean_weights)


In [12]:
model.pretraining()


In [13]:
# Initialize lists to store loss and accuracy values
loss_history = []
loss_history_val = []

loss_history_running = []

# Early Stopping
best_loss_val = np.inf
early_stop_count = 0

# Train
print("Start pretrain...")
outer = tqdm(total=INITIAL_TRAIN_EPOCHS, desc="Epochs", position=0)
inner = tqdm(total=len(dataloader_source_train), desc=f"Batch", position=1)

checkpoint = {
    "epoch": -1,
    "model": model,
    "optimizer": pre_optimizer,
}
for epoch in range(INITIAL_TRAIN_EPOCHS):
    checkpoint["epoch"] = epoch

    # Train mode
    model.train()
    loss_running = []
    mean_weights = []

    inner.refresh()  # force print final state
    inner.reset()  # reuse bar
    for _, batch in enumerate(dataloader_source_train):

        pre_optimizer.zero_grad()
        loss = model_loss(*batch, model)
        loss_running.append(loss.item())
        mean_weights.append(len(batch))  # we will weight average by batch size later

        loss.backward()
        pre_optimizer.step()

        inner.update(1)

    loss_history.append(np.average(loss_running, weights=mean_weights))
    loss_history_running.append(loss_running)

    # Evaluate mode
    model.eval()
    with torch.no_grad():
        curr_loss_val = compute_acc(dataloader_source_val, model)
        loss_history_val.append(curr_loss_val)

    # Print the results
    outer.update(1)
    print(
        "epoch:",
        epoch,
        "train loss:",
        round(loss_history[-1], 6),
        "validation loss:",
        round(loss_history_val[-1], 6),
        end=" ",
    )

    # Save the best weights
    if curr_loss_val < best_loss_val:
        best_loss_val = curr_loss_val
        torch.save(checkpoint, os.path.join(pretrain_folder, f"best_model.pth"))
        early_stop_count = 0

        print("<-- new best val loss")
    else:
        print("")

    # Save checkpoint every 10
    if epoch % 10 == 0 or epoch >= INITIAL_TRAIN_EPOCHS - 1:
        torch.save(checkpoint, os.path.join(pretrain_folder, f"checkpt{epoch}.pth"))

    # check to see if validation loss has plateau'd
    if early_stop_count >= EARLY_STOP_CRIT and epoch >= MIN_EPOCHS - 1:
        print(f"Validation loss plateaued after {early_stop_count} at epoch {epoch}")
        torch.save(checkpoint, os.path.join(pretrain_folder, f"earlystop{epoch}.pth"))
        break

    early_stop_count += 1


# Save final model
torch.save(checkpoint, os.path.join(pretrain_folder, f"final_model.pth"))


Start pretrain...


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Batch:   0%|          | 0/40 [00:00<?, ?it/s]

epoch: 0 train loss: 1.208606 validation loss: 1.498633 <-- new best val loss
epoch: 1 train loss: 0.860363 validation loss: 1.221814 <-- new best val loss
epoch: 2 train loss: 0.687908 validation loss: 1.014193 <-- new best val loss
epoch: 3 train loss: 0.612534 validation loss: 0.947307 <-- new best val loss
epoch: 4 train loss: 0.579146 validation loss: 0.948112 
epoch: 5 train loss: 0.558538 validation loss: 0.883462 <-- new best val loss
epoch: 6 train loss: 0.537817 validation loss: 0.831158 <-- new best val loss
epoch: 7 train loss: 0.523133 validation loss: 0.7611 <-- new best val loss
epoch: 8 train loss: 0.507153 validation loss: 0.698754 <-- new best val loss
epoch: 9 train loss: 0.49611 validation loss: 0.756352 


 ## Adversarial Adaptation

In [14]:
advtrain_folder = os.path.join(model_folder, "advtrain")

if not os.path.isdir(advtrain_folder):
    os.makedirs(advtrain_folder)


In [15]:
def batch_generator(data, batch_size):
    """Generate batches of data.
    
    Given a list of numpy data, it iterates over the list and returns batches of
    the same size.
    """
    all_examples_indices = len(data[0])
    while True:
        mini_batch_indices = np.random.choice(
            all_examples_indices, size=batch_size, replace=False
        )
        tbr = [k[mini_batch_indices] for k in data]
        yield tbr


In [16]:
criterion_dis = nn.CrossEntropyLoss()


def discrim_loss_accu(x, y_dis, model):
    x = x.to(torch.float32).to(device)

    emb = model.source_encoder(x)
    y_pred = model.dis(emb)

    loss = criterion_dis(y_pred, y_dis)

    accu = torch.mean(
        (torch.flatten(torch.argmax(y_pred, dim=1)) == y_dis).to(torch.float32)
    ).cpu()

    return loss, accu


def compute_acc_dis(dataloader_source, dataloader_target, model):
    len_target = len(dataloader_target)
    len_source = len(dataloader_source)

    loss_running = []
    accu_running = []
    mean_weights = []
    model.eval()
    model.source_encoder.eval()
    model.dis.eval()
    with torch.no_grad():
        for y_val, dl in zip([1, 0], [dataloader_target, dataloader_source]):
            for _, (x, _) in enumerate(dl):

                y_dis = torch.full(
                    (x.shape[0],), y_val, device=device, dtype=torch.long
                )

                loss, accu = discrim_loss_accu(x, y_dis, model)

                accu_running.append(accu)
                loss_running.append(loss.item())

                # we will weight average by batch size later
                mean_weights.append(len(x))

    return (
        np.average(loss_running, weights=mean_weights),
        np.average(accu_running, weights=mean_weights),
    )


In [17]:
def train_adversarial(
    model,
    save_folder,
    sc_mix_train_s,
    lab_mix_train_s,
    mat_sp_train_s,
    dataloader_source_train_eval,
    dataloader_target_train_eval,
):

    model.to(device)
    model.advtraining()
    model.set_encoder("source")

    S_batches = batch_generator(
        [sc_mix_train_s.copy(), lab_mix_train_s.copy()], BATCH_SIZE
    )
    T_batches = batch_generator(
        [mat_sp_train_s.copy(), np.ones(shape=(len(mat_sp_train_s), 2))], BATCH_SIZE
    )

    enc_optimizer = torch.optim.Adam(
        chain(
            model.source_encoder.parameters(),
            model.dis.parameters(),
            model.clf.parameters(),
        ),
        lr=0.001,
        betas=(0.9, 0.999),
        eps=1e-07,
    )

    dis_optimizer = torch.optim.Adam(
        chain(model.source_encoder.parameters(), model.dis.parameters()),
        lr=ALPHA_LR * 0.001,
        betas=(0.9, 0.999),
        eps=1e-07,
    )

    # Initialize lists to store loss and accuracy values
    loss_history_running = []

    # Train
    print("Start adversarial training...")
    outer = tqdm(total=N_ITER, desc="Iterations", position=0)

    checkpoint = {
        "epoch": -1,
        "model": model,
        "dis_optimizer": dis_optimizer,
        "enc_optimizer": enc_optimizer,
    }
    for iters in range(N_ITER):
        checkpoint["epoch"] = iters

        model.train()
        model.dis.train()
        model.clf.train()
        model.source_encoder.train()

        x_source, y_true = next(S_batches)
        x_target, _ = next(T_batches)

        ## Train encoder
        set_requires_grad(model.source_encoder, True)
        set_requires_grad(model.clf, True)
        set_requires_grad(model.dis, True)

        # save discriminator weights
        dis_weights = deepcopy(model.dis.state_dict())
        new_dis_weights = {}
        for k in dis_weights:
            if "num_batches_tracked" not in k:
                new_dis_weights[k] = dis_weights[k]

        dis_weights = new_dis_weights

        x_source, x_target, y_true, = (
            torch.Tensor(x_source),
            torch.Tensor(x_target),
            torch.Tensor(y_true),
        )
        x_source, x_target, y_true, = (
            x_source.to(device),
            x_target.to(device),
            y_true.to(device),
        )

        x = torch.cat((x_source, x_target))

        # save for discriminator later
        x_d = x.detach()

        # y_dis is the REAL one
        y_dis = torch.cat(
            [
                torch.zeros(x_source.shape[0], device=device, dtype=torch.long),
                torch.ones(x_target.shape[0], device=device, dtype=torch.long),
            ]
        )
        y_dis_flipped = 1 - y_dis.detach()

        emb = model.source_encoder(x).view(x.shape[0], -1)

        y_dis_pred = model.dis(emb)
        y_clf_pred = model.clf(emb)

        # we use flipped because we want to confuse discriminator
        loss_dis = criterion_dis(y_dis_pred, y_dis_flipped)

        # Set true = predicted for target samples since we don't know what it is
        y_clf_true = torch.cat((y_true, y_clf_pred[-x_target.shape[0] :].detach()))
        # Loss fn does mean over all samples including target
        loss_clf = criterion_clf(y_clf_pred, y_clf_true)

        # Scale back up loss so mean doesn't include target
        loss = (x.shape[0] / x_source.shape[0]) * loss_clf + ALPHA * loss_dis

        # loss = loss_clf + ALPHA * loss_dis

        enc_optimizer.zero_grad()
        loss.backward()
        enc_optimizer.step()

        model.dis.load_state_dict(dis_weights, strict=False)

        ## Train discriminator
        set_requires_grad(model.source_encoder, True)
        set_requires_grad(model.clf, True)
        set_requires_grad(model.dis, True)

        # save encoder and clf weights
        source_encoder_weights = deepcopy(model.source_encoder.state_dict())
        clf_weights = deepcopy(model.clf.state_dict())

        new_clf_weights = {}
        for k in clf_weights:
            if "num_batches_tracked" not in k:
                new_clf_weights[k] = clf_weights[k]

        clf_weights = new_clf_weights

        new_source_encoder_weights = {}
        for k in source_encoder_weights:
            if "num_batches_tracked" not in k:
                new_source_encoder_weights[k] = source_encoder_weights[k]

        source_encoder_weights = new_source_encoder_weights

        emb = model.source_encoder(x_d).view(x_d.shape[0], -1)
        y_pred = model.dis(emb)

        # we use the real domain labels to train discriminator
        loss = criterion_dis(y_pred, y_dis)

        dis_optimizer.zero_grad()
        loss.backward()
        dis_optimizer.step()

        model.clf.load_state_dict(clf_weights, strict=False)
        model.source_encoder.load_state_dict(source_encoder_weights, strict=False)

        # Save checkpoint every 100
        if iters % 100 == 99 or iters >= N_ITER - 1:
            torch.save(checkpoint, os.path.join(save_folder, f"checkpt{iters}.pth"))

            model.eval()
            source_loss = compute_acc(dataloader_source_train_eval, model)
            _, dis_accu = compute_acc_dis(
                dataloader_source_train_eval, dataloader_target_train_eval, model
            )

            # Print the results
            print(
                "iter:",
                iters,
                "source loss:",
                round(source_loss, 6),
                "dis accu:",
                round(dis_accu, 6),
            )

        outer.update(1)

    torch.save(checkpoint, os.path.join(save_folder, f"final_model.pth"))


In [18]:
# st_sample_id_l = [SAMPLE_ID_N]


In [19]:
if TRAIN_USING_ALL_ST_SAMPLES:
    print(f"Adversarial training for all ST slides")
    save_folder = advtrain_folder

    best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
    model = best_checkpoint["model"]
    model.to(device)
    model.advtraining()

    train_adversarial(
        model,
        save_folder,
        sc_mix_d["train"],
        lab_mix_d["train"],
        mat_sp_train_s,
        dataloader_source_train,
        dataloader_target_train,
    )

else:
    for sample_id in st_sample_id_l:
        print(f"Adversarial training for ST slide {sample_id}: ")

        save_folder = os.path.join(advtrain_folder, sample_id)
        if not os.path.isdir(save_folder):
            os.makedirs(save_folder)

        best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
        model = best_checkpoint["model"]
        model.to(device)
        model.advtraining()

        train_adversarial(
            model,
            save_folder,
            sc_mix_d["train"],
            lab_mix_d["train"],
            mat_sp_d["train"][sample_id],
            dataloader_source_train,
            dataloader_target_train_d[sample_id],
        )


Adversarial training for ST slide 151507: 
Start adversarial training...


Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.457244 dis accu: 0.23029
iter: 199 source loss: 2.176082 dis accu: 0.001321
iter: 299 source loss: 1.390336 dis accu: 0.239619
iter: 399 source loss: 1.072204 dis accu: 0.174441
iter: 499 source loss: 0.927285 dis accu: 0.999959
iter: 599 source loss: 0.881081 dis accu: 0.18798
iter: 699 source loss: 0.927853 dis accu: 0.315322
iter: 799 source loss: 1.027221 dis accu: 0.0
iter: 899 source loss: 0.98781 dis accu: 0.174523
iter: 999 source loss: 0.897621 dis accu: 0.008214
iter: 1099 source loss: 0.95897 dis accu: 0.175679
iter: 1199 source loss: 0.778132 dis accu: 0.197061
iter: 1299 source loss: 0.739123 dis accu: 0.255676
iter: 1399 source loss: 1.001235 dis accu: 0.160117
iter: 1499 source loss: 0.859048 dis accu: 0.039586
iter: 1599 source loss: 0.75684 dis accu: 0.998844
iter: 1699 source loss: 0.839493 dis accu: 0.130975
iter: 1799 source loss: 0.716023 dis accu: 0.878643
iter: 1899 source loss: 1.354906 dis accu: 0.410798
iter: 1999 source loss: 0.688942 

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.040961 dis accu: 8.5e-05
iter: 199 source loss: 1.718421 dis accu: 0.148608
iter: 299 source loss: 0.935595 dis accu: 0.148864
iter: 399 source loss: 0.90574 dis accu: 0.148864
iter: 499 source loss: 0.847886 dis accu: 0.952209
iter: 599 source loss: 0.902912 dis accu: 0.042429
iter: 699 source loss: 0.93596 dis accu: 0.16495
iter: 799 source loss: 0.770101 dis accu: 0.144097
iter: 899 source loss: 0.823657 dis accu: 8.5e-05
iter: 999 source loss: 0.662517 dis accu: 0.148864
iter: 1099 source loss: 0.665636 dis accu: 0.148779
iter: 1199 source loss: 0.620418 dis accu: 0.147842
iter: 1299 source loss: 0.566893 dis accu: 0.364286
iter: 1399 source loss: 0.547967 dis accu: 0.100519
iter: 1499 source loss: 0.52172 dis accu: 0.620904
iter: 1599 source loss: 0.526171 dis accu: 0.822964
iter: 1699 source loss: 0.559354 dis accu: 0.182867
iter: 1799 source loss: 0.531006 dis accu: 0.164482
iter: 1899 source loss: 0.474771 dis accu: 0.6275
iter: 1999 source loss: 0.46991

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.662996 dis accu: 0.111182
iter: 199 source loss: 1.163008 dis accu: 1.0
iter: 299 source loss: 1.124818 dis accu: 0.133859
iter: 399 source loss: 0.768076 dis accu: 0.98597
iter: 499 source loss: 0.99573 dis accu: 0.141997
iter: 599 source loss: 0.760078 dis accu: 0.212784
iter: 699 source loss: 0.774479 dis accu: 0.136868
iter: 799 source loss: 0.907493 dis accu: 4.2e-05
iter: 899 source loss: 0.688695 dis accu: 0.03001
iter: 999 source loss: 0.625597 dis accu: 0.999873
iter: 1099 source loss: 0.578874 dis accu: 0.925568
iter: 1199 source loss: 0.574344 dis accu: 0.152721
iter: 1299 source loss: 0.545951 dis accu: 0.119871
iter: 1399 source loss: 0.525727 dis accu: 0.612326
iter: 1499 source loss: 0.551294 dis accu: 0.954985
iter: 1599 source loss: 0.535855 dis accu: 0.853934
iter: 1699 source loss: 0.505088 dis accu: 0.841938
iter: 1799 source loss: 0.50261 dis accu: 0.554425
iter: 1899 source loss: 0.504375 dis accu: 0.796796
iter: 1999 source loss: 0.41746 d

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.636811 dis accu: 0.153941
iter: 199 source loss: 1.485481 dis accu: 0.082914
iter: 299 source loss: 1.07504 dis accu: 0.151995
iter: 399 source loss: 1.124022 dis accu: 0.098651
iter: 499 source loss: 1.005144 dis accu: 0.370532
iter: 599 source loss: 0.786279 dis accu: 0.847752
iter: 699 source loss: 0.826839 dis accu: 0.030289
iter: 799 source loss: 0.741606 dis accu: 0.144592
iter: 899 source loss: 0.639962 dis accu: 0.449681
iter: 999 source loss: 0.629115 dis accu: 0.630991
iter: 1099 source loss: 0.626453 dis accu: 0.027751
iter: 1199 source loss: 0.586736 dis accu: 0.021786
iter: 1299 source loss: 0.560813 dis accu: 0.314607
iter: 1399 source loss: 0.569869 dis accu: 0.22391
iter: 1499 source loss: 0.537148 dis accu: 0.134016
iter: 1599 source loss: 0.539689 dis accu: 0.297644
iter: 1699 source loss: 0.478948 dis accu: 0.941326
iter: 1799 source loss: 0.45752 dis accu: 0.90173
iter: 1899 source loss: 0.467826 dis accu: 0.958035
iter: 1999 source loss: 0.4

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.970336 dis accu: 0.0
iter: 199 source loss: 1.373427 dis accu: 0.155156
iter: 299 source loss: 0.932649 dis accu: 0.167575
iter: 399 source loss: 0.869368 dis accu: 0.999324
iter: 499 source loss: 0.945422 dis accu: 0.155156
iter: 599 source loss: 0.857539 dis accu: 0.87416
iter: 699 source loss: 0.720238 dis accu: 0.950154
iter: 799 source loss: 0.786003 dis accu: 0.000338
iter: 899 source loss: 0.668091 dis accu: 0.089216
iter: 999 source loss: 0.843169 dis accu: 0.008364
iter: 1099 source loss: 0.879007 dis accu: 0.02826
iter: 1199 source loss: 0.623526 dis accu: 0.652304
iter: 1299 source loss: 0.615605 dis accu: 0.999958
iter: 1399 source loss: 0.587035 dis accu: 0.999366
iter: 1499 source loss: 0.624774 dis accu: 0.056943
iter: 1599 source loss: 0.561983 dis accu: 0.924513
iter: 1699 source loss: 0.576307 dis accu: 0.014024
iter: 1799 source loss: 0.545979 dis accu: 0.032738
iter: 1899 source loss: 0.487814 dis accu: 0.957251
iter: 1999 source loss: 0.5043

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 3.528713 dis accu: 0.012918
iter: 199 source loss: 1.975197 dis accu: 0.00082
iter: 299 source loss: 1.635268 dis accu: 0.17979
iter: 399 source loss: 1.755752 dis accu: 0.178888
iter: 499 source loss: 0.954567 dis accu: 0.17979
iter: 599 source loss: 1.012953 dis accu: 0.017101
iter: 699 source loss: 0.937869 dis accu: 0.17979
iter: 799 source loss: 0.976207 dis accu: 0.000533
iter: 899 source loss: 0.827159 dis accu: 0.17979
iter: 999 source loss: 0.817565 dis accu: 0.003158
iter: 1099 source loss: 0.852921 dis accu: 0.17979
iter: 1199 source loss: 0.648631 dis accu: 1.0
iter: 1299 source loss: 0.790146 dis accu: 0.022023
iter: 1399 source loss: 0.631234 dis accu: 0.998811
iter: 1499 source loss: 0.729573 dis accu: 0.876722
iter: 1599 source loss: 0.738877 dis accu: 0.912361
iter: 1699 source loss: 0.714719 dis accu: 0.959359
iter: 1799 source loss: 0.62514 dis accu: 1.0
iter: 1899 source loss: 1.524789 dis accu: 0.0
iter: 1999 source loss: 1.301556 dis accu: 0.

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 3.001231 dis accu: 0.00503
iter: 199 source loss: 1.609718 dis accu: 0.136104
iter: 299 source loss: 2.316307 dis accu: 0.147485
iter: 399 source loss: 1.231184 dis accu: 0.049233
iter: 499 source loss: 0.861534 dis accu: 0.047016
iter: 599 source loss: 1.238268 dis accu: 0.134101
iter: 699 source loss: 0.894793 dis accu: 0.17289
iter: 799 source loss: 1.049727 dis accu: 0.028431
iter: 899 source loss: 0.937335 dis accu: 0.132566
iter: 999 source loss: 0.71184 dis accu: 0.998551
iter: 1099 source loss: 1.108864 dis accu: 0.147315
iter: 1199 source loss: 1.048858 dis accu: 0.999787
iter: 1299 source loss: 0.808818 dis accu: 0.022379
iter: 1399 source loss: 0.81971 dis accu: 0.286573
iter: 1499 source loss: 0.711633 dis accu: 0.991986
iter: 1599 source loss: 0.848345 dis accu: 0.102131
iter: 1699 source loss: 0.760874 dis accu: 0.434143
iter: 1799 source loss: 0.915039 dis accu: 0.012916
iter: 1899 source loss: 1.104184 dis accu: 0.145737
iter: 1999 source loss: 0.6

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.735597 dis accu: 0.838018
iter: 199 source loss: 1.653123 dis accu: 0.156361
iter: 299 source loss: 1.70132 dis accu: 0.008495
iter: 399 source loss: 0.963798 dis accu: 0.999001
iter: 499 source loss: 0.953492 dis accu: 0.999958
iter: 599 source loss: 0.847248 dis accu: 0.999334
iter: 699 source loss: 0.892877 dis accu: 0.016989
iter: 799 source loss: 0.764283 dis accu: 0.14924
iter: 899 source loss: 0.870601 dis accu: 0.774558
iter: 999 source loss: 0.815209 dis accu: 0.155778
iter: 1099 source loss: 0.667918 dis accu: 0.439517
iter: 1199 source loss: 0.781436 dis accu: 0.227566
iter: 1299 source loss: 0.616212 dis accu: 0.992921
iter: 1399 source loss: 0.6024 dis accu: 0.96802
iter: 1499 source loss: 0.671069 dis accu: 0.289944
iter: 1599 source loss: 0.832024 dis accu: 0.000375
iter: 1699 source loss: 0.774975 dis accu: 0.050135
iter: 1799 source loss: 0.956298 dis accu: 0.13325
iter: 1899 source loss: 0.698365 dis accu: 0.569394
iter: 1999 source loss: 0.720

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 3.191894 dis accu: 0.154727
iter: 199 source loss: 1.935134 dis accu: 0.154727
iter: 299 source loss: 1.28509 dis accu: 0.154727
iter: 399 source loss: 1.144855 dis accu: 0.085119
iter: 499 source loss: 1.200171 dis accu: 0.003466
iter: 599 source loss: 0.891091 dis accu: 0.613499
iter: 699 source loss: 1.144835 dis accu: 0.832044
iter: 799 source loss: 0.894585 dis accu: 0.0
iter: 899 source loss: 0.826559 dis accu: 0.154727
iter: 999 source loss: 0.861211 dis accu: 0.846203
iter: 1099 source loss: 0.776544 dis accu: 0.845231
iter: 1199 source loss: 0.715537 dis accu: 0.683403
iter: 1299 source loss: 0.809557 dis accu: 0.414353
iter: 1399 source loss: 0.697082 dis accu: 0.951904
iter: 1499 source loss: 1.228901 dis accu: 0.040531
iter: 1599 source loss: 0.789425 dis accu: 0.96285
iter: 1699 source loss: 0.920047 dis accu: 0.154727
iter: 1799 source loss: 0.768986 dis accu: 0.800685
iter: 1899 source loss: 0.832726 dis accu: 0.242593
iter: 1999 source loss: 0.8944

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.911457 dis accu: 0.163258
iter: 199 source loss: 1.571693 dis accu: 0.175078
iter: 299 source loss: 1.268293 dis accu: 0.008673
iter: 399 source loss: 1.12767 dis accu: 0.193191
iter: 499 source loss: 0.95201 dis accu: 0.18008
iter: 599 source loss: 0.823326 dis accu: 0.163903
iter: 699 source loss: 0.91927 dis accu: 0.193191
iter: 799 source loss: 0.719863 dis accu: 0.783573
iter: 899 source loss: 0.897129 dis accu: 0.130582
iter: 999 source loss: 0.700496 dis accu: 0.990439
iter: 1099 source loss: 0.708731 dis accu: 0.192747
iter: 1199 source loss: 0.626689 dis accu: 0.998306
iter: 1299 source loss: 0.637383 dis accu: 0.999879
iter: 1399 source loss: 0.738185 dis accu: 0.096373
iter: 1499 source loss: 0.797445 dis accu: 0.003066
iter: 1599 source loss: 0.634585 dis accu: 0.886885
iter: 1699 source loss: 0.626099 dis accu: 0.965711
iter: 1799 source loss: 0.862099 dis accu: 0.054258
iter: 1899 source loss: 0.620082 dis accu: 0.230465
iter: 1999 source loss: 0.6

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.755419 dis accu: 0.031564
iter: 199 source loss: 2.067189 dis accu: 0.170469
iter: 299 source loss: 1.138528 dis accu: 0.781999
iter: 399 source loss: 1.298213 dis accu: 0.565657
iter: 499 source loss: 1.047442 dis accu: 0.166902
iter: 599 source loss: 0.924717 dis accu: 0.023559
iter: 699 source loss: 0.915599 dis accu: 0.489589
iter: 799 source loss: 1.015416 dis accu: 0.657694
iter: 899 source loss: 0.901472 dis accu: 0.831024
iter: 999 source loss: 0.838283 dis accu: 0.375819
iter: 1099 source loss: 0.754088 dis accu: 0.855703
iter: 1199 source loss: 0.96189 dis accu: 0.999129
iter: 1299 source loss: 1.174147 dis accu: 0.15392
iter: 1399 source loss: 0.728394 dis accu: 0.990709
iter: 1499 source loss: 1.171886 dis accu: 0.854293
iter: 1599 source loss: 0.766221 dis accu: 0.878681
iter: 1699 source loss: 0.839841 dis accu: 0.176566
iter: 1799 source loss: 0.694628 dis accu: 0.978764
iter: 1899 source loss: 0.828159 dis accu: 0.300539
iter: 1999 source loss: 0

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.925198 dis accu: 0.721158
iter: 199 source loss: 1.491902 dis accu: 0.18718
iter: 299 source loss: 1.364774 dis accu: 0.168304
iter: 399 source loss: 1.191044 dis accu: 0.188114
iter: 499 source loss: 0.906314 dis accu: 0.188114
iter: 599 source loss: 0.988928 dis accu: 0.188114
iter: 699 source loss: 1.023653 dis accu: 0.193472
iter: 799 source loss: 0.844309 dis accu: 0.991759
iter: 899 source loss: 0.908751 dis accu: 0.007835
iter: 999 source loss: 0.765124 dis accu: 0.871154
iter: 1099 source loss: 0.953149 dis accu: 0.265609
iter: 1199 source loss: 0.71403 dis accu: 0.877527
iter: 1299 source loss: 0.915 dis accu: 0.772713
iter: 1399 source loss: 1.042673 dis accu: 0.812008
iter: 1499 source loss: 0.70446 dis accu: 0.977186
iter: 1599 source loss: 0.841118 dis accu: 0.188114
iter: 1699 source loss: 0.777579 dis accu: 0.999553
iter: 1799 source loss: 0.925994 dis accu: 0.13798
iter: 1899 source loss: 0.756252 dis accu: 0.926443
iter: 1999 source loss: 0.7858

In [20]:
# if TRAIN_USING_ALL_ST_SAMPLES:
#     best_checkpoint = torch.load(os.path.join(advtrain_folder, f"final_model.pth"))
# else:
#     best_checkpoint = torch.load(
#         os.path.join(advtrain_folder, SAMPLE_ID_N, f"final_model.pth")
#     )

# model = best_checkpoint["model"]
# model.to(device)

# model.eval()
# model.set_encoder("source")

# pred_mix = (
#     torch.exp(model(torch.Tensor(sc_mix_test_s).to(device)))
#     .detach()
#     .cpu()
#     .numpy()
# )

# cell_type_nums = sc_sub_dict.keys()
# nrows = ceil(len(cell_type_nums) / 5)

# line_kws = {"color": "tab:orange"}
# scatter_kws = {"s": 5}

# props = dict(facecolor="w", alpha=0.5)

# fig, ax = plt.subplots(
#     nrows,
#     5,
#     figsize=(25, 5 * nrows),
#     constrained_layout=True,
#     sharex=False,
#     sharey=True,
# )
# for i, visnum in enumerate(cell_type_nums):
#     sns.regplot(
#         x=pred_mix[:, visnum],
#         y=lab_mix_d["test"][:, visnum],
#         line_kws=line_kws,
#         scatter_kws=scatter_kws,
#         ax=ax.flat[i],
#     ).set_title(sc_sub_dict[visnum])

#     ax.flat[i].set_aspect("equal")
#     ax.flat[i].set_xlabel("Predicted Proportion")

#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("True Proportion")
#     else:
#         ax.flat[i].set_ylabel("")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     textstr = (
#         f"MSE: {mean_squared_error(pred_mix[:,visnum], lab_mix_d["test"][:,visnum]):.5f}"
#     )

#     # place a text box in upper left in axes coords
#     ax.flat[i].text(
#         0.95,
#         0.05,
#         textstr,
#         transform=ax.flat[i].transAxes,
#         verticalalignment="bottom",
#         horizontalalignment="right",
#         bbox=props,
#     )

# for i in range(len(cell_type_nums), nrows * 5):
#     ax.flat[i].axis("off")

# plt.show()
