 # ADDA for ST

In [1]:
import os
import datetime
from itertools import chain
from copy import deepcopy
import warnings

from tqdm.autonotebook import tqdm

import h5py
import pickle
import numpy as np

import torch
from torch import nn

from src.da_models.adda import ADDAST
from src.da_models.datasets import SpotDataset
from src.da_models.utils import set_requires_grad
from src.utils.dupstdout import DupStdout
from src.utils.data_loading import load_spatial, load_sc, get_selected_dir
from src.utils.evaluation import format_iters

# datetime object containing current date and time
script_start_time = datetime.datetime.now().strftime("%Y-%m-%d_%Hh%Mm%S")


  from tqdm.autonotebook import tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device == "cpu":
    warnings.warn("Using CPU", stacklevel=2)


In [3]:
DATA_DIR = "data"
TRAIN_USING_ALL_ST_SAMPLES = False
N_MARKERS = 20
ALL_GENES = False

N_SPOTS = 20000
N_MIX = 8

ST_SPLIT = False

SCALER_NAME = "celldart"

SAMPLE_ID_N = "151673"

BATCH_SIZE = 512
NUM_WORKERS = 4
INITIAL_TRAIN_EPOCHS = 10

EARLY_STOP_CRIT = 100
MIN_EPOCHS = INITIAL_TRAIN_EPOCHS


EARLY_STOP_CRIT_ADV = 10
MIN_EPOCHS_ADV = 10


MODEL_NAME = "CellDART"
MODEL_VERSION = "celldart1_nobnfix"

celldart_kwargs = {
    "emb_dim": 64,
    "bn_momentum": 0.99,
}


In [4]:
N_ITER = 3000
ALPHA_LR = 5
ALPHA = 0.6


In [5]:
model_folder = os.path.join("model", MODEL_NAME, script_start_time)

model_folder = os.path.join("model", MODEL_NAME, MODEL_VERSION)

if not os.path.isdir(model_folder):
    os.makedirs(model_folder)
    print(model_folder)

model/CellDART/celldart1_nobnfix


 # Data load


In [6]:
# Load spatial data
mat_sp_d, mat_sp_train, st_sample_id_l = load_spatial(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    train_using_all_st_samples=TRAIN_USING_ALL_ST_SAMPLES,
    st_split=ST_SPLIT,
)

# Load sc data
sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = load_sc(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    n_mix=N_MIX,
    n_spots=N_SPOTS,
)



 # Training: Adversarial domain adaptation for cell fraction estimation

 ## Prepare dataloaders

In [7]:
### source dataloaders
source_train_set = SpotDataset(
    deepcopy(sc_mix_d["train"]), deepcopy(lab_mix_d["train"])
)
source_val_set = SpotDataset(deepcopy(sc_mix_d["val"]), deepcopy(lab_mix_d["val"]))
source_test_set = SpotDataset(deepcopy(sc_mix_d["test"]), deepcopy(lab_mix_d["test"]))

dataloader_source_train = torch.utils.data.DataLoader(
    source_train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
dataloader_source_val = torch.utils.data.DataLoader(
    source_val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)
dataloader_source_test = torch.utils.data.DataLoader(
    source_test_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)

# ### target dataloaders
# target_test_set_d = {}
# for sample_id in st_sample_id_l:
#     target_test_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d["test"][sample_id]))

# dataloader_target_test_d = {}
# for sample_id in st_sample_id_l:
#     dataloader_target_test_d[sample_id] = torch.utils.data.DataLoader(
#         target_test_set_d[sample_id],
#         batch_size=BATCH_SIZE,
#         shuffle=False,
#         num_workers=NUM_WORKERS,
#         pin_memory=False,
#     )

# if TRAIN_USING_ALL_ST_SAMPLES:
#     target_train_set = SpotDataset(deepcopy(mat_sp_train_s))
#     dataloader_target_train = torch.utils.data.DataLoader(
#         target_train_set,
#         batch_size=BATCH_SIZE,
#         shuffle=True,
#         num_workers=NUM_WORKERS,
#         pin_memory=True,
#     )
# else:
#     target_train_set_d = {}
#     dataloader_target_train_d = {}
#     for sample_id in st_sample_id_l:
#         target_train_set_d[sample_id] = SpotDataset(
#             deepcopy(mat_sp_d["train"][sample_id])
#         )
#         dataloader_target_train_d[sample_id] = torch.utils.data.DataLoader(
#             target_train_set_d[sample_id],
#             batch_size=BATCH_SIZE,
#             shuffle=True,
#             num_workers=NUM_WORKERS,
#             pin_memory=True,
#         )

# if TRAIN_USING_ALL_ST_SAMPLES:
#     target_train_set = SpotDataset(deepcopy(mat_sp_train))
#     dataloader_target_train = torch.utils.data.DataLoader(
#         target_train_set,
#         batch_size=BATCH_SIZE,
#         shuffle=True,
#         num_workers=NUM_WORKERS,
#         pin_memory=True,
#     )

### target dataloaders
target_train_set_d = {}
dataloader_target_train_d = {}
if ST_SPLIT:
    target_val_set_d = {}
    target_test_set_d = {}

    dataloader_target_val_d = {}
    dataloader_target_test_d = {}
    for sample_id in st_sample_id_l:
        target_train_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d[sample_id]["train"]))
        target_val_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d[sample_id]["val"]))
        target_test_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d[sample_id]["test"]))

        dataloader_target_train_d[sample_id] = torch.utils.data.DataLoader(
            target_train_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )
        dataloader_target_val_d[sample_id] = torch.utils.data.DataLoader(
            target_val_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )
        dataloader_target_test_d[sample_id] = torch.utils.data.DataLoader(
            target_test_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )

else:
    target_test_set_d = {}
    dataloader_target_test_d = {}
    for sample_id in st_sample_id_l:
        target_train_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d[sample_id]["train"]))
        dataloader_target_train_d[sample_id] = torch.utils.data.DataLoader(
            target_train_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )

        target_test_set_d[sample_id] = SpotDataset(
            deepcopy(mat_sp_d[sample_id]["test"])
        )
        dataloader_target_test_d[sample_id] = torch.utils.data.DataLoader(
            target_test_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )


if TRAIN_USING_ALL_ST_SAMPLES:
    target_train_set = SpotDataset(mat_sp_train)
    dataloader_target_train = torch.utils.data.DataLoader(
        target_train_set,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    

 ## Define Model

In [8]:
model = ADDAST(
    inp_dim=sc_mix_d["train"].shape[1],
    ncls_source=lab_mix_d["train"].shape[1],
    **celldart_kwargs
)

## CellDART uses just one encoder!
model.target_encoder = model.source_encoder
model.to(device)


ADDAST(
  (source_encoder): MLPEncoder(
    (encoder): Sequential(
      (0): Linear(in_features=360, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): Linear(in_features=1024, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (5): ELU(alpha=1.0)
    )
  )
  (target_encoder): MLPEncoder(
    (encoder): Sequential(
      (0): Linear(in_features=360, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): Linear(in_features=1024, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
      (5): ELU(alpha=1.0)
    )
  )
  (dis): Discriminator(
    (head): Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1):

 ## Pretrain

In [9]:
pretrain_folder = os.path.join(model_folder, "pretrain")

if not os.path.isdir(pretrain_folder):
    os.makedirs(pretrain_folder)


In [10]:
pre_optimizer = torch.optim.Adam(
    model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-07
)

criterion_clf = nn.KLDivLoss(reduction="batchmean")


In [11]:
def model_loss(x, y_true, model):
    x = x.to(torch.float32).to(device)
    y_true = y_true.to(torch.float32).to(device)

    y_pred = model(x)

    loss = criterion_clf(y_pred, y_true)

    return loss


def compute_acc(dataloader, model):
    loss_running = []
    mean_weights = []
    model.eval()
    with torch.no_grad():
        for _, batch in enumerate(dataloader):

            loss = model_loss(*batch, model)

            loss_running.append(loss.item())

            # we will weight average by batch size later
            mean_weights.append(len(batch))

    return np.average(loss_running, weights=mean_weights)


In [12]:
model.pretraining()


In [13]:
# Initialize lists to store loss and accuracy values
loss_history = []
loss_history_val = []

loss_history_running = []

# Early Stopping
best_loss_val = np.inf
early_stop_count = 0
with DupStdout().dup_to_file(os.path.join(pretrain_folder, "log.txt"), "w") as f_log:
    # Train
    print("Start pretrain...")
    outer = tqdm(total=INITIAL_TRAIN_EPOCHS, desc="Epochs", position=0)
    inner = tqdm(total=len(dataloader_source_train), desc=f"Batch", position=1)

    checkpoint = {
        "epoch": -1,
        "model": model,
        "optimizer": pre_optimizer,
    }
    for epoch in range(INITIAL_TRAIN_EPOCHS):
        checkpoint["epoch"] = epoch

        # Train mode
        model.train()
        loss_running = []
        mean_weights = []

        inner.refresh()  # force print final state
        inner.reset()  # reuse bar
        for _, batch in enumerate(dataloader_source_train):

            pre_optimizer.zero_grad()
            loss = model_loss(*batch, model)
            loss_running.append(loss.item())
            mean_weights.append(len(batch))  # we will weight average by batch size later

            loss.backward()
            pre_optimizer.step()

            inner.update(1)

        loss_history.append(np.average(loss_running, weights=mean_weights))
        loss_history_running.append(loss_running)

        # Evaluate mode
        model.eval()
        with torch.no_grad():
            curr_loss_val = compute_acc(dataloader_source_val, model)
            loss_history_val.append(curr_loss_val)

        # Print the results
        outer.update(1)
        print(
            "epoch:",
            epoch,
            "train loss:",
            round(loss_history[-1], 6),
            "validation loss:",
            round(loss_history_val[-1], 6),
            end=" ",
        )

        # Save the best weights
        if curr_loss_val < best_loss_val:
            best_loss_val = curr_loss_val
            torch.save(checkpoint, os.path.join(pretrain_folder, f"best_model.pth"))
            early_stop_count = 0

            print("<-- new best val loss")
        else:
            print("")

        # Save checkpoint every 10
        if epoch % 10 == 0 or epoch >= INITIAL_TRAIN_EPOCHS - 1:
            torch.save(checkpoint, os.path.join(pretrain_folder, f"checkpt{epoch}.pth"))

        # check to see if validation loss has plateau'd
        if early_stop_count >= EARLY_STOP_CRIT and epoch >= MIN_EPOCHS - 1:
            print(f"Validation loss plateaued after {early_stop_count} at epoch {epoch}")
            torch.save(checkpoint, os.path.join(pretrain_folder, f"earlystop{epoch}.pth"))
            break

        early_stop_count += 1


    # Save final model
    torch.save(checkpoint, os.path.join(pretrain_folder, f"final_model.pth"))


Start pretrain...


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Batch:   0%|          | 0/40 [00:00<?, ?it/s]

epoch: 0 train loss: 1.208384 validation loss: 1.169877 <-- new best val loss
epoch: 1 train loss: 0.864314 validation loss: 0.790085 <-- new best val loss
epoch: 2 train loss: 0.695696 validation loss: 0.741301 <-- new best val loss
epoch: 3 train loss: 0.612641 validation loss: 0.69476 <-- new best val loss
epoch: 4 train loss: 0.57914 validation loss: 0.667346 <-- new best val loss
epoch: 5 train loss: 0.556197 validation loss: 0.67858 
epoch: 6 train loss: 0.53698 validation loss: 0.677397 
epoch: 7 train loss: 0.521874 validation loss: 0.663509 <-- new best val loss
epoch: 8 train loss: 0.507593 validation loss: 0.672562 
epoch: 9 train loss: 0.495928 validation loss: 0.671366 


 ## Adversarial Adaptation

In [14]:
advtrain_folder = os.path.join(model_folder, "advtrain")

if not os.path.isdir(advtrain_folder):
    os.makedirs(advtrain_folder)


In [15]:
def batch_generator(data, batch_size):
    """Generate batches of data.
    
    Given a list of numpy data, it iterates over the list and returns batches of
    the same size.
    """
    all_examples_indices = len(data[0])
    while True:
        mini_batch_indices = np.random.choice(
            all_examples_indices, size=batch_size, replace=False
        )
        tbr = [k[mini_batch_indices] for k in data]
        yield tbr


In [16]:
criterion_dis = nn.CrossEntropyLoss()


def discrim_loss_accu(x, y_dis, model):
    x = x.to(torch.float32).to(device)

    emb = model.source_encoder(x)
    y_pred = model.dis(emb)

    loss = criterion_dis(y_pred, y_dis)

    accu = torch.mean(
        (torch.flatten(torch.argmax(y_pred, dim=1)) == y_dis).to(torch.float32)
    ).cpu()

    return loss, accu


def compute_acc_dis(dataloader_source, dataloader_target, model):
    len_target = len(dataloader_target)
    len_source = len(dataloader_source)

    loss_running = []
    accu_running = []
    mean_weights = []
    model.eval()
    model.source_encoder.eval()
    model.dis.eval()
    with torch.no_grad():
        for y_val, dl in zip([1, 0], [dataloader_target, dataloader_source]):
            for _, (x, _) in enumerate(dl):

                y_dis = torch.full(
                    (x.shape[0],), y_val, device=device, dtype=torch.long
                )

                loss, accu = discrim_loss_accu(x, y_dis, model)

                accu_running.append(accu)
                loss_running.append(loss.item())

                # we will weight average by batch size later
                mean_weights.append(len(x))

    return (
        np.average(loss_running, weights=mean_weights),
        np.average(accu_running, weights=mean_weights),
    )


In [17]:
def train_adversarial(
    model,
    save_folder,
    sc_mix_train_s,
    lab_mix_train_s,
    mat_sp_train_s,
    dataloader_source_train_eval,
    dataloader_target_train_eval,
):

    model.to(device)
    model.advtraining()
    model.set_encoder("source")

    S_batches = batch_generator(
        [sc_mix_train_s.copy(), lab_mix_train_s.copy()], BATCH_SIZE
    )
    T_batches = batch_generator(
        [mat_sp_train_s.copy(), np.ones(shape=(len(mat_sp_train_s), 2))], BATCH_SIZE
    )

    enc_optimizer = torch.optim.Adam(
        chain(
            model.source_encoder.parameters(),
            model.dis.parameters(),
            model.clf.parameters(),
        ),
        lr=0.001,
        betas=(0.9, 0.999),
        eps=1e-07,
    )

    dis_optimizer = torch.optim.Adam(
        chain(model.source_encoder.parameters(), model.dis.parameters()),
        lr=ALPHA_LR * 0.001,
        betas=(0.9, 0.999),
        eps=1e-07,
    )

    # Initialize lists to store loss and accuracy values
    loss_history_running = []
    with DupStdout().dup_to_file(os.path.join(save_folder, "log.txt"), "w") as f_log:
        # Train
        print("Start adversarial training...")
        outer = tqdm(total=N_ITER, desc="Iterations", position=0)

        checkpoint = {
            "epoch": -1,
            "model": model,
            "dis_optimizer": dis_optimizer,
            "enc_optimizer": enc_optimizer,
        }
        for iters in range(N_ITER):
            checkpoint["epoch"] = iters

            model.train()
            model.dis.train()
            model.clf.train()
            model.source_encoder.train()

            x_source, y_true = next(S_batches)
            x_target, _ = next(T_batches)

            ## Train encoder
            set_requires_grad(model.source_encoder, True)
            set_requires_grad(model.clf, True)
            set_requires_grad(model.dis, True)

            # save discriminator weights
            dis_weights = deepcopy(model.dis.state_dict())
            new_dis_weights = {}
            for k in dis_weights:
                if "num_batches_tracked" not in k:
                    new_dis_weights[k] = dis_weights[k]

            dis_weights = new_dis_weights

            x_source, x_target, y_true, = (
                torch.Tensor(x_source),
                torch.Tensor(x_target),
                torch.Tensor(y_true),
            )
            x_source, x_target, y_true, = (
                x_source.to(device),
                x_target.to(device),
                y_true.to(device),
            )

            x = torch.cat((x_source, x_target))

            # save for discriminator later
            x_d = x.detach()

            # y_dis is the REAL one
            y_dis = torch.cat(
                [
                    torch.zeros(x_source.shape[0], device=device, dtype=torch.long),
                    torch.ones(x_target.shape[0], device=device, dtype=torch.long),
                ]
            )
            y_dis_flipped = 1 - y_dis.detach()

            emb = model.source_encoder(x).view(x.shape[0], -1)

            y_dis_pred = model.dis(emb)
            y_clf_pred = model.clf(emb)

            # we use flipped because we want to confuse discriminator
            loss_dis = criterion_dis(y_dis_pred, y_dis_flipped)

            # Set true = predicted for target samples since we don't know what it is
            y_clf_true = torch.cat((y_true, y_clf_pred[-x_target.shape[0] :].detach()))
            # Loss fn does mean over all samples including target
            loss_clf = criterion_clf(y_clf_pred, y_clf_true)

            # Scale back up loss so mean doesn't include target
            loss = (x.shape[0] / x_source.shape[0]) * loss_clf + ALPHA * loss_dis

            # loss = loss_clf + ALPHA * loss_dis

            enc_optimizer.zero_grad()
            loss.backward()
            enc_optimizer.step()

            model.dis.load_state_dict(dis_weights, strict=False)

            ## Train discriminator
            set_requires_grad(model.source_encoder, True)
            set_requires_grad(model.clf, True)
            set_requires_grad(model.dis, True)

            # save encoder and clf weights
            source_encoder_weights = deepcopy(model.source_encoder.state_dict())
            clf_weights = deepcopy(model.clf.state_dict())

            new_clf_weights = {}
            for k in clf_weights:
                if "num_batches_tracked" not in k:
                    new_clf_weights[k] = clf_weights[k]

            clf_weights = new_clf_weights

            new_source_encoder_weights = {}
            for k in source_encoder_weights:
                if "num_batches_tracked" not in k:
                    new_source_encoder_weights[k] = source_encoder_weights[k]

            source_encoder_weights = new_source_encoder_weights

            emb = model.source_encoder(x_d).view(x_d.shape[0], -1)
            y_pred = model.dis(emb)

            # we use the real domain labels to train discriminator
            loss = criterion_dis(y_pred, y_dis)

            dis_optimizer.zero_grad()
            loss.backward()
            dis_optimizer.step()

            model.clf.load_state_dict(clf_weights, strict=False)
            model.source_encoder.load_state_dict(source_encoder_weights, strict=False)

            # Save checkpoint every 100
            if iters % 100 == 99 or iters >= N_ITER - 1:
                torch.save(checkpoint, os.path.join(save_folder, f"checkpt{iters}.pth"))

                model.eval()
                source_loss = compute_acc(dataloader_source_train_eval, model)
                _, dis_accu = compute_acc_dis(
                    dataloader_source_train_eval, dataloader_target_train_eval, model
                )

                # Print the results
                print(
                    "iter:",
                    iters,
                    "source loss:",
                    round(source_loss, 6),
                    "dis accu:",
                    round(dis_accu, 6),
                )

            outer.update(1)

    torch.save(checkpoint, os.path.join(save_folder, f"final_model.pth"))


In [18]:
# st_sample_id_l = [SAMPLE_ID_N]


In [19]:
if TRAIN_USING_ALL_ST_SAMPLES:
    print(f"Adversarial training for all ST slides")
    save_folder = advtrain_folder

    best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
    model = best_checkpoint["model"]
    model.to(device)
    model.advtraining()

    train_adversarial(
        model,
        save_folder,
        sc_mix_d["train"],
        lab_mix_d["train"],
        mat_sp_train,
        dataloader_source_train,
        dataloader_target_train,
    )

else:
    for sample_id in st_sample_id_l:
        print(f"Adversarial training for ST slide {sample_id}: ")

        save_folder = os.path.join(advtrain_folder, sample_id)
        if not os.path.isdir(save_folder):
            os.makedirs(save_folder)

        best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
        model = best_checkpoint["model"]
        model.to(device)
        model.advtraining()

        train_adversarial(
            model,
            save_folder,
            sc_mix_d["train"],
            lab_mix_d["train"],
            mat_sp_d[sample_id]["train"],
            dataloader_source_train,
            dataloader_target_train_d[sample_id],
        )


Adversarial training for ST slide 151507: 
Start adversarial training...


Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 0.987163 dis accu: 0.846941
iter: 199 source loss: 0.929064 dis accu: 0.541402
iter: 299 source loss: 0.871016 dis accu: 0.492776
iter: 399 source loss: 0.82102 dis accu: 0.914348
iter: 499 source loss: 0.843393 dis accu: 0.379757
iter: 599 source loss: 0.807728 dis accu: 0.03133
iter: 699 source loss: 0.806839 dis accu: 0.257038
iter: 799 source loss: 0.786054 dis accu: 0.199538
iter: 899 source loss: 0.728812 dis accu: 0.629943
iter: 999 source loss: 0.705629 dis accu: 0.997606
iter: 1099 source loss: 0.756152 dis accu: 0.361554
iter: 1199 source loss: 0.694573 dis accu: 0.84884
iter: 1299 source loss: 0.719585 dis accu: 0.034508
iter: 1399 source loss: 0.734894 dis accu: 0.9044
iter: 1499 source loss: 0.749051 dis accu: 0.760505
iter: 1599 source loss: 0.677974 dis accu: 0.343763
iter: 1699 source loss: 0.618616 dis accu: 0.998927
iter: 1799 source loss: 0.77521 dis accu: 0.205606
iter: 1899 source loss: 0.646458 dis accu: 0.924915
iter: 1999 source loss: 0.650

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 0.772525 dis accu: 0.999385
iter: 199 source loss: 0.946078 dis accu: 0.447712
iter: 299 source loss: 0.871639 dis accu: 0.522884
iter: 399 source loss: 0.781804 dis accu: 0.864255
iter: 499 source loss: 0.746321 dis accu: 0.995612
iter: 599 source loss: 0.822482 dis accu: 0.881972
iter: 699 source loss: 0.759046 dis accu: 0.07669
iter: 799 source loss: 0.788363 dis accu: 0.586983
iter: 899 source loss: 0.68845 dis accu: 0.532685
iter: 999 source loss: 0.756128 dis accu: 0.364624
iter: 1099 source loss: 0.717931 dis accu: 0.503937
iter: 1199 source loss: 0.710001 dis accu: 0.840428
iter: 1299 source loss: 0.701363 dis accu: 0.590961
iter: 1399 source loss: 0.661991 dis accu: 0.999139
iter: 1499 source loss: 0.648142 dis accu: 0.820579
iter: 1599 source loss: 0.695152 dis accu: 0.910925
iter: 1699 source loss: 0.761645 dis accu: 0.238517
iter: 1799 source loss: 0.665953 dis accu: 0.907767
iter: 1899 source loss: 0.831585 dis accu: 0.778912
iter: 1999 source loss: 0

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 0.886212 dis accu: 0.94792
iter: 199 source loss: 0.851239 dis accu: 0.877244
iter: 299 source loss: 0.784407 dis accu: 0.267901
iter: 399 source loss: 0.759543 dis accu: 0.183589
iter: 499 source loss: 0.703336 dis accu: 0.062205
iter: 599 source loss: 0.685696 dis accu: 0.266287
iter: 699 source loss: 0.641882 dis accu: 0.16237
iter: 799 source loss: 0.637661 dis accu: 0.097342
iter: 899 source loss: 0.656361 dis accu: 0.783372
iter: 999 source loss: 0.632556 dis accu: 0.759692
iter: 1099 source loss: 0.629674 dis accu: 0.257735
iter: 1199 source loss: 0.631924 dis accu: 0.446246
iter: 1299 source loss: 0.565997 dis accu: 0.999919
iter: 1399 source loss: 0.578601 dis accu: 0.020453
iter: 1499 source loss: 0.541368 dis accu: 0.965267
iter: 1599 source loss: 0.56871 dis accu: 0.807737
iter: 1699 source loss: 0.560466 dis accu: 0.97394
iter: 1799 source loss: 0.548946 dis accu: 0.982129
iter: 1899 source loss: 0.579157 dis accu: 0.582476
iter: 1999 source loss: 0.5

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 0.925534 dis accu: 0.519851
iter: 199 source loss: 0.879942 dis accu: 0.422993
iter: 299 source loss: 0.914214 dis accu: 0.05736
iter: 399 source loss: 0.769398 dis accu: 0.002314
iter: 499 source loss: 0.709683 dis accu: 0.000487
iter: 599 source loss: 0.739952 dis accu: 0.00479
iter: 699 source loss: 0.630766 dis accu: 0.067792
iter: 799 source loss: 0.606252 dis accu: 0.05261
iter: 899 source loss: 0.577121 dis accu: 0.000122
iter: 999 source loss: 0.564795 dis accu: 0.006373
iter: 1099 source loss: 0.538764 dis accu: 0.008768
iter: 1199 source loss: 0.503509 dis accu: 0.177559
iter: 1299 source loss: 0.485708 dis accu: 0.204473
iter: 1399 source loss: 0.473228 dis accu: 0.472518
iter: 1499 source loss: 0.469606 dis accu: 0.483519
iter: 1599 source loss: 0.474332 dis accu: 0.480677
iter: 1699 source loss: 0.44426 dis accu: 0.453966
iter: 1799 source loss: 0.439124 dis accu: 0.312495
iter: 1899 source loss: 0.410535 dis accu: 0.469838
iter: 1999 source loss: 0.4

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.025688 dis accu: 0.115211
iter: 199 source loss: 0.873875 dis accu: 0.658129
iter: 299 source loss: 0.888087 dis accu: 0.612485
iter: 399 source loss: 0.860045 dis accu: 0.013609
iter: 499 source loss: 0.763745 dis accu: 0.154305
iter: 599 source loss: 0.698494 dis accu: 0.273995
iter: 699 source loss: 0.653691 dis accu: 0.019737
iter: 799 source loss: 0.621213 dis accu: 0.167787
iter: 899 source loss: 0.581006 dis accu: 0.039559
iter: 999 source loss: 0.549078 dis accu: 0.001183
iter: 1099 source loss: 0.517015 dis accu: 0.870758
iter: 1199 source loss: 0.505947 dis accu: 0.499429
iter: 1299 source loss: 0.490193 dis accu: 0.660074
iter: 1399 source loss: 0.485607 dis accu: 0.023921
iter: 1499 source loss: 0.466372 dis accu: 0.203499
iter: 1599 source loss: 0.463984 dis accu: 0.742107
iter: 1699 source loss: 0.464256 dis accu: 0.464055
iter: 1799 source loss: 0.445586 dis accu: 0.3428
iter: 1899 source loss: 0.420139 dis accu: 0.534086
iter: 1999 source loss: 0

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 0.933133 dis accu: 0.06924
iter: 199 source loss: 0.917034 dis accu: 0.128479
iter: 299 source loss: 0.798304 dis accu: 0.104945
iter: 399 source loss: 0.770476 dis accu: 0.6558
iter: 499 source loss: 0.737161 dis accu: 0.39531
iter: 599 source loss: 0.703902 dis accu: 0.234871
iter: 699 source loss: 0.646539 dis accu: 0.66444
iter: 799 source loss: 0.607428 dis accu: 0.988978
iter: 899 source loss: 0.594819 dis accu: 0.01515
iter: 999 source loss: 0.624458 dis accu: 0.037748
iter: 1099 source loss: 0.548845 dis accu: 0.995702
iter: 1199 source loss: 0.53191 dis accu: 0.95689
iter: 1299 source loss: 0.519419 dis accu: 0.854285
iter: 1399 source loss: 0.510201 dis accu: 0.083411
iter: 1499 source loss: 0.488095 dis accu: 0.365818
iter: 1599 source loss: 0.478415 dis accu: 0.174057
iter: 1699 source loss: 0.471917 dis accu: 0.169334
iter: 1799 source loss: 0.460904 dis accu: 0.560941
iter: 1899 source loss: 0.432174 dis accu: 0.69457
iter: 1999 source loss: 0.423436

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 0.967231 dis accu: 0.363293
iter: 199 source loss: 0.910498 dis accu: 0.156947
iter: 299 source loss: 0.821127 dis accu: 0.036831
iter: 399 source loss: 0.82612 dis accu: 0.950518
iter: 499 source loss: 0.859278 dis accu: 0.018333
iter: 599 source loss: 0.888861 dis accu: 0.236748
iter: 699 source loss: 0.7375 dis accu: 0.568022
iter: 799 source loss: 0.716786 dis accu: 0.993986
iter: 899 source loss: 0.717185 dis accu: 0.354168
iter: 999 source loss: 0.64422 dis accu: 0.995438
iter: 1099 source loss: 0.729165 dis accu: 0.002489
iter: 1199 source loss: 0.713208 dis accu: 0.453961
iter: 1299 source loss: 0.662644 dis accu: 0.783907
iter: 1399 source loss: 0.655381 dis accu: 0.038366
iter: 1499 source loss: 0.612427 dis accu: 0.916549
iter: 1599 source loss: 0.639299 dis accu: 0.736002
iter: 1699 source loss: 0.670534 dis accu: 0.148196
iter: 1799 source loss: 0.614143 dis accu: 0.809498
iter: 1899 source loss: 0.611004 dis accu: 0.754873
iter: 1999 source loss: 0.6

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.06253 dis accu: 0.00025
iter: 199 source loss: 0.934737 dis accu: 0.120924
iter: 299 source loss: 0.878845 dis accu: 0.615074
iter: 399 source loss: 0.742251 dis accu: 0.956194
iter: 499 source loss: 0.727797 dis accu: 0.976723
iter: 599 source loss: 0.690168 dis accu: 0.729253
iter: 699 source loss: 0.690184 dis accu: 0.039517
iter: 799 source loss: 0.688676 dis accu: 0.005247
iter: 899 source loss: 0.709747 dis accu: 0.056756
iter: 999 source loss: 0.67936 dis accu: 0.18351
iter: 1099 source loss: 0.615073 dis accu: 0.342661
iter: 1199 source loss: 0.590831 dis accu: 0.939579
iter: 1299 source loss: 0.596784 dis accu: 0.9863
iter: 1399 source loss: 0.675334 dis accu: 0.001416
iter: 1499 source loss: 0.594673 dis accu: 0.999625
iter: 1599 source loss: 0.625465 dis accu: 0.017489
iter: 1699 source loss: 0.586103 dis accu: 0.998459
iter: 1799 source loss: 0.658754 dis accu: 0.859046
iter: 1899 source loss: 0.624789 dis accu: 0.33529
iter: 1999 source loss: 0.6402

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 0.939206 dis accu: 0.297094
iter: 199 source loss: 0.882794 dis accu: 0.586785
iter: 299 source loss: 0.810105 dis accu: 0.743517
iter: 399 source loss: 0.745081 dis accu: 0.931596
iter: 499 source loss: 0.749572 dis accu: 0.724777
iter: 599 source loss: 0.779062 dis accu: 0.579889
iter: 699 source loss: 0.738082 dis accu: 0.995558
iter: 799 source loss: 0.801756 dis accu: 0.649647
iter: 899 source loss: 0.757102 dis accu: 0.607809
iter: 999 source loss: 0.70169 dis accu: 0.792715
iter: 1099 source loss: 0.684356 dis accu: 0.912814
iter: 1199 source loss: 0.697159 dis accu: 0.217733
iter: 1299 source loss: 0.639676 dis accu: 0.995812
iter: 1399 source loss: 0.64344 dis accu: 0.535344
iter: 1499 source loss: 0.623682 dis accu: 0.960827
iter: 1599 source loss: 0.687744 dis accu: 0.713017
iter: 1699 source loss: 0.615287 dis accu: 0.75735
iter: 1799 source loss: 0.621258 dis accu: 0.852743
iter: 1899 source loss: 0.629088 dis accu: 0.705529
iter: 1999 source loss: 0.

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 0.942204 dis accu: 0.969332
iter: 199 source loss: 0.946177 dis accu: 0.097495
iter: 299 source loss: 0.890763 dis accu: 0.60795
iter: 399 source loss: 0.84519 dis accu: 0.594559
iter: 499 source loss: 0.810966 dis accu: 0.226249
iter: 599 source loss: 0.705212 dis accu: 0.426055
iter: 699 source loss: 0.721166 dis accu: 0.367634
iter: 799 source loss: 0.648708 dis accu: 0.997381
iter: 899 source loss: 0.771955 dis accu: 0.28467
iter: 999 source loss: 0.653225 dis accu: 0.482279
iter: 1099 source loss: 0.63925 dis accu: 0.904448
iter: 1199 source loss: 0.655374 dis accu: 0.330292
iter: 1299 source loss: 0.646943 dis accu: 0.020023
iter: 1399 source loss: 0.598847 dis accu: 0.99831
iter: 1499 source loss: 0.613607 dis accu: 0.148904
iter: 1599 source loss: 0.594862 dis accu: 0.957716
iter: 1699 source loss: 0.579323 dis accu: 0.987412
iter: 1799 source loss: 0.691423 dis accu: 0.077937
iter: 1899 source loss: 0.630647 dis accu: 0.659317
iter: 1999 source loss: 0.64

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 0.90995 dis accu: 0.390387
iter: 199 source loss: 0.890558 dis accu: 0.677433
iter: 299 source loss: 0.917241 dis accu: 0.336258
iter: 399 source loss: 0.912211 dis accu: 0.854612
iter: 499 source loss: 0.765533 dis accu: 0.579603
iter: 599 source loss: 0.761247 dis accu: 0.770176
iter: 699 source loss: 0.776101 dis accu: 0.621651
iter: 799 source loss: 0.731634 dis accu: 0.806714
iter: 899 source loss: 0.682967 dis accu: 0.956765
iter: 999 source loss: 0.826407 dis accu: 0.095032
iter: 1099 source loss: 0.659661 dis accu: 0.996991
iter: 1199 source loss: 0.740432 dis accu: 0.058622
iter: 1299 source loss: 0.691746 dis accu: 0.949941
iter: 1399 source loss: 0.658886 dis accu: 0.999407
iter: 1499 source loss: 0.750399 dis accu: 0.25
iter: 1599 source loss: 0.686818 dis accu: 0.888013
iter: 1699 source loss: 0.738681 dis accu: 0.717531
iter: 1799 source loss: 0.711268 dis accu: 0.855756
iter: 1899 source loss: 0.65895 dis accu: 0.988004
iter: 1999 source loss: 0.652

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 0.983511 dis accu: 0.592882
iter: 199 source loss: 0.935838 dis accu: 0.189429
iter: 299 source loss: 0.900617 dis accu: 0.626641
iter: 399 source loss: 0.896808 dis accu: 0.989514
iter: 499 source loss: 0.774078 dis accu: 0.986573
iter: 599 source loss: 0.842533 dis accu: 0.225064
iter: 699 source loss: 0.749913 dis accu: 0.933333
iter: 799 source loss: 0.796364 dis accu: 0.959676
iter: 899 source loss: 0.768189 dis accu: 0.114408
iter: 999 source loss: 0.743306 dis accu: 0.832225
iter: 1099 source loss: 0.71562 dis accu: 0.178772
iter: 1199 source loss: 0.673284 dis accu: 0.979071
iter: 1299 source loss: 0.70627 dis accu: 0.92191
iter: 1399 source loss: 0.675174 dis accu: 0.973956
iter: 1499 source loss: 0.636279 dis accu: 0.999744
iter: 1599 source loss: 0.676284 dis accu: 0.619309
iter: 1699 source loss: 0.664814 dis accu: 0.814066
iter: 1799 source loss: 0.623471 dis accu: 0.943052
iter: 1899 source loss: 0.626794 dis accu: 0.629497
iter: 1999 source loss: 0.

In [20]:
# if TRAIN_USING_ALL_ST_SAMPLES:
#     best_checkpoint = torch.load(os.path.join(advtrain_folder, f"final_model.pth"))
# else:
#     best_checkpoint = torch.load(
#         os.path.join(advtrain_folder, SAMPLE_ID_N, f"final_model.pth")
#     )

# model = best_checkpoint["model"]
# model.to(device)

# model.eval()
# model.set_encoder("source")

# pred_mix = (
#     torch.exp(model(torch.Tensor(sc_mix_test_s).to(device)))
#     .detach()
#     .cpu()
#     .numpy()
# )

# cell_type_nums = sc_sub_dict.keys()
# nrows = ceil(len(cell_type_nums) / 5)

# line_kws = {"color": "tab:orange"}
# scatter_kws = {"s": 5}

# props = dict(facecolor="w", alpha=0.5)

# fig, ax = plt.subplots(
#     nrows,
#     5,
#     figsize=(25, 5 * nrows),
#     constrained_layout=True,
#     sharex=False,
#     sharey=True,
# )
# for i, visnum in enumerate(cell_type_nums):
#     sns.regplot(
#         x=pred_mix[:, visnum],
#         y=lab_mix_d["test"][:, visnum],
#         line_kws=line_kws,
#         scatter_kws=scatter_kws,
#         ax=ax.flat[i],
#     ).set_title(sc_sub_dict[visnum])

#     ax.flat[i].set_aspect("equal")
#     ax.flat[i].set_xlabel("Predicted Proportion")

#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("True Proportion")
#     else:
#         ax.flat[i].set_ylabel("")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     textstr = (
#         f"MSE: {mean_squared_error(pred_mix[:,visnum], lab_mix_d["test"][:,visnum]):.5f}"
#     )

#     # place a text box in upper left in axes coords
#     ax.flat[i].text(
#         0.95,
#         0.05,
#         textstr,
#         transform=ax.flat[i].transAxes,
#         verticalalignment="bottom",
#         horizontalalignment="right",
#         bbox=props,
#     )

# for i in range(len(cell_type_nums), nrows * 5):
#     ax.flat[i].axis("off")

# plt.show()
