 # ADDA for ST

In [1]:
import os
import datetime
from itertools import chain
from copy import deepcopy
import warnings

from tqdm.autonotebook import tqdm

import h5py
import pickle
import numpy as np

import torch
from torch import nn

from src.da_models.adda import ADDAST
from src.da_models.datasets import SpotDataset
from src.da_models.utils import set_requires_grad
from src.utils.dupstdout import DupStdout
from src.utils.data_loading import (
    load_spatial,
    load_sc,
    get_selected_dir,
    get_model_rel_path,
)
from src.utils.evaluation import format_iters

# datetime object containing current date and time
script_start_time = datetime.datetime.now().strftime("%Y-%m-%d_%Hh%Mm%S")


  from tqdm.autonotebook import tqdm


In [2]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
if device == "cpu":
    warnings.warn("Using CPU", stacklevel=2)


In [3]:
DATA_DIR = "data"
TRAIN_USING_ALL_ST_SAMPLES = False
N_MARKERS = 20
ALL_GENES = False

N_SPOTS = 20000
N_MIX = 8

ST_SPLIT = False

SCALER_NAME = "celldart"

SAMPLE_ID_N = "151673"

BATCH_SIZE = 512
NUM_WORKERS = 4
INITIAL_TRAIN_EPOCHS = 10

EARLY_STOP_CRIT = 100
MIN_EPOCHS = INITIAL_TRAIN_EPOCHS


EARLY_STOP_CRIT_ADV = 10
MIN_EPOCHS_ADV = 10


MODEL_NAME = "CellDART"
MODEL_VERSION = "celldart1_bnfix"

celldart_kwargs = {
    "emb_dim": 64,
    "bn_momentum": 0.01,
}


In [4]:
N_ITER = 3000
ALPHA_LR = 5
ALPHA = 0.6


In [5]:
model_folder = get_model_rel_path(
    MODEL_NAME,
    MODEL_VERSION,
    scaler_name=SCALER_NAME,
    n_markers=N_MARKERS,
    all_genes=ALL_GENES,
    n_mix=N_MIX,
    n_spots=N_SPOTS,
    st_split=ST_SPLIT,
)
model_folder = os.path.join("model", model_folder)

if not os.path.isdir(model_folder):
    os.makedirs(model_folder)
    print(model_folder)


 # Data load


In [6]:
# Load spatial data
mat_sp_d, mat_sp_train, st_sample_id_l = load_spatial(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    train_using_all_st_samples=TRAIN_USING_ALL_ST_SAMPLES,
    st_split=ST_SPLIT,
)

# Load sc data
sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = load_sc(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    n_mix=N_MIX,
    n_spots=N_SPOTS,
)



 # Training: Adversarial domain adaptation for cell fraction estimation

 ## Prepare dataloaders

In [7]:
### source dataloaders
source_train_set = SpotDataset(
    deepcopy(sc_mix_d["train"]), deepcopy(lab_mix_d["train"])
)
source_val_set = SpotDataset(deepcopy(sc_mix_d["val"]), deepcopy(lab_mix_d["val"]))
source_test_set = SpotDataset(deepcopy(sc_mix_d["test"]), deepcopy(lab_mix_d["test"]))

dataloader_source_train = torch.utils.data.DataLoader(
    source_train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
dataloader_source_val = torch.utils.data.DataLoader(
    source_val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)
dataloader_source_test = torch.utils.data.DataLoader(
    source_test_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)

# ### target dataloaders
# target_test_set_d = {}
# for sample_id in st_sample_id_l:
#     target_test_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d["test"][sample_id]))

# dataloader_target_test_d = {}
# for sample_id in st_sample_id_l:
#     dataloader_target_test_d[sample_id] = torch.utils.data.DataLoader(
#         target_test_set_d[sample_id],
#         batch_size=BATCH_SIZE,
#         shuffle=False,
#         num_workers=NUM_WORKERS,
#         pin_memory=False,
#     )

# if TRAIN_USING_ALL_ST_SAMPLES:
#     target_train_set = SpotDataset(deepcopy(mat_sp_train_s))
#     dataloader_target_train = torch.utils.data.DataLoader(
#         target_train_set,
#         batch_size=BATCH_SIZE,
#         shuffle=True,
#         num_workers=NUM_WORKERS,
#         pin_memory=True,
#     )
# else:
#     target_train_set_d = {}
#     dataloader_target_train_d = {}
#     for sample_id in st_sample_id_l:
#         target_train_set_d[sample_id] = SpotDataset(
#             deepcopy(mat_sp_d["train"][sample_id])
#         )
#         dataloader_target_train_d[sample_id] = torch.utils.data.DataLoader(
#             target_train_set_d[sample_id],
#             batch_size=BATCH_SIZE,
#             shuffle=True,
#             num_workers=NUM_WORKERS,
#             pin_memory=True,
#         )

# if TRAIN_USING_ALL_ST_SAMPLES:
#     target_train_set = SpotDataset(deepcopy(mat_sp_train))
#     dataloader_target_train = torch.utils.data.DataLoader(
#         target_train_set,
#         batch_size=BATCH_SIZE,
#         shuffle=True,
#         num_workers=NUM_WORKERS,
#         pin_memory=True,
#     )

### target dataloaders
target_train_set_d = {}
dataloader_target_train_d = {}
if ST_SPLIT:
    target_val_set_d = {}
    target_test_set_d = {}

    dataloader_target_val_d = {}
    dataloader_target_test_d = {}
    for sample_id in st_sample_id_l:
        target_train_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d[sample_id]["train"]))
        target_val_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d[sample_id]["val"]))
        target_test_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d[sample_id]["test"]))

        dataloader_target_train_d[sample_id] = torch.utils.data.DataLoader(
            target_train_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )
        dataloader_target_val_d[sample_id] = torch.utils.data.DataLoader(
            target_val_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )
        dataloader_target_test_d[sample_id] = torch.utils.data.DataLoader(
            target_test_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )

else:
    target_test_set_d = {}
    dataloader_target_test_d = {}
    for sample_id in st_sample_id_l:
        target_train_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d[sample_id]["train"]))
        dataloader_target_train_d[sample_id] = torch.utils.data.DataLoader(
            target_train_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )

        target_test_set_d[sample_id] = SpotDataset(
            deepcopy(mat_sp_d[sample_id]["test"])
        )
        dataloader_target_test_d[sample_id] = torch.utils.data.DataLoader(
            target_test_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )


if TRAIN_USING_ALL_ST_SAMPLES:
    target_train_set = SpotDataset(mat_sp_train)
    dataloader_target_train = torch.utils.data.DataLoader(
        target_train_set,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    

 ## Define Model

In [8]:
model = ADDAST(
    inp_dim=sc_mix_d["train"].shape[1],
    ncls_source=lab_mix_d["train"].shape[1],
    **celldart_kwargs
)

## CellDART uses just one encoder!
model.target_encoder = model.source_encoder
model.to(device)


ADDAST(
  (source_encoder): MLPEncoder(
    (encoder): Sequential(
      (0): Linear(in_features=360, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): Linear(in_features=1024, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (5): ELU(alpha=1.0)
    )
  )
  (target_encoder): MLPEncoder(
    (encoder): Sequential(
      (0): Linear(in_features=360, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): Linear(in_features=1024, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (5): ELU(alpha=1.0)
    )
  )
  (dis): Discriminator(
    (head): Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1):

 ## Pretrain

In [9]:
pretrain_folder = os.path.join(model_folder, "pretrain")

if not os.path.isdir(pretrain_folder):
    os.makedirs(pretrain_folder)


In [10]:
pre_optimizer = torch.optim.Adam(
    model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-07
)

criterion_clf = nn.KLDivLoss(reduction="batchmean")


In [11]:
def model_loss(x, y_true, model):
    x = x.to(torch.float32).to(device)
    y_true = y_true.to(torch.float32).to(device)

    y_pred = model(x)

    loss = criterion_clf(y_pred, y_true)

    return loss


def compute_acc(dataloader, model):
    loss_running = []
    mean_weights = []
    model.eval()
    with torch.no_grad():
        for _, batch in enumerate(dataloader):

            loss = model_loss(*batch, model)

            loss_running.append(loss.item())

            # we will weight average by batch size later
            mean_weights.append(len(batch))

    return np.average(loss_running, weights=mean_weights)


In [12]:
model.pretraining()


In [13]:
# Initialize lists to store loss and accuracy values
loss_history = []
loss_history_val = []

loss_history_running = []

# Early Stopping
best_loss_val = np.inf
early_stop_count = 0
with DupStdout().dup_to_file(os.path.join(pretrain_folder, "log.txt"), "w") as f_log:
    # Train
    print("Start pretrain...")
    outer = tqdm(total=INITIAL_TRAIN_EPOCHS, desc="Epochs", position=0)
    inner = tqdm(total=len(dataloader_source_train), desc=f"Batch", position=1)

    checkpoint = {
        "epoch": -1,
        "model": model,
        "optimizer": pre_optimizer,
    }
    for epoch in range(INITIAL_TRAIN_EPOCHS):
        checkpoint["epoch"] = epoch

        # Train mode
        model.train()
        loss_running = []
        mean_weights = []

        inner.refresh()  # force print final state
        inner.reset()  # reuse bar
        for _, batch in enumerate(dataloader_source_train):

            pre_optimizer.zero_grad()
            loss = model_loss(*batch, model)
            loss_running.append(loss.item())
            mean_weights.append(len(batch))  # we will weight average by batch size later

            loss.backward()
            pre_optimizer.step()

            inner.update(1)

        loss_history.append(np.average(loss_running, weights=mean_weights))
        loss_history_running.append(loss_running)

        # Evaluate mode
        model.eval()
        with torch.no_grad():
            curr_loss_val = compute_acc(dataloader_source_val, model)
            loss_history_val.append(curr_loss_val)

        # Print the results
        outer.update(1)
        print(
            "epoch:",
            epoch,
            "train loss:",
            round(loss_history[-1], 6),
            "validation loss:",
            round(loss_history_val[-1], 6),
            end=" ",
        )

        # Save the best weights
        if curr_loss_val < best_loss_val:
            best_loss_val = curr_loss_val
            torch.save(checkpoint, os.path.join(pretrain_folder, f"best_model.pth"))
            early_stop_count = 0

            print("<-- new best val loss")
        else:
            print("")

        # Save checkpoint every 10
        if epoch % 10 == 0 or epoch >= INITIAL_TRAIN_EPOCHS - 1:
            torch.save(checkpoint, os.path.join(pretrain_folder, f"checkpt{epoch}.pth"))

        # check to see if validation loss has plateau'd
        if early_stop_count >= EARLY_STOP_CRIT and epoch >= MIN_EPOCHS - 1:
            print(f"Validation loss plateaued after {early_stop_count} at epoch {epoch}")
            torch.save(checkpoint, os.path.join(pretrain_folder, f"earlystop{epoch}.pth"))
            break

        early_stop_count += 1


    # Save final model
    torch.save(checkpoint, os.path.join(pretrain_folder, f"final_model.pth"))


Start pretrain...


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Batch:   0%|          | 0/40 [00:00<?, ?it/s]

epoch: 0 train loss: 1.191773 validation loss: 1.494227 <-- new best val loss
epoch: 1 train loss: 0.846174 validation loss: 1.264912 <-- new best val loss
epoch: 2 train loss: 0.681539 validation loss: 1.038583 <-- new best val loss
epoch: 3 train loss: 0.608706 validation loss: 0.962192 <-- new best val loss
epoch: 4 train loss: 0.577715 validation loss: 1.043455 
epoch: 5 train loss: 0.554479 validation loss: 0.888233 <-- new best val loss
epoch: 6 train loss: 0.538921 validation loss: 0.831037 <-- new best val loss
epoch: 7 train loss: 0.525295 validation loss: 0.775619 <-- new best val loss
epoch: 8 train loss: 0.511206 validation loss: 0.767794 <-- new best val loss
epoch: 9 train loss: 0.494654 validation loss: 0.807821 


 ## Adversarial Adaptation

In [14]:
advtrain_folder = os.path.join(model_folder, "advtrain")

if not os.path.isdir(advtrain_folder):
    os.makedirs(advtrain_folder)


In [15]:
def batch_generator(data, batch_size):
    """Generate batches of data.
    
    Given a list of numpy data, it iterates over the list and returns batches of
    the same size.
    """
    all_examples_indices = len(data[0])
    while True:
        mini_batch_indices = np.random.choice(
            all_examples_indices, size=batch_size, replace=False
        )
        tbr = [k[mini_batch_indices] for k in data]
        yield tbr


In [16]:
criterion_dis = nn.CrossEntropyLoss()


def discrim_loss_accu(x, y_dis, model):
    x = x.to(torch.float32).to(device)

    emb = model.source_encoder(x)
    y_pred = model.dis(emb)

    loss = criterion_dis(y_pred, y_dis)

    accu = torch.mean(
        (torch.flatten(torch.argmax(y_pred, dim=1)) == y_dis).to(torch.float32)
    ).cpu()

    return loss, accu


def compute_acc_dis(dataloader_source, dataloader_target, model):
    len_target = len(dataloader_target)
    len_source = len(dataloader_source)

    loss_running = []
    accu_running = []
    mean_weights = []
    model.eval()
    model.source_encoder.eval()
    model.dis.eval()
    with torch.no_grad():
        for y_val, dl in zip([1, 0], [dataloader_target, dataloader_source]):
            for _, (x, _) in enumerate(dl):

                y_dis = torch.full(
                    (x.shape[0],), y_val, device=device, dtype=torch.long
                )

                loss, accu = discrim_loss_accu(x, y_dis, model)

                accu_running.append(accu)
                loss_running.append(loss.item())

                # we will weight average by batch size later
                mean_weights.append(len(x))

    return (
        np.average(loss_running, weights=mean_weights),
        np.average(accu_running, weights=mean_weights),
    )


In [17]:
def train_adversarial(
    model,
    save_folder,
    sc_mix_train_s,
    lab_mix_train_s,
    mat_sp_train_s,
    dataloader_source_train_eval,
    dataloader_target_train_eval,
):

    model.to(device)
    model.advtraining()
    model.set_encoder("source")

    S_batches = batch_generator(
        [sc_mix_train_s.copy(), lab_mix_train_s.copy()], BATCH_SIZE
    )
    T_batches = batch_generator(
        [mat_sp_train_s.copy(), np.ones(shape=(len(mat_sp_train_s), 2))], BATCH_SIZE
    )

    enc_optimizer = torch.optim.Adam(
        chain(
            model.source_encoder.parameters(),
            model.dis.parameters(),
            model.clf.parameters(),
        ),
        lr=0.001,
        betas=(0.9, 0.999),
        eps=1e-07,
    )

    dis_optimizer = torch.optim.Adam(
        chain(model.source_encoder.parameters(), model.dis.parameters()),
        lr=ALPHA_LR * 0.001,
        betas=(0.9, 0.999),
        eps=1e-07,
    )

    # Initialize lists to store loss and accuracy values
    loss_history_running = []
    with DupStdout().dup_to_file(os.path.join(save_folder, "log.txt"), "w") as f_log:
        # Train
        print("Start adversarial training...")
        outer = tqdm(total=N_ITER, desc="Iterations", position=0)

        checkpoint = {
            "epoch": -1,
            "model": model,
            "dis_optimizer": dis_optimizer,
            "enc_optimizer": enc_optimizer,
        }
        for iters in range(N_ITER):
            checkpoint["epoch"] = iters

            model.train()
            model.dis.train()
            model.clf.train()
            model.source_encoder.train()

            x_source, y_true = next(S_batches)
            x_target, _ = next(T_batches)

            ## Train encoder
            set_requires_grad(model.source_encoder, True)
            set_requires_grad(model.clf, True)
            set_requires_grad(model.dis, True)

            # save discriminator weights
            dis_weights = deepcopy(model.dis.state_dict())
            new_dis_weights = {}
            for k in dis_weights:
                if "num_batches_tracked" not in k:
                    new_dis_weights[k] = dis_weights[k]

            dis_weights = new_dis_weights

            x_source, x_target, y_true, = (
                torch.Tensor(x_source),
                torch.Tensor(x_target),
                torch.Tensor(y_true),
            )
            x_source, x_target, y_true, = (
                x_source.to(device),
                x_target.to(device),
                y_true.to(device),
            )

            x = torch.cat((x_source, x_target))

            # save for discriminator later
            x_d = x.detach()

            # y_dis is the REAL one
            y_dis = torch.cat(
                [
                    torch.zeros(x_source.shape[0], device=device, dtype=torch.long),
                    torch.ones(x_target.shape[0], device=device, dtype=torch.long),
                ]
            )
            y_dis_flipped = 1 - y_dis.detach()

            emb = model.source_encoder(x).view(x.shape[0], -1)

            y_dis_pred = model.dis(emb)
            y_clf_pred = model.clf(emb)

            # we use flipped because we want to confuse discriminator
            loss_dis = criterion_dis(y_dis_pred, y_dis_flipped)

            # Set true = predicted for target samples since we don't know what it is
            y_clf_true = torch.cat((y_true, y_clf_pred[-x_target.shape[0] :].detach()))
            # Loss fn does mean over all samples including target
            loss_clf = criterion_clf(y_clf_pred, y_clf_true)

            # Scale back up loss so mean doesn't include target
            loss = (x.shape[0] / x_source.shape[0]) * loss_clf + ALPHA * loss_dis

            # loss = loss_clf + ALPHA * loss_dis

            enc_optimizer.zero_grad()
            loss.backward()
            enc_optimizer.step()

            model.dis.load_state_dict(dis_weights, strict=False)

            ## Train discriminator
            set_requires_grad(model.source_encoder, True)
            set_requires_grad(model.clf, True)
            set_requires_grad(model.dis, True)

            # save encoder and clf weights
            source_encoder_weights = deepcopy(model.source_encoder.state_dict())
            clf_weights = deepcopy(model.clf.state_dict())

            new_clf_weights = {}
            for k in clf_weights:
                if "num_batches_tracked" not in k:
                    new_clf_weights[k] = clf_weights[k]

            clf_weights = new_clf_weights

            new_source_encoder_weights = {}
            for k in source_encoder_weights:
                if "num_batches_tracked" not in k:
                    new_source_encoder_weights[k] = source_encoder_weights[k]

            source_encoder_weights = new_source_encoder_weights

            emb = model.source_encoder(x_d).view(x_d.shape[0], -1)
            y_pred = model.dis(emb)

            # we use the real domain labels to train discriminator
            loss = criterion_dis(y_pred, y_dis)

            dis_optimizer.zero_grad()
            loss.backward()
            dis_optimizer.step()

            model.clf.load_state_dict(clf_weights, strict=False)
            model.source_encoder.load_state_dict(source_encoder_weights, strict=False)

            # Save checkpoint every 100
            if iters % 100 == 99 or iters >= N_ITER - 1:
                torch.save(checkpoint, os.path.join(save_folder, f"checkpt{iters}.pth"))

                model.eval()
                source_loss = compute_acc(dataloader_source_train_eval, model)
                _, dis_accu = compute_acc_dis(
                    dataloader_source_train_eval, dataloader_target_train_eval, model
                )

                # Print the results
                print(
                    "iter:",
                    iters,
                    "source loss:",
                    round(source_loss, 6),
                    "dis accu:",
                    round(dis_accu, 6),
                )

            outer.update(1)

    torch.save(checkpoint, os.path.join(save_folder, f"final_model.pth"))


In [18]:
# st_sample_id_l = [SAMPLE_ID_N]


In [19]:
if TRAIN_USING_ALL_ST_SAMPLES:
    print(f"Adversarial training for all ST slides")
    save_folder = advtrain_folder

    best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
    model = best_checkpoint["model"]
    model.to(device)
    model.advtraining()

    train_adversarial(
        model,
        save_folder,
        sc_mix_d["train"],
        lab_mix_d["train"],
        mat_sp_train,
        dataloader_source_train,
        dataloader_target_train,
    )

else:
    for sample_id in st_sample_id_l:
        print(f"Adversarial training for ST slide {sample_id}: ")

        save_folder = os.path.join(advtrain_folder, sample_id)
        if not os.path.isdir(save_folder):
            os.makedirs(save_folder)

        best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
        model = best_checkpoint["model"]
        model.to(device)
        model.advtraining()

        train_adversarial(
            model,
            save_folder,
            sc_mix_d["train"],
            lab_mix_d["train"],
            mat_sp_d[sample_id]["train"],
            dataloader_source_train,
            dataloader_target_train_d[sample_id],
        )


Adversarial training for ST slide 151507: 
Start adversarial training...


Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.901114 dis accu: 0.174441
iter: 199 source loss: 2.55145 dis accu: 0.174441
iter: 299 source loss: 1.346753 dis accu: 0.126476
iter: 399 source loss: 1.09311 dis accu: 0.540246
iter: 499 source loss: 1.083525 dis accu: 0.520185
iter: 599 source loss: 1.544477 dis accu: 0.172666
iter: 699 source loss: 1.058488 dis accu: 0.174441
iter: 799 source loss: 0.856215 dis accu: 0.841575
iter: 899 source loss: 0.977106 dis accu: 0.175184
iter: 999 source loss: 0.946962 dis accu: 0.174441
iter: 1099 source loss: 1.078436 dis accu: 0.174441
iter: 1199 source loss: 1.467818 dis accu: 0.174441
iter: 1299 source loss: 0.969801 dis accu: 0.994221
iter: 1399 source loss: 0.823752 dis accu: 0.999298
iter: 1499 source loss: 0.769578 dis accu: 0.996285
iter: 1599 source loss: 1.03458 dis accu: 0.999876
iter: 1699 source loss: 0.849736 dis accu: 0.660241
iter: 1799 source loss: 0.869032 dis accu: 0.967886
iter: 1899 source loss: 0.910772 dis accu: 0.442376
iter: 1999 source loss: 0.

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 3.18651 dis accu: 0.17979
iter: 199 source loss: 2.162461 dis accu: 0.179708
iter: 299 source loss: 1.471877 dis accu: 0.003158
iter: 399 source loss: 2.340784 dis accu: 0.006972
iter: 499 source loss: 0.906347 dis accu: 0.056225
iter: 599 source loss: 1.10008 dis accu: 0.823327
iter: 699 source loss: 0.987241 dis accu: 0.64083
iter: 799 source loss: 0.900314 dis accu: 0.179257
iter: 899 source loss: 0.855083 dis accu: 0.287197
iter: 999 source loss: 0.954081 dis accu: 0.261811
iter: 1099 source loss: 1.156461 dis accu: 0.839526
iter: 1199 source loss: 0.781923 dis accu: 0.993151
iter: 1299 source loss: 0.8485 dis accu: 0.156414
iter: 1399 source loss: 0.703679 dis accu: 0.240609
iter: 1499 source loss: 1.016651 dis accu: 0.136934
iter: 1599 source loss: 0.700056 dis accu: 0.941027
iter: 1699 source loss: 0.987151 dis accu: 0.045276
iter: 1799 source loss: 0.787172 dis accu: 0.163017
iter: 1899 source loss: 0.725481 dis accu: 0.478305
iter: 1999 source loss: 1.113

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.088117 dis accu: 0.193191
iter: 199 source loss: 1.871674 dis accu: 0.193191
iter: 299 source loss: 1.67594 dis accu: 0.193191
iter: 399 source loss: 1.388784 dis accu: 0.008552
iter: 499 source loss: 1.355279 dis accu: 0.192505
iter: 599 source loss: 0.927332 dis accu: 0.031224
iter: 699 source loss: 0.876851 dis accu: 1.0
iter: 799 source loss: 1.137174 dis accu: 0.185687
iter: 899 source loss: 1.050157 dis accu: 0.072411
iter: 999 source loss: 0.75729 dis accu: 1.0
iter: 1099 source loss: 1.171724 dis accu: 0.113276
iter: 1199 source loss: 0.879211 dis accu: 0.985276
iter: 1299 source loss: 0.823198 dis accu: 0.979184
iter: 1399 source loss: 1.087545 dis accu: 0.137319
iter: 1499 source loss: 0.771358 dis accu: 0.99762
iter: 1599 source loss: 1.42991 dis accu: 0.07225
iter: 1699 source loss: 0.762211 dis accu: 1.0
iter: 1799 source loss: 1.181169 dis accu: 0.882609
iter: 1899 source loss: 1.186175 dis accu: 0.300335
iter: 1999 source loss: 0.759336 dis accu: 

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.864358 dis accu: 0.188114
iter: 199 source loss: 1.232486 dis accu: 0.188114
iter: 299 source loss: 1.815542 dis accu: 0.187424
iter: 399 source loss: 1.174423 dis accu: 0.188114
iter: 499 source loss: 1.289071 dis accu: 0.03211
iter: 599 source loss: 1.131901 dis accu: 0.178128
iter: 699 source loss: 0.938189 dis accu: 0.188114
iter: 799 source loss: 0.883053 dis accu: 0.87091
iter: 899 source loss: 0.989606 dis accu: 0.378461
iter: 999 source loss: 0.987323 dis accu: 0.187708
iter: 1099 source loss: 0.847464 dis accu: 0.982504
iter: 1199 source loss: 0.882684 dis accu: 0.117561
iter: 1299 source loss: 0.908201 dis accu: 0.995778
iter: 1399 source loss: 1.082566 dis accu: 0.162093
iter: 1499 source loss: 0.868711 dis accu: 0.972924
iter: 1599 source loss: 0.884467 dis accu: 0.130795
iter: 1699 source loss: 0.828007 dis accu: 1.0
iter: 1799 source loss: 1.005369 dis accu: 0.12844
iter: 1899 source loss: 0.673011 dis accu: 0.999756
iter: 1999 source loss: 0.82319

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.756627 dis accu: 0.153079
iter: 199 source loss: 2.308808 dis accu: 0.154727
iter: 299 source loss: 1.132058 dis accu: 0.583872
iter: 399 source loss: 1.340649 dis accu: 0.151219
iter: 499 source loss: 0.994406 dis accu: 0.999958
iter: 599 source loss: 1.032518 dis accu: 0.02764
iter: 699 source loss: 1.030172 dis accu: 0.845357
iter: 799 source loss: 1.066156 dis accu: 0.141161
iter: 899 source loss: 0.997199 dis accu: 0.155995
iter: 999 source loss: 0.820106 dis accu: 0.76979
iter: 1099 source loss: 0.892059 dis accu: 0.846498
iter: 1199 source loss: 1.006453 dis accu: 0.241494
iter: 1299 source loss: 0.822314 dis accu: 0.999873
iter: 1399 source loss: 0.86705 dis accu: 0.923925
iter: 1499 source loss: 0.986514 dis accu: 0.845273
iter: 1599 source loss: 0.917204 dis accu: 0.999324
iter: 1699 source loss: 0.872204 dis accu: 0.948819
iter: 1799 source loss: 0.823871 dis accu: 0.999789
iter: 1899 source loss: 0.91714 dis accu: 0.141118
iter: 1999 source loss: 0.7

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.086325 dis accu: 0.148864
iter: 199 source loss: 2.170891 dis accu: 0.148864
iter: 299 source loss: 1.31186 dis accu: 0.148864
iter: 399 source loss: 1.29529 dis accu: 0.012129
iter: 499 source loss: 1.095165 dis accu: 0.148651
iter: 599 source loss: 1.207594 dis accu: 0.148864
iter: 699 source loss: 0.949529 dis accu: 0.0
iter: 799 source loss: 0.839558 dis accu: 0.148864
iter: 899 source loss: 0.80704 dis accu: 0.0
iter: 999 source loss: 0.730774 dis accu: 0.032386
iter: 1099 source loss: 0.677497 dis accu: 0.624564
iter: 1199 source loss: 0.652228 dis accu: 0.998085
iter: 1299 source loss: 0.651552 dis accu: 0.949868
iter: 1399 source loss: 0.595578 dis accu: 0.005617
iter: 1499 source loss: 0.585707 dis accu: 0.443825
iter: 1599 source loss: 0.625767 dis accu: 0.343221
iter: 1699 source loss: 0.567345 dis accu: 0.941272
iter: 1799 source loss: 0.502041 dis accu: 0.795855
iter: 1899 source loss: 0.510021 dis accu: 0.159077
iter: 1999 source loss: 0.513635 dis

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.831828 dis accu: 0.170261
iter: 199 source loss: 2.516052 dis accu: 0.170469
iter: 299 source loss: 1.49067 dis accu: 0.917545
iter: 399 source loss: 1.237836 dis accu: 0.170469
iter: 499 source loss: 1.253386 dis accu: 0.170469
iter: 599 source loss: 1.099978 dis accu: 0.164579
iter: 699 source loss: 0.903678 dis accu: 0.988511
iter: 799 source loss: 0.925019 dis accu: 0.997387
iter: 899 source loss: 1.07308 dis accu: 0.163003
iter: 999 source loss: 0.792465 dis accu: 0.854168
iter: 1099 source loss: 1.190881 dis accu: 0.170552
iter: 1199 source loss: 0.833517 dis accu: 0.210411
iter: 1299 source loss: 0.874593 dis accu: 0.537536
iter: 1399 source loss: 0.795067 dis accu: 0.999959
iter: 1499 source loss: 1.437807 dis accu: 0.978764
iter: 1599 source loss: 0.875958 dis accu: 0.171091
iter: 1699 source loss: 0.853897 dis accu: 0.848652
iter: 1799 source loss: 0.998897 dis accu: 0.845541
iter: 1899 source loss: 0.796286 dis accu: 0.831315
iter: 1999 source loss: 1

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.371819 dis accu: 0.196835
iter: 199 source loss: 2.035225 dis accu: 0.167187
iter: 299 source loss: 1.398811 dis accu: 0.149032
iter: 399 source loss: 1.200216 dis accu: 0.166937
iter: 499 source loss: 1.037758 dis accu: 0.165688
iter: 599 source loss: 1.056965 dis accu: 0.167187
iter: 699 source loss: 0.969409 dis accu: 0.752655
iter: 799 source loss: 1.081769 dis accu: 0.057339
iter: 899 source loss: 0.934507 dis accu: 0.178305
iter: 999 source loss: 1.051774 dis accu: 0.163898
iter: 1099 source loss: 0.980958 dis accu: 0.167146
iter: 1199 source loss: 0.865609 dis accu: 0.779971
iter: 1299 source loss: 0.777756 dis accu: 0.88299
iter: 1399 source loss: 0.864095 dis accu: 0.155611
iter: 1499 source loss: 0.736114 dis accu: 0.937706
iter: 1599 source loss: 0.792013 dis accu: 0.572392
iter: 1699 source loss: 0.694331 dis accu: 0.997127
iter: 1799 source loss: 0.876579 dis accu: 0.152363
iter: 1899 source loss: 0.717604 dis accu: 0.999917
iter: 1999 source loss: 

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.802146 dis accu: 0.008207
iter: 199 source loss: 1.384901 dis accu: 0.153941
iter: 299 source loss: 1.21253 dis accu: 0.153941
iter: 399 source loss: 1.176651 dis accu: 0.017006
iter: 499 source loss: 1.151485 dis accu: 0.025889
iter: 599 source loss: 0.967511 dis accu: 0.131435
iter: 699 source loss: 0.776276 dis accu: 1.0
iter: 799 source loss: 1.194076 dis accu: 0.999535
iter: 899 source loss: 0.788542 dis accu: 0.179957
iter: 999 source loss: 0.740451 dis accu: 0.999365
iter: 1099 source loss: 0.887802 dis accu: 0.110495
iter: 1199 source loss: 0.723534 dis accu: 0.999915
iter: 1299 source loss: 0.833613 dis accu: 0.989805
iter: 1399 source loss: 0.926896 dis accu: 0.009687
iter: 1499 source loss: 0.817839 dis accu: 0.893946
iter: 1599 source loss: 0.66973 dis accu: 0.976226
iter: 1699 source loss: 0.839837 dis accu: 4.2e-05
iter: 1799 source loss: 0.70814 dis accu: 0.971361
iter: 1899 source loss: 0.885042 dis accu: 0.967934
iter: 1999 source loss: 0.790989

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.278919 dis accu: 0.616948
iter: 199 source loss: 1.489486 dis accu: 0.031301
iter: 299 source loss: 1.169131 dis accu: 0.0
iter: 399 source loss: 1.155053 dis accu: 0.147763
iter: 499 source loss: 0.968496 dis accu: 0.131373
iter: 599 source loss: 0.837983 dis accu: 0.139611
iter: 699 source loss: 0.841676 dis accu: 0.155156
iter: 799 source loss: 0.799772 dis accu: 0.036286
iter: 899 source loss: 0.753992 dis accu: 0.972162
iter: 999 source loss: 0.802711 dis accu: 1.0
iter: 1099 source loss: 0.956698 dis accu: 0.929836
iter: 1199 source loss: 0.783592 dis accu: 0.244667
iter: 1299 source loss: 0.68296 dis accu: 0.549022
iter: 1399 source loss: 0.701368 dis accu: 0.994424
iter: 1499 source loss: 0.890631 dis accu: 0.380264
iter: 1599 source loss: 0.637186 dis accu: 1.0
iter: 1699 source loss: 0.961161 dis accu: 0.0
iter: 1799 source loss: 0.662279 dis accu: 1.0
iter: 1899 source loss: 0.72475 dis accu: 0.998564
iter: 1999 source loss: 0.858659 dis accu: 0.13344

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 3.051417 dis accu: 0.152255
iter: 199 source loss: 1.1248 dis accu: 0.153823
iter: 299 source loss: 1.435042 dis accu: 0.151238
iter: 399 source loss: 1.072963 dis accu: 0.155434
iter: 499 source loss: 1.081445 dis accu: 0.761784
iter: 599 source loss: 1.0012 dis accu: 0.100712
iter: 699 source loss: 0.906696 dis accu: 1.0
iter: 799 source loss: 0.903802 dis accu: 0.152255
iter: 899 source loss: 1.230964 dis accu: 0.152255
iter: 999 source loss: 1.230209 dis accu: 0.937436
iter: 1099 source loss: 1.414929 dis accu: 0.152255
iter: 1199 source loss: 0.982625 dis accu: 0.101094
iter: 1299 source loss: 0.902011 dis accu: 0.866311
iter: 1399 source loss: 0.904559 dis accu: 0.191378
iter: 1499 source loss: 0.811179 dis accu: 0.994405
iter: 1599 source loss: 0.922483 dis accu: 0.216048
iter: 1699 source loss: 0.814574 dis accu: 0.992455
iter: 1799 source loss: 0.899169 dis accu: 0.45011
iter: 1899 source loss: 0.817076 dis accu: 0.939344
iter: 1999 source loss: 0.719171 

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 2.043903 dis accu: 0.035166
iter: 199 source loss: 1.473627 dis accu: 0.147485
iter: 299 source loss: 1.224874 dis accu: 0.147485
iter: 399 source loss: 1.281021 dis accu: 0.147485
iter: 499 source loss: 1.077772 dis accu: 0.667945
iter: 599 source loss: 1.649837 dis accu: 0.021441
iter: 699 source loss: 1.020853 dis accu: 0.110315
iter: 799 source loss: 0.782179 dis accu: 1.0
iter: 899 source loss: 1.069934 dis accu: 0.025746
iter: 999 source loss: 0.904697 dis accu: 0.115516
iter: 1099 source loss: 0.797595 dis accu: 0.341986
iter: 1199 source loss: 0.879623 dis accu: 0.744842
iter: 1299 source loss: 0.907659 dis accu: 8.5e-05
iter: 1399 source loss: 0.791602 dis accu: 0.135294
iter: 1499 source loss: 0.742956 dis accu: 0.863512
iter: 1599 source loss: 1.057359 dis accu: 0.859335
iter: 1699 source loss: 0.902263 dis accu: 0.123743
iter: 1799 source loss: 0.769719 dis accu: 0.170673
iter: 1899 source loss: 0.939003 dis accu: 0.956436
iter: 1999 source loss: 1.065

In [20]:
# if TRAIN_USING_ALL_ST_SAMPLES:
#     best_checkpoint = torch.load(os.path.join(advtrain_folder, f"final_model.pth"))
# else:
#     best_checkpoint = torch.load(
#         os.path.join(advtrain_folder, SAMPLE_ID_N, f"final_model.pth")
#     )

# model = best_checkpoint["model"]
# model.to(device)

# model.eval()
# model.set_encoder("source")

# pred_mix = (
#     torch.exp(model(torch.Tensor(sc_mix_test_s).to(device)))
#     .detach()
#     .cpu()
#     .numpy()
# )

# cell_type_nums = sc_sub_dict.keys()
# nrows = ceil(len(cell_type_nums) / 5)

# line_kws = {"color": "tab:orange"}
# scatter_kws = {"s": 5}

# props = dict(facecolor="w", alpha=0.5)

# fig, ax = plt.subplots(
#     nrows,
#     5,
#     figsize=(25, 5 * nrows),
#     constrained_layout=True,
#     sharex=False,
#     sharey=True,
# )
# for i, visnum in enumerate(cell_type_nums):
#     sns.regplot(
#         x=pred_mix[:, visnum],
#         y=lab_mix_d["test"][:, visnum],
#         line_kws=line_kws,
#         scatter_kws=scatter_kws,
#         ax=ax.flat[i],
#     ).set_title(sc_sub_dict[visnum])

#     ax.flat[i].set_aspect("equal")
#     ax.flat[i].set_xlabel("Predicted Proportion")

#     if i % 5 == 0:
#         ax.flat[i].set_ylabel("True Proportion")
#     else:
#         ax.flat[i].set_ylabel("")
#     ax.flat[i].set_xlim([0, 1])
#     ax.flat[i].set_ylim([0, 1])

#     textstr = (
#         f"MSE: {mean_squared_error(pred_mix[:,visnum], lab_mix_d["test"][:,visnum]):.5f}"
#     )

#     # place a text box in upper left in axes coords
#     ax.flat[i].text(
#         0.95,
#         0.05,
#         textstr,
#         transform=ax.flat[i].transAxes,
#         verticalalignment="bottom",
#         horizontalalignment="right",
#         bbox=props,
#     )

# for i in range(len(cell_type_nums), nrows * 5):
#     ax.flat[i].axis("off")

# plt.show()
