 # ADDA for ST

In [1]:
import os
import datetime
from itertools import chain
from copy import deepcopy
import warnings

from tqdm.autonotebook import tqdm

import h5py
import pickle
import numpy as np

import torch
from torch import nn

from src.da_models.adda import ADDAST
from src.da_models.datasets import SpotDataset
from src.da_models.utils import set_requires_grad
from src.utils.data_loading import load_spatial, load_sc

# datetime object containing current date and time
script_start_time = datetime.datetime.now().strftime("%Y-%m-%d_%Hh%Mm%S")


  from tqdm.autonotebook import tqdm


In [2]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
if device == "cpu":
    warnings.warn("Using CPU", stacklevel=2)


In [3]:
NUM_MARKERS = 20
N_MIX = 8
N_SPOTS = 20000
TRAIN_USING_ALL_ST_SAMPLES = False

ST_SPLIT = False

SAMPLE_ID_N = "151673"

BATCH_SIZE = 512
NUM_WORKERS = 4
INITIAL_TRAIN_EPOCHS = 10

EARLY_STOP_CRIT = 100
MIN_EPOCHS = INITIAL_TRAIN_EPOCHS


EARLY_STOP_CRIT_ADV = 10
MIN_EPOCHS_ADV = 10

SPATIALLIBD_DIR = "./data/spatialLIBD"
SC_DLPFC_PATH = "./data/sc_dlpfc/adata_sc_dlpfc.h5ad"

PROCESSED_DATA_DIR = "./data/preprocessed_markers_celldart"

MODEL_NAME = "CellDART"


celldart_kwargs = {
    "emb_dim": 64,
    "bn_momentum": 0.01,
}


In [None]:
N_ITER = 3000
ALPHA_LR = 5
ALPHA = 0.6


In [4]:
model_folder = os.path.join("model", MODEL_NAME, script_start_time)

model_folder = os.path.join("model", MODEL_NAME, "bn_fix")

if not os.path.isdir(model_folder):
    os.makedirs(model_folder)
    print(model_folder)


 # Data load


In [5]:
# Load spatial data
mat_sp_d, mat_sp_train_s, st_sample_id_l = load_spatial(
    TRAIN_USING_ALL_ST_SAMPLES, PROCESSED_DATA_DIR, ST_SPLIT
)

# Load sc data
sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = load_sc(PROCESSED_DATA_DIR)


 # Training: Adversarial domain adaptation for cell fraction estimation

 ## Prepare dataloaders

In [6]:
### source dataloaders
source_train_set = SpotDataset(
    deepcopy(sc_mix_d["train"]), deepcopy(lab_mix_d["train"])
)
source_val_set = SpotDataset(deepcopy(sc_mix_d["val"]), deepcopy(lab_mix_d["val"]))
source_test_set = SpotDataset(deepcopy(sc_mix_d["test"]), deepcopy(lab_mix_d["test"]))

dataloader_source_train = torch.utils.data.DataLoader(
    source_train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
dataloader_source_val = torch.utils.data.DataLoader(
    source_val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)
dataloader_source_test = torch.utils.data.DataLoader(
    source_test_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)

### target dataloaders
target_test_set_d = {}
for sample_id in st_sample_id_l:
    target_test_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d["test"][sample_id]))

dataloader_target_test_d = {}
for sample_id in st_sample_id_l:
    dataloader_target_test_d[sample_id] = torch.utils.data.DataLoader(
        target_test_set_d[sample_id],
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=False,
    )

if TRAIN_USING_ALL_ST_SAMPLES:
    target_train_set = SpotDataset(deepcopy(mat_sp_train_s))
    dataloader_target_train = torch.utils.data.DataLoader(
        target_train_set,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
else:
    target_train_set_d = {}
    dataloader_target_train_d = {}
    for sample_id in st_sample_id_l:
        target_train_set_d[sample_id] = SpotDataset(
            deepcopy(mat_sp_d["train"][sample_id])
        )
        dataloader_target_train_d[sample_id] = torch.utils.data.DataLoader(
            target_train_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=True,
        )


 ## Define Model

In [7]:
model = ADDAST(
    inp_dim=sc_mix_d["train"].shape[1],
    ncls_source=lab_mix_d["train"].shape[1],
    **celldart_kwargs
)

## CellDART uses just one encoder!
model.target_encoder = model.source_encoder
model.to(device)


ADDAST(
  (source_encoder): MLPEncoder(
    (encoder): Sequential(
      (0): Linear(in_features=367, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): Linear(in_features=1024, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (5): ELU(alpha=1.0)
    )
  )
  (target_encoder): MLPEncoder(
    (encoder): Sequential(
      (0): Linear(in_features=367, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): Linear(in_features=1024, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (5): ELU(alpha=1.0)
    )
  )
  (dis): Discriminator(
    (head): Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1):

 ## Pretrain

In [8]:
pretrain_folder = os.path.join(model_folder, "pretrain")

if not os.path.isdir(pretrain_folder):
    os.makedirs(pretrain_folder)


In [9]:
pre_optimizer = torch.optim.Adam(
    model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-07
)

criterion_clf = nn.KLDivLoss(reduction="batchmean")


In [10]:
def model_loss(x, y_true, model):
    x = x.to(torch.float32).to(device)
    y_true = y_true.to(torch.float32).to(device)

    y_pred = model(x)

    loss = criterion_clf(y_pred, y_true)

    return loss


def compute_acc(dataloader, model):
    loss_running = []
    mean_weights = []
    model.eval()
    with torch.no_grad():
        for _, batch in enumerate(dataloader):

            loss = model_loss(*batch, model)

            loss_running.append(loss.item())

            # we will weight average by batch size later
            mean_weights.append(len(batch))

    return np.average(loss_running, weights=mean_weights)


In [11]:
model.pretraining()


In [12]:
# Initialize lists to store loss and accuracy values
loss_history = []
loss_history_val = []

loss_history_running = []

# Early Stopping
best_loss_val = np.inf
early_stop_count = 0

# Train
print("Start pretrain...")
outer = tqdm(total=INITIAL_TRAIN_EPOCHS, desc="Epochs", position=0)
inner = tqdm(total=len(dataloader_source_train), desc=f"Batch", position=1)

checkpoint = {
    "epoch": -1,
    "model": model,
    "optimizer": pre_optimizer,
}
for epoch in range(INITIAL_TRAIN_EPOCHS):
    checkpoint["epoch"] = epoch

    # Train mode
    model.train()
    loss_running = []
    mean_weights = []

    inner.refresh()  # force print final state
    inner.reset()  # reuse bar
    for _, batch in enumerate(dataloader_source_train):

        pre_optimizer.zero_grad()
        loss = model_loss(*batch, model)
        loss_running.append(loss.item())
        mean_weights.append(len(batch))  # we will weight average by batch size later

        loss.backward()
        pre_optimizer.step()

        inner.update(1)

    loss_history.append(np.average(loss_running, weights=mean_weights))
    loss_history_running.append(loss_running)

    # Evaluate mode
    model.eval()
    with torch.no_grad():
        curr_loss_val = compute_acc(dataloader_source_val, model)
        loss_history_val.append(curr_loss_val)

    # Print the results
    outer.update(1)
    print(
        "epoch:",
        epoch,
        "train loss:",
        round(loss_history[-1], 6),
        "validation loss:",
        round(loss_history_val[-1], 6),
        end=" ",
    )

    # Save the best weights
    if curr_loss_val < best_loss_val:
        best_loss_val = curr_loss_val
        torch.save(checkpoint, os.path.join(pretrain_folder, f"best_model.pth"))
        early_stop_count = 0

        print("<-- new best val loss")
    else:
        print("")

    # Save checkpoint every 10
    if epoch % 10 == 0 or epoch >= INITIAL_TRAIN_EPOCHS - 1:
        torch.save(checkpoint, os.path.join(pretrain_folder, f"checkpt{epoch}.pth"))

    # check to see if validation loss has plateau'd
    if early_stop_count >= EARLY_STOP_CRIT and epoch >= MIN_EPOCHS - 1:
        print(f"Validation loss plateaued after {early_stop_count} at epoch {epoch}")
        torch.save(checkpoint, os.path.join(pretrain_folder, f"earlystop{epoch}.pth"))
        break

    early_stop_count += 1


# Save final model
torch.save(checkpoint, os.path.join(pretrain_folder, f"final_model.pth"))


Start pretrain...


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Batch:   0%|          | 0/40 [00:00<?, ?it/s]

epoch: 0 train loss: 1.254696 validation loss: 1.486211 <-- new best val loss
epoch: 1 train loss: 0.903333 validation loss: 1.219224 <-- new best val loss
epoch: 2 train loss: 0.720916 validation loss: 1.022733 <-- new best val loss
epoch: 3 train loss: 0.625695 validation loss: 0.97935 <-- new best val loss
epoch: 4 train loss: 0.585073 validation loss: 0.92795 <-- new best val loss
epoch: 5 train loss: 0.562747 validation loss: 0.876444 <-- new best val loss
epoch: 6 train loss: 0.541671 validation loss: 0.799501 <-- new best val loss
epoch: 7 train loss: 0.525164 validation loss: 0.820431 
epoch: 8 train loss: 0.512558 validation loss: 0.726743 <-- new best val loss
epoch: 9 train loss: 0.498368 validation loss: 0.691346 <-- new best val loss


 ## Adversarial Adaptation

In [14]:
advtrain_folder = os.path.join(model_folder, "advtrain")

if not os.path.isdir(advtrain_folder):
    os.makedirs(advtrain_folder)


In [15]:
def batch_generator(data, batch_size):
    """Generate batches of data.
    
    Given a list of numpy data, it iterates over the list and returns batches of
    the same size.
    """
    all_examples_indices = len(data[0])
    while True:
        mini_batch_indices = np.random.choice(
            all_examples_indices, size=batch_size, replace=False
        )
        tbr = [k[mini_batch_indices] for k in data]
        yield tbr


In [16]:
criterion_dis = nn.CrossEntropyLoss()


def discrim_loss_accu(x, y_dis, model):
    x = x.to(torch.float32).to(device)

    emb = model.source_encoder(x)
    y_pred = model.dis(emb)

    loss = criterion_dis(y_pred, y_dis)

    accu = torch.mean(
        (torch.flatten(torch.argmax(y_pred, dim=1)) == y_dis).to(torch.float32)
    ).cpu()

    return loss, accu


def compute_acc_dis(dataloader_source, dataloader_target, model):
    len_target = len(dataloader_target)
    len_source = len(dataloader_source)

    loss_running = []
    accu_running = []
    mean_weights = []
    model.eval()
    model.source_encoder.eval()
    model.dis.eval()
    with torch.no_grad():
        for y_val, dl in zip([1, 0], [dataloader_target, dataloader_source]):
            for _, (x, _) in enumerate(dl):

                y_dis = torch.full(
                    (x.shape[0],), y_val, device=device, dtype=torch.long
                )

                loss, accu = discrim_loss_accu(x, y_dis, model)

                accu_running.append(accu)
                loss_running.append(loss.item())

                # we will weight average by batch size later
                mean_weights.append(len(x))

    return (
        np.average(loss_running, weights=mean_weights),
        np.average(accu_running, weights=mean_weights),
    )


In [17]:
def train_adversarial(
    model,
    save_folder,
    sc_mix_train_s,
    lab_mix_train_s,
    mat_sp_train_s,
    dataloader_source_train_eval,
    dataloader_target_train_eval,
):

    model.to(device)
    model.advtraining()
    model.set_encoder("source")

    S_batches = batch_generator(
        [sc_mix_train_s.copy(), lab_mix_train_s.copy()], BATCH_SIZE
    )
    T_batches = batch_generator(
        [mat_sp_train_s.copy(), np.ones(shape=(len(mat_sp_train_s), 2))], BATCH_SIZE
    )

    enc_optimizer = torch.optim.Adam(
        chain(
            model.source_encoder.parameters(),
            model.dis.parameters(),
            model.clf.parameters(),
        ),
        lr=0.001,
        betas=(0.9, 0.999),
        eps=1e-07,
    )

    dis_optimizer = torch.optim.Adam(
        chain(model.source_encoder.parameters(), model.dis.parameters()),
        lr=ALPHA_LR * 0.001,
        betas=(0.9, 0.999),
        eps=1e-07,
    )

    # Initialize lists to store loss and accuracy values
    loss_history_running = []

    # Train
    print("Start adversarial training...")
    outer = tqdm(total=N_ITER, desc="Iterations", position=0)

    checkpoint = {
        "epoch": -1,
        "model": model,
        "dis_optimizer": dis_optimizer,
        "enc_optimizer": enc_optimizer,
    }
    for iters in range(N_ITER):
        checkpoint["epoch"] = iters

        model.train()
        model.dis.train()
        model.clf.train()
        model.source_encoder.train()

        x_source, y_true = next(S_batches)
        x_target, _ = next(T_batches)

        ## Train encoder
        set_requires_grad(model.source_encoder, True)
        set_requires_grad(model.clf, True)
        set_requires_grad(model.dis, True)

        # save discriminator weights
        dis_weights = deepcopy(model.dis.state_dict())
        new_dis_weights = {}
        for k in dis_weights:
            if "num_batches_tracked" not in k:
                new_dis_weights[k] = dis_weights[k]

        dis_weights = new_dis_weights

        x_source, x_target, y_true, = (
            torch.Tensor(x_source),
            torch.Tensor(x_target),
            torch.Tensor(y_true),
        )
        x_source, x_target, y_true, = (
            x_source.to(device),
            x_target.to(device),
            y_true.to(device),
        )

        x = torch.cat((x_source, x_target))

        # save for discriminator later
        x_d = x.detach()

        # y_dis is the REAL one
        y_dis = torch.cat(
            [
                torch.zeros(x_source.shape[0], device=device, dtype=torch.long),
                torch.ones(x_target.shape[0], device=device, dtype=torch.long),
            ]
        )
        y_dis_flipped = 1 - y_dis.detach()

        emb = model.source_encoder(x).view(x.shape[0], -1)

        y_dis_pred = model.dis(emb)
        y_clf_pred = model.clf(emb)

        # we use flipped because we want to confuse discriminator
        loss_dis = criterion_dis(y_dis_pred, y_dis_flipped)

        # Set true = predicted for target samples since we don't know what it is
        y_clf_true = torch.cat((y_true, y_clf_pred[-x_target.shape[0] :].detach()))
        # Loss fn does mean over all samples including target
        loss_clf = criterion_clf(y_clf_pred, y_clf_true)

        # Scale back up loss so mean doesn't include target
        loss = (x.shape[0] / x_source.shape[0]) * loss_clf + ALPHA * loss_dis

        # loss = loss_clf + ALPHA * loss_dis

        enc_optimizer.zero_grad()
        loss.backward()
        enc_optimizer.step()

        model.dis.load_state_dict(dis_weights, strict=False)

        ## Train discriminator
        set_requires_grad(model.source_encoder, True)
        set_requires_grad(model.clf, True)
        set_requires_grad(model.dis, True)

        # save encoder and clf weights
        source_encoder_weights = deepcopy(model.source_encoder.state_dict())
        clf_weights = deepcopy(model.clf.state_dict())

        new_clf_weights = {}
        for k in clf_weights:
            if "num_batches_tracked" not in k:
                new_clf_weights[k] = clf_weights[k]

        clf_weights = new_clf_weights

        new_source_encoder_weights = {}
        for k in source_encoder_weights:
            if "num_batches_tracked" not in k:
                new_source_encoder_weights[k] = source_encoder_weights[k]

        source_encoder_weights = new_source_encoder_weights

        emb = model.source_encoder(x_d).view(x_d.shape[0], -1)
        y_pred = model.dis(emb)

        # we use the real domain labels to train discriminator
        loss = criterion_dis(y_pred, y_dis)

        dis_optimizer.zero_grad()
        loss.backward()
        dis_optimizer.step()

        model.clf.load_state_dict(clf_weights, strict=False)
        model.source_encoder.load_state_dict(source_encoder_weights, strict=False)

        # Save checkpoint every 100
        if iters % 100 == 99 or iters >= N_ITER - 1:
            torch.save(checkpoint, os.path.join(save_folder, f"checkpt{iters}.pth"))

            model.eval()
            source_loss = compute_acc(dataloader_source_train_eval, model)
            _, dis_accu = compute_acc_dis(
                dataloader_source_train_eval, dataloader_target_train_eval, model
            )

            # Print the results
            print(
                "iter:",
                iters,
                "source loss:",
                round(source_loss, 6),
                "dis accu:",
                round(dis_accu, 6),
            )

        outer.update(1)

    torch.save(checkpoint, os.path.join(save_folder, f"final_model.pth"))


In [18]:
# st_sample_id_l = [SAMPLE_ID_N]


In [19]:
if TRAIN_USING_ALL_ST_SAMPLES:
    print(f"Adversarial training for all ST slides")
    save_folder = advtrain_folder

    best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
    model = best_checkpoint["model"]
    model.to(device)
    model.advtraining()

    train_adversarial(
        model,
        save_folder,
        sc_mix_d["train"],
        lab_mix_d["train"],
        mat_sp_train_s,
        dataloader_source_train,
        dataloader_target_train,
    )

else:
    for sample_id in st_sample_id_l:
        print(f"Adversarial training for ST slide {sample_id}: ")

        save_folder = os.path.join(advtrain_folder, sample_id)
        if not os.path.isdir(save_folder):
            os.makedirs(save_folder)

        best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
        model = best_checkpoint["model"]
        model.to(device)
        model.advtraining()

        train_adversarial(
            model,
            save_folder,
            sc_mix_d["train"],
            lab_mix_d["train"],
            mat_sp_d["train"][sample_id],
            dataloader_source_train,
            dataloader_target_train_d[sample_id],
        )


Adversarial training for ST slide 151509: 
Start adversarial training...


Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 5.329889 dis accu: 0.831296
iter: 199 source loss: 2.493377 dis accu: 0.809512
iter: 299 source loss: 2.24339 dis accu: 0.822421
iter: 399 source loss: 1.727972 dis accu: 0.832748
iter: 499 source loss: 1.399806 dis accu: 0.871798
iter: 599 source loss: 1.563935 dis accu: 0.81996
iter: 699 source loss: 1.272468 dis accu: 0.808181
iter: 799 source loss: 1.92189 dis accu: 0.276615
iter: 899 source loss: 1.336929 dis accu: 0.723224
iter: 999 source loss: 1.197243 dis accu: 0.319618
iter: 1099 source loss: 1.114381 dis accu: 0.488765
iter: 1199 source loss: 1.035202 dis accu: 0.647787
iter: 1299 source loss: 1.01229 dis accu: 0.770906
iter: 1399 source loss: 0.994731 dis accu: 0.68248
iter: 1499 source loss: 0.946541 dis accu: 0.733753
iter: 1599 source loss: 0.91165 dis accu: 0.583646
iter: 1699 source loss: 0.881743 dis accu: 0.528581
iter: 1799 source loss: 0.850467 dis accu: 0.537053
iter: 1899 source loss: 0.97951 dis accu: 0.734439
iter: 1999 source loss: 0.8901

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 5.770964 dis accu: 0.188073
iter: 199 source loss: 4.385024 dis accu: 0.188114
iter: 299 source loss: 1.801786 dis accu: 0.839368
iter: 399 source loss: 1.650552 dis accu: 0.973695
iter: 499 source loss: 1.617005 dis accu: 0.933994
iter: 599 source loss: 1.440305 dis accu: 0.994073
iter: 699 source loss: 1.147237 dis accu: 0.998376
iter: 799 source loss: 1.21829 dis accu: 0.151539
iter: 899 source loss: 1.005714 dis accu: 0.903548
iter: 999 source loss: 1.027547 dis accu: 0.884793
iter: 1099 source loss: 0.98729 dis accu: 0.597508
iter: 1199 source loss: 0.921883 dis accu: 0.530527
iter: 1299 source loss: 0.880969 dis accu: 0.558253
iter: 1399 source loss: 0.826557 dis accu: 0.72063
iter: 1499 source loss: 0.852296 dis accu: 0.638224
iter: 1599 source loss: 0.857951 dis accu: 0.707356
iter: 1699 source loss: 0.875367 dis accu: 0.827799
iter: 1799 source loss: 0.807194 dis accu: 0.375457
iter: 1899 source loss: 0.744437 dis accu: 0.234513
iter: 1999 source loss: 0.

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 4.223671 dis accu: 0.829739
iter: 199 source loss: 1.597801 dis accu: 0.832974
iter: 299 source loss: 1.8905 dis accu: 0.835421
iter: 399 source loss: 2.251702 dis accu: 0.830859
iter: 499 source loss: 1.538765 dis accu: 0.838698
iter: 599 source loss: 1.609873 dis accu: 0.845292
iter: 699 source loss: 1.465553 dis accu: 0.835172
iter: 799 source loss: 1.30905 dis accu: 0.835462
iter: 899 source loss: 1.299996 dis accu: 0.833555
iter: 999 source loss: 1.250536 dis accu: 0.951597
iter: 1099 source loss: 1.586589 dis accu: 0.743343
iter: 1199 source loss: 1.134781 dis accu: 0.850145
iter: 1299 source loss: 1.129024 dis accu: 0.834094
iter: 1399 source loss: 1.096065 dis accu: 0.732518
iter: 1499 source loss: 1.062745 dis accu: 0.645334
iter: 1599 source loss: 1.062451 dis accu: 0.446786
iter: 1699 source loss: 1.011898 dis accu: 0.53061
iter: 1799 source loss: 0.990317 dis accu: 0.381916
iter: 1899 source loss: 0.959217 dis accu: 0.613314
iter: 1999 source loss: 0.9

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 6.273739 dis accu: 0.180036
iter: 199 source loss: 4.396993 dis accu: 0.17979
iter: 299 source loss: 2.014038 dis accu: 0.254716
iter: 399 source loss: 1.889748 dis accu: 0.179954
iter: 499 source loss: 1.52598 dis accu: 0.844242
iter: 599 source loss: 1.76749 dis accu: 0.844816
iter: 699 source loss: 1.32588 dis accu: 0.63513
iter: 799 source loss: 1.17573 dis accu: 0.470103
iter: 899 source loss: 1.116491 dis accu: 0.620612
iter: 999 source loss: 1.009376 dis accu: 0.738517
iter: 1099 source loss: 0.962529 dis accu: 0.761852
iter: 1199 source loss: 0.922027 dis accu: 0.714977
iter: 1299 source loss: 0.886619 dis accu: 0.645669
iter: 1399 source loss: 0.874757 dis accu: 0.659203
iter: 1499 source loss: 0.805668 dis accu: 0.874303
iter: 1599 source loss: 0.839437 dis accu: 0.295399
iter: 1699 source loss: 0.847576 dis accu: 0.250656
iter: 1799 source loss: 0.809087 dis accu: 0.069267
iter: 1899 source loss: 0.813877 dis accu: 0.350312
iter: 1999 source loss: 0.799

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 5.880369 dis accu: 0.148864
iter: 199 source loss: 3.81105 dis accu: 0.148864
iter: 299 source loss: 2.37056 dis accu: 0.944378
iter: 399 source loss: 1.613185 dis accu: 0.979913
iter: 499 source loss: 1.448887 dis accu: 0.937739
iter: 599 source loss: 1.54659 dis accu: 0.98319
iter: 699 source loss: 1.296325 dis accu: 0.723253
iter: 799 source loss: 1.113409 dis accu: 0.754149
iter: 899 source loss: 1.103869 dis accu: 0.920972
iter: 999 source loss: 1.01698 dis accu: 0.752149
iter: 1099 source loss: 0.990372 dis accu: 0.631798
iter: 1199 source loss: 0.937003 dis accu: 0.640565
iter: 1299 source loss: 0.984528 dis accu: 0.86135
iter: 1399 source loss: 0.873027 dis accu: 0.416291
iter: 1499 source loss: 0.841154 dis accu: 0.522385
iter: 1599 source loss: 0.783788 dis accu: 0.340114
iter: 1699 source loss: 0.76382 dis accu: 0.742233
iter: 1799 source loss: 0.767339 dis accu: 0.789727
iter: 1899 source loss: 0.733168 dis accu: 0.586433
iter: 1999 source loss: 0.6589

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 4.958266 dis accu: 0.851853
iter: 199 source loss: 2.746567 dis accu: 0.572608
iter: 299 source loss: 2.677324 dis accu: 0.954966
iter: 399 source loss: 2.139 dis accu: 0.175638
iter: 499 source loss: 2.174251 dis accu: 0.206142
iter: 599 source loss: 1.477738 dis accu: 0.8611
iter: 699 source loss: 1.416895 dis accu: 0.870924
iter: 799 source loss: 1.236058 dis accu: 0.796128
iter: 899 source loss: 1.0784 dis accu: 0.949104
iter: 999 source loss: 1.048647 dis accu: 0.877198
iter: 1099 source loss: 1.036816 dis accu: 0.718608
iter: 1199 source loss: 1.032866 dis accu: 0.174152
iter: 1299 source loss: 0.935641 dis accu: 0.369727
iter: 1399 source loss: 0.906779 dis accu: 0.353463
iter: 1499 source loss: 0.849789 dis accu: 0.513869
iter: 1599 source loss: 0.882661 dis accu: 0.197639
iter: 1699 source loss: 0.847478 dis accu: 0.450425
iter: 1799 source loss: 0.84184 dis accu: 0.919756
iter: 1899 source loss: 0.760778 dis accu: 0.572567
iter: 1999 source loss: 0.73001

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 5.602354 dis accu: 0.16842
iter: 199 source loss: 4.168841 dis accu: 0.155156
iter: 299 source loss: 2.674257 dis accu: 0.727453
iter: 399 source loss: 1.592845 dis accu: 0.156169
iter: 499 source loss: 1.498531 dis accu: 0.845098
iter: 599 source loss: 1.590232 dis accu: 0.800955
iter: 699 source loss: 1.246642 dis accu: 0.851138
iter: 799 source loss: 1.18041 dis accu: 0.282558
iter: 899 source loss: 1.069776 dis accu: 0.418494
iter: 999 source loss: 0.994856 dis accu: 0.880877
iter: 1099 source loss: 0.99097 dis accu: 0.539137
iter: 1199 source loss: 0.888884 dis accu: 0.628395
iter: 1299 source loss: 0.845087 dis accu: 0.851603
iter: 1399 source loss: 0.825884 dis accu: 0.592447
iter: 1499 source loss: 0.829222 dis accu: 0.190597
iter: 1599 source loss: 0.861523 dis accu: 0.842437
iter: 1699 source loss: 0.836888 dis accu: 0.95083
iter: 1799 source loss: 0.862771 dis accu: 0.319731
iter: 1899 source loss: 0.79398 dis accu: 0.911798
iter: 1999 source loss: 0.82

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 5.982381 dis accu: 0.151577
iter: 199 source loss: 3.611098 dis accu: 0.147869
iter: 299 source loss: 2.503056 dis accu: 0.147485
iter: 399 source loss: 1.714815 dis accu: 0.893691
iter: 499 source loss: 2.10442 dis accu: 0.52191
iter: 599 source loss: 1.578095 dis accu: 0.772208
iter: 699 source loss: 1.105151 dis accu: 0.958653
iter: 799 source loss: 1.06441 dis accu: 0.833419
iter: 899 source loss: 1.033318 dis accu: 0.254561
iter: 999 source loss: 1.006135 dis accu: 0.281202
iter: 1099 source loss: 0.912237 dis accu: 0.602984
iter: 1199 source loss: 0.905776 dis accu: 0.491006
iter: 1299 source loss: 0.885401 dis accu: 0.17954
iter: 1399 source loss: 0.857631 dis accu: 0.577067
iter: 1499 source loss: 0.821352 dis accu: 0.676684
iter: 1599 source loss: 0.837948 dis accu: 0.731586
iter: 1699 source loss: 0.753309 dis accu: 0.318585
iter: 1799 source loss: 0.797195 dis accu: 0.337766
iter: 1899 source loss: 0.908319 dis accu: 0.432609
iter: 1999 source loss: 0.8

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 4.818593 dis accu: 0.152255
iter: 199 source loss: 3.179957 dis accu: 0.961216
iter: 299 source loss: 1.654203 dis accu: 0.97567
iter: 399 source loss: 1.695862 dis accu: 0.152255
iter: 499 source loss: 2.311325 dis accu: 0.152255
iter: 599 source loss: 1.556799 dis accu: 0.450746
iter: 699 source loss: 1.349708 dis accu: 0.970965
iter: 799 source loss: 1.329499 dis accu: 0.580154
iter: 899 source loss: 1.27181 dis accu: 0.879239
iter: 999 source loss: 1.142508 dis accu: 0.407341
iter: 1099 source loss: 1.037918 dis accu: 0.601772
iter: 1199 source loss: 0.972496 dis accu: 0.849356
iter: 1299 source loss: 0.968495 dis accu: 0.769456
iter: 1399 source loss: 0.94617 dis accu: 0.686716
iter: 1499 source loss: 0.881424 dis accu: 0.75462
iter: 1599 source loss: 0.885129 dis accu: 0.526322
iter: 1699 source loss: 0.841185 dis accu: 0.762504
iter: 1799 source loss: 0.858134 dis accu: 0.813793
iter: 1899 source loss: 0.794863 dis accu: 0.725161
iter: 1999 source loss: 0.7

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 4.5874 dis accu: 0.153941
iter: 199 source loss: 3.175754 dis accu: 0.153941
iter: 299 source loss: 2.037589 dis accu: 0.154998
iter: 399 source loss: 1.536415 dis accu: 0.846271
iter: 499 source loss: 1.368904 dis accu: 0.953974
iter: 599 source loss: 1.543801 dis accu: 0.736706
iter: 699 source loss: 1.105701 dis accu: 0.933331
iter: 799 source loss: 1.102988 dis accu: 0.929481
iter: 899 source loss: 1.066866 dis accu: 0.600322
iter: 999 source loss: 1.005063 dis accu: 0.326875
iter: 1099 source loss: 0.968182 dis accu: 0.661872
iter: 1199 source loss: 0.942358 dis accu: 0.312196
iter: 1299 source loss: 0.874894 dis accu: 0.727019
iter: 1399 source loss: 0.945905 dis accu: 0.500275
iter: 1499 source loss: 0.814914 dis accu: 0.406193
iter: 1599 source loss: 0.725077 dis accu: 0.750624
iter: 1699 source loss: 0.744203 dis accu: 0.457972
iter: 1799 source loss: 0.720768 dis accu: 0.589788
iter: 1899 source loss: 0.63961 dis accu: 0.524134
iter: 1999 source loss: 0.

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 7.499743 dis accu: 0.167187
iter: 199 source loss: 1.746726 dis accu: 0.879034
iter: 299 source loss: 1.864487 dis accu: 0.846679
iter: 399 source loss: 1.657976 dis accu: 0.837726
iter: 499 source loss: 1.647039 dis accu: 0.835728
iter: 599 source loss: 1.565954 dis accu: 0.837352
iter: 699 source loss: 1.434672 dis accu: 0.854674
iter: 799 source loss: 1.342416 dis accu: 0.848595
iter: 899 source loss: 1.538955 dis accu: 0.880866
iter: 999 source loss: 1.354575 dis accu: 0.833729
iter: 1099 source loss: 1.116913 dis accu: 0.856257
iter: 1199 source loss: 1.110342 dis accu: 0.979513
iter: 1299 source loss: 1.067139 dis accu: 0.901686
iter: 1399 source loss: 1.05278 dis accu: 0.213991
iter: 1499 source loss: 1.014133 dis accu: 0.830065
iter: 1599 source loss: 1.047283 dis accu: 0.656007
iter: 1699 source loss: 0.999851 dis accu: 0.987091
iter: 1799 source loss: 1.020417 dis accu: 0.256048
iter: 1899 source loss: 0.988247 dis accu: 0.975557
iter: 1999 source loss: 

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 4.526574 dis accu: 0.154685
iter: 199 source loss: 4.648291 dis accu: 0.98563
iter: 299 source loss: 1.783121 dis accu: 0.91467
iter: 399 source loss: 1.545641 dis accu: 0.7569
iter: 499 source loss: 1.591211 dis accu: 0.745911
iter: 599 source loss: 1.35349 dis accu: 0.99459
iter: 699 source loss: 1.205744 dis accu: 0.973754
iter: 799 source loss: 1.26488 dis accu: 0.14209
iter: 899 source loss: 1.06863 dis accu: 0.58924
iter: 999 source loss: 0.96142 dis accu: 0.790753
iter: 1099 source loss: 0.935498 dis accu: 0.725244
iter: 1199 source loss: 0.953202 dis accu: 0.28735
iter: 1299 source loss: 0.850594 dis accu: 0.369046
iter: 1399 source loss: 0.920641 dis accu: 0.532564
iter: 1499 source loss: 0.868406 dis accu: 0.713706
iter: 1599 source loss: 0.784912 dis accu: 0.464013
iter: 1699 source loss: 0.737416 dis accu: 0.62136
iter: 1799 source loss: 0.704578 dis accu: 0.374202
iter: 1899 source loss: 0.684682 dis accu: 0.441909
iter: 1999 source loss: 0.694103 dis

In [20]:
# if TRAIN_USING_ALL_ST_SAMPLES:
#     best_checkpoint = torch.load(os.path.join(advtrain_folder, f"final_model.pth"))
# else:
#     best_checkpoint = torch.load(
#         os.path.join(advtrain_folder, SAMPLE_ID_N, f"final_model.pth")
#     )

# model = best_checkpoint["model"]
# model.to(device)

# model.eval()
# model.set_encoder("source")

# pred_mix = (
#     torch.exp(model(torch.Tensor(sc_mix_test_s).to(device)))
#     .detach()
#     .cpu()
#     .numpy()
# )

# cell_type_nums = sc_sub_dict.keys()
# nrows = ceil(len(cell_type_nums) / 5)

# line_kws = {"color": "tab:orange"}
# scatter_kws = {"s": 5}

# props = dict(facecolor="w", alpha=0.5)

# fig, ax = plt.subplots(
#     nrows,
#     5,
#     figsize=(25, 5 * nrows),
#     constrained_layout=True,
#     sharex=False,
#     sharey=True,
# )
# for i, visnum in enumerate(cell_type_nums):
#     sns.regplot(
#         x=pred_mix[:, visnum],
#         y=lab_mix_d["test"][:, visnum],
#         line_kws=line_kws,
#         scatter_kws=scatter_kws,
#         ax=ax.flat[i],
#     ).set_title(sc_sub_dict[visnum])

#     ax.flat[i].set_aspect("equal")
#     ax.flat[i].set_xlabel("Predicted Proportion")

#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("True Proportion")
#     else:
#         ax.flat[i].set_ylabel("")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     textstr = (
#         f"MSE: {mean_squared_error(pred_mix[:,visnum], lab_mix_d["test"][:,visnum]):.5f}"
#     )

#     # place a text box in upper left in axes coords
#     ax.flat[i].text(
#         0.95,
#         0.05,
#         textstr,
#         transform=ax.flat[i].transAxes,
#         verticalalignment="bottom",
#         horizontalalignment="right",
#         bbox=props,
#     )

# for i in range(len(cell_type_nums), nrows * 5):
#     ax.flat[i].axis("off")

# plt.show()
