 # ADDA for ST

In [1]:
import os
import datetime
from itertools import chain
from copy import deepcopy
import warnings

from tqdm.autonotebook import tqdm

import h5py
import pickle
import numpy as np

import torch
from torch import nn

from src.da_models.adda import ADDAST
from src.da_models.datasets import SpotDataset
from src.da_models.utils import set_requires_grad
from src.utils.dupstdout import DupStdout
from src.utils.data_loading import (
    load_spatial,
    load_sc,
    get_selected_dir,
    get_model_rel_path,
)
from src.utils.evaluation import format_iters

script_start_time = datetime.datetime.now(datetime.timezone.utc)


  from tqdm.autonotebook import tqdm


In [2]:
TORCH_MANUAL_SEED = 1205
CUDA_DEVICE = 0


In [3]:
# Data path and parameters
DATA_DIR = "data"
TRAIN_USING_ALL_ST_SAMPLES = False
N_MARKERS = 20
ALL_GENES = False

# Pseudo-spot parameters
N_SPOTS = 20000
N_MIX = 8

# ST spot parameters
ST_SPLIT = False
SAMPLE_ID_N = "151673"

# Scaler parameter
SCALER_NAME = "celldart"


In [4]:
# Model parameters
MODEL_NAME = "CellDART"
MODEL_VERSION = "celldart1_bnfix"

celldart_kwargs = {
    "emb_dim": 64,
    "bn_momentum": 0.01,
}


In [5]:
BATCH_SIZE = 512
NUM_WORKERS = 16
# Pretraining parameters
# SAMPLE_ID_N = "151673"

INITIAL_TRAIN_EPOCHS = 10

EARLY_STOP_CRIT = 100
MIN_EPOCHS = INITIAL_TRAIN_EPOCHS

# Adversarial training parameters
EARLY_STOP_CRIT_ADV = 10
MIN_EPOCHS_ADV = 10

N_ITER = 3000
ALPHA_LR = 5
ALPHA = 0.6


In [6]:
if CUDA_DEVICE:
    device = torch.device(f"cuda:{CUDA_DEVICE}" if torch.cuda.is_available() else "cpu")
else:
    device = torch.device(f"cuda" if torch.cuda.is_available() else "cpu")
if device == "cpu":
    warnings.warn("Using CPU", stacklevel=2)


In [7]:
if TORCH_MANUAL_SEED or TORCH_MANUAL_SEED == 0:
    torch_seed = TORCH_MANUAL_SEED
    torch_seed_path = str(TORCH_MANUAL_SEED)
else:
    torch_seed = int(script_start_time.timestamp())
    torch_seed_path = script_start_time.strftime("%Y-%m-%d_%Hh%Mm%Ss")

torch.manual_seed(torch_seed)
np.random.seed(torch_seed)


In [8]:
model_folder = get_model_rel_path(
    MODEL_NAME,
    MODEL_VERSION,
    scaler_name=SCALER_NAME,
    n_markers=N_MARKERS,
    all_genes=ALL_GENES,
    n_mix=N_MIX,
    n_spots=N_SPOTS,
    st_split=ST_SPLIT,
    torch_seed_path=torch_seed_path,
)
model_folder = os.path.join("model", model_folder)

if not os.path.isdir(model_folder):
    os.makedirs(model_folder)
    print(model_folder)


 # Data load


In [9]:
# Load spatial data
mat_sp_d, mat_sp_train, st_sample_id_l = load_spatial(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    train_using_all_st_samples=TRAIN_USING_ALL_ST_SAMPLES,
    st_split=ST_SPLIT,
)

# Load sc data
sc_mix_d, lab_mix_d, sc_sub_dict, sc_sub_dict2 = load_sc(
    get_selected_dir(DATA_DIR, N_MARKERS, ALL_GENES),
    SCALER_NAME,
    n_mix=N_MIX,
    n_spots=N_SPOTS,
)


 # Training: Adversarial domain adaptation for cell fraction estimation

 ## Prepare dataloaders

In [10]:
### source dataloaders
source_train_set = SpotDataset(
    deepcopy(sc_mix_d["train"]), deepcopy(lab_mix_d["train"])
)
source_val_set = SpotDataset(deepcopy(sc_mix_d["val"]), deepcopy(lab_mix_d["val"]))
source_test_set = SpotDataset(deepcopy(sc_mix_d["test"]), deepcopy(lab_mix_d["test"]))

dataloader_source_train = torch.utils.data.DataLoader(
    source_train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)
dataloader_source_val = torch.utils.data.DataLoader(
    source_val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)
dataloader_source_test = torch.utils.data.DataLoader(
    source_test_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)

# ### target dataloaders
# target_test_set_d = {}
# for sample_id in st_sample_id_l:
#     target_test_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d["test"][sample_id]))

# dataloader_target_test_d = {}
# for sample_id in st_sample_id_l:
#     dataloader_target_test_d[sample_id] = torch.utils.data.DataLoader(
#         target_test_set_d[sample_id],
#         batch_size=BATCH_SIZE,
#         shuffle=False,
#         num_workers=NUM_WORKERS,
#         pin_memory=False,
#     )

# if TRAIN_USING_ALL_ST_SAMPLES:
#     target_train_set = SpotDataset(deepcopy(mat_sp_train_s))
#     dataloader_target_train = torch.utils.data.DataLoader(
#         target_train_set,
#         batch_size=BATCH_SIZE,
#         shuffle=True,
#         num_workers=NUM_WORKERS,
#         pin_memory=True,
#     )
# else:
#     target_train_set_d = {}
#     dataloader_target_train_d = {}
#     for sample_id in st_sample_id_l:
#         target_train_set_d[sample_id] = SpotDataset(
#             deepcopy(mat_sp_d["train"][sample_id])
#         )
#         dataloader_target_train_d[sample_id] = torch.utils.data.DataLoader(
#             target_train_set_d[sample_id],
#             batch_size=BATCH_SIZE,
#             shuffle=True,
#             num_workers=NUM_WORKERS,
#             pin_memory=True,
#         )

# if TRAIN_USING_ALL_ST_SAMPLES:
#     target_train_set = SpotDataset(deepcopy(mat_sp_train))
#     dataloader_target_train = torch.utils.data.DataLoader(
#         target_train_set,
#         batch_size=BATCH_SIZE,
#         shuffle=True,
#         num_workers=NUM_WORKERS,
#         pin_memory=True,
#     )

### target dataloaders
target_train_set_d = {}
dataloader_target_train_d = {}
if ST_SPLIT:
    target_val_set_d = {}
    target_test_set_d = {}

    dataloader_target_val_d = {}
    dataloader_target_test_d = {}
    for sample_id in st_sample_id_l:
        target_train_set_d[sample_id] = SpotDataset(
            deepcopy(mat_sp_d[sample_id]["train"])
        )
        target_val_set_d[sample_id] = SpotDataset(deepcopy(mat_sp_d[sample_id]["val"]))
        target_test_set_d[sample_id] = SpotDataset(
            deepcopy(mat_sp_d[sample_id]["test"])
        )

        dataloader_target_train_d[sample_id] = torch.utils.data.DataLoader(
            target_train_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )
        dataloader_target_val_d[sample_id] = torch.utils.data.DataLoader(
            target_val_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )
        dataloader_target_test_d[sample_id] = torch.utils.data.DataLoader(
            target_test_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )

else:
    target_test_set_d = {}
    dataloader_target_test_d = {}
    for sample_id in st_sample_id_l:
        target_train_set_d[sample_id] = SpotDataset(
            deepcopy(mat_sp_d[sample_id]["train"])
        )
        dataloader_target_train_d[sample_id] = torch.utils.data.DataLoader(
            target_train_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )

        target_test_set_d[sample_id] = SpotDataset(
            deepcopy(mat_sp_d[sample_id]["test"])
        )
        dataloader_target_test_d[sample_id] = torch.utils.data.DataLoader(
            target_test_set_d[sample_id],
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=NUM_WORKERS,
            pin_memory=False,
        )


if TRAIN_USING_ALL_ST_SAMPLES:
    target_train_set = SpotDataset(mat_sp_train)
    dataloader_target_train = torch.utils.data.DataLoader(
        target_train_set,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )


 ## Define Model

In [11]:
model = ADDAST(
    inp_dim=sc_mix_d["train"].shape[1],
    ncls_source=lab_mix_d["train"].shape[1],
    **celldart_kwargs
)

## CellDART uses just one encoder!
model.target_encoder = model.source_encoder
model.to(device)


ADDAST(
  (source_encoder): MLPEncoder(
    (encoder): Sequential(
      (0): Linear(in_features=360, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): Linear(in_features=1024, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (5): ELU(alpha=1.0)
    )
  )
  (target_encoder): MLPEncoder(
    (encoder): Sequential(
      (0): Linear(in_features=360, out_features=1024, bias=True)
      (1): BatchNorm1d(1024, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): ELU(alpha=1.0)
      (3): Linear(in_features=1024, out_features=64, bias=True)
      (4): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (5): ELU(alpha=1.0)
    )
  )
  (dis): Discriminator(
    (head): Sequential(
      (0): Linear(in_features=64, out_features=32, bias=True)
      (1):

 ## Pretrain

In [12]:
pretrain_folder = os.path.join(model_folder, "pretrain")

if not os.path.isdir(pretrain_folder):
    os.makedirs(pretrain_folder)


In [13]:
pre_optimizer = torch.optim.Adam(
    model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-07
)

criterion_clf = nn.KLDivLoss(reduction="batchmean")


In [14]:
def model_loss(x, y_true, model):
    x = x.to(torch.float32).to(device)
    y_true = y_true.to(torch.float32).to(device)

    y_pred = model(x)

    loss = criterion_clf(y_pred, y_true)

    return loss


def run_pretrain_epoch(model, dataloader, optimizer=None, inner=None):
    loss_running = []
    mean_weights = []

    is_training = model.training and optimizer

    for _, batch in enumerate(dataloader):
        loss = model_loss(*batch, model)
        loss_running.append(loss.item())
        mean_weights.append(len(batch))  # we will weight average by batch size later

        if is_training:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if inner:
            inner.update(1)
    return loss_running, mean_weights


def compute_acc(dataloader, model):

    model.eval()
    with torch.no_grad():
        loss_running, mean_weights = run_pretrain_epoch(model, dataloader)

    return np.average(loss_running, weights=mean_weights)


In [15]:
model.pretraining()


In [16]:
# Initialize lists to store loss and accuracy values
loss_history = []
loss_history_val = []

loss_history_running = []

# Early Stopping
best_loss_val = np.inf
early_stop_count = 0


with DupStdout().dup_to_file(os.path.join(pretrain_folder, "log.txt"), "w") as f_log:
    # Train
    print("Start pretrain...")
    outer = tqdm(total=INITIAL_TRAIN_EPOCHS, desc="Epochs", position=0)
    inner = tqdm(total=len(dataloader_source_train), desc=f"Batch", position=1)

    print(" Epoch | Train Loss | Val Loss   ")
    print("---------------------------------")
    checkpoint = {
        "epoch": -1,
        "model": model,
        "optimizer": pre_optimizer,
    }
    for epoch in range(INITIAL_TRAIN_EPOCHS):
        inner.refresh()  # force print final state
        inner.reset()  # reuse bar

        checkpoint["epoch"] = epoch

        # Train mode
        model.train()

        loss_running, mean_weights = run_pretrain_epoch(
            model, dataloader_source_train, optimizer=pre_optimizer, inner=inner
        )

        loss_history.append(np.average(loss_running, weights=mean_weights))
        loss_history_running.append(loss_running)

        # Evaluate mode
        model.eval()
        with torch.no_grad():
            curr_loss_val = compute_acc(dataloader_source_val, model)
            loss_history_val.append(curr_loss_val)

        # Print the results
        outer.update(1)
        print(
            f" {epoch:5d}",
            f"| {loss_history[-1]:<10.8f}",
            f"| {curr_loss_val:<10.8f}",
            end=" ",
        )

        # Save the best weights
        if curr_loss_val < best_loss_val:
            best_loss_val = curr_loss_val
            torch.save(checkpoint, os.path.join(pretrain_folder, f"best_model.pth"))
            early_stop_count = 0

            print("<-- new best val loss")
        else:
            print("")

        # Save checkpoint every 10
        if epoch % 10 == 0 or epoch >= INITIAL_TRAIN_EPOCHS - 1:
            torch.save(checkpoint, os.path.join(pretrain_folder, f"checkpt{epoch}.pth"))

        # check to see if validation loss has plateau'd
        if early_stop_count >= EARLY_STOP_CRIT and epoch >= MIN_EPOCHS - 1:
            print(
                f"Validation loss plateaued after {early_stop_count} at epoch {epoch}"
            )
            torch.save(
                checkpoint, os.path.join(pretrain_folder, f"earlystop{epoch}.pth")
            )
            break

        early_stop_count += 1

    model.eval()
    with torch.no_grad():
        curr_loss_train = compute_acc(dataloader_source_train, model)

    print(f"Final train loss: {curr_loss_train}")
    # Save final model
    torch.save(checkpoint, os.path.join(pretrain_folder, f"final_model.pth"))


Start pretrain...


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Batch:   0%|          | 0/40 [00:00<?, ?it/s]

 Epoch | Train Loss | Val Loss   
---------------------------------
     0 | 1.19777877 | 1.49447010 <-- new best val loss
     1 | 0.84543390 | 1.23260732 <-- new best val loss
     2 | 0.68179746 | 1.02762511 <-- new best val loss
     3 | 0.60963448 | 0.97713646 <-- new best val loss
     4 | 0.57786487 | 0.88541307 <-- new best val loss
     5 | 0.55358100 | 0.85641049 <-- new best val loss
     6 | 0.53845328 | 0.83081232 <-- new best val loss
     7 | 0.52159256 | 0.86795192 
     8 | 0.50930184 | 0.76681435 <-- new best val loss
     9 | 0.49394928 | 1.05101705 
Final train loss: 0.9653496757149697


 ## Adversarial Adaptation

In [17]:
advtrain_folder = os.path.join(model_folder, "advtrain")

if not os.path.isdir(advtrain_folder):
    os.makedirs(advtrain_folder)


In [18]:
def batch_generator(data, batch_size):
    """Generate batches of data.

    Given a list of numpy data, it iterates over the list and returns batches of
    the same size.
    """
    all_examples_indices = len(data[0])
    while True:
        mini_batch_indices = np.random.choice(
            all_examples_indices, size=batch_size, replace=False
        )
        tbr = [k[mini_batch_indices] for k in data]
        yield tbr


In [19]:
criterion_dis = nn.CrossEntropyLoss()


def discrim_loss_accu(x, y_dis, model):
    x = x.to(torch.float32).to(device)

    emb = model.source_encoder(x)
    y_pred = model.dis(emb)

    loss = criterion_dis(y_pred, y_dis)

    accu = torch.mean(
        (torch.flatten(torch.argmax(y_pred, dim=1)) == y_dis).to(torch.float32)
    ).cpu()

    return loss, accu


def compute_acc_dis(dataloader_source, dataloader_target, model):
    len_target = len(dataloader_target)
    len_source = len(dataloader_source)

    loss_running = []
    accu_running = []
    mean_weights = []
    model.eval()
    model.source_encoder.eval()
    model.dis.eval()
    with torch.no_grad():
        for y_val, dl in zip([1, 0], [dataloader_target, dataloader_source]):
            for _, (x, _) in enumerate(dl):

                y_dis = torch.full(
                    (x.shape[0],), y_val, device=device, dtype=torch.long
                )

                loss, accu = discrim_loss_accu(x, y_dis, model)

                accu_running.append(accu)
                loss_running.append(loss.item())

                # we will weight average by batch size later
                mean_weights.append(len(x))

    return (
        np.average(loss_running, weights=mean_weights),
        np.average(accu_running, weights=mean_weights),
    )


In [20]:
def train_adversarial(
    model,
    save_folder,
    sc_mix_train_s,
    lab_mix_train_s,
    mat_sp_train_s,
    dataloader_source_train_eval,
    dataloader_target_train_eval,
):

    model.to(device)
    model.advtraining()
    model.set_encoder("source")

    S_batches = batch_generator(
        [sc_mix_train_s.copy(), lab_mix_train_s.copy()], BATCH_SIZE
    )
    T_batches = batch_generator(
        [mat_sp_train_s.copy(), np.ones(shape=(len(mat_sp_train_s), 2))], BATCH_SIZE
    )

    enc_optimizer = torch.optim.Adam(
        chain(
            model.source_encoder.parameters(),
            model.dis.parameters(),
            model.clf.parameters(),
        ),
        lr=0.001,
        betas=(0.9, 0.999),
        eps=1e-07,
    )

    dis_optimizer = torch.optim.Adam(
        chain(model.source_encoder.parameters(), model.dis.parameters()),
        lr=ALPHA_LR * 0.001,
        betas=(0.9, 0.999),
        eps=1e-07,
    )

    # Initialize lists to store loss and accuracy values
    loss_history_running = []
    with DupStdout().dup_to_file(os.path.join(save_folder, "log.txt"), "w") as f_log:
        # Train
        print("Start adversarial training...")
        outer = tqdm(total=N_ITER, desc="Iterations", position=0)

        checkpoint = {
            "epoch": -1,
            "model": model,
            "dis_optimizer": dis_optimizer,
            "enc_optimizer": enc_optimizer,
        }
        for iters in range(N_ITER):
            checkpoint["epoch"] = iters

            model.train()
            model.dis.train()
            model.clf.train()
            model.source_encoder.train()

            x_source, y_true = next(S_batches)
            x_target, _ = next(T_batches)

            ## Train encoder
            set_requires_grad(model.source_encoder, True)
            set_requires_grad(model.clf, True)
            set_requires_grad(model.dis, True)

            # save discriminator weights
            dis_weights = deepcopy(model.dis.state_dict())
            new_dis_weights = {}
            for k in dis_weights:
                if "num_batches_tracked" not in k:
                    new_dis_weights[k] = dis_weights[k]

            dis_weights = new_dis_weights

            x_source, x_target, y_true, = (
                torch.Tensor(x_source),
                torch.Tensor(x_target),
                torch.Tensor(y_true),
            )
            x_source, x_target, y_true, = (
                x_source.to(device),
                x_target.to(device),
                y_true.to(device),
            )

            x = torch.cat((x_source, x_target))

            # save for discriminator later
            x_d = x.detach()

            # y_dis is the REAL one
            y_dis = torch.cat(
                [
                    torch.zeros(x_source.shape[0], device=device, dtype=torch.long),
                    torch.ones(x_target.shape[0], device=device, dtype=torch.long),
                ]
            )
            y_dis_flipped = 1 - y_dis.detach()

            emb = model.source_encoder(x).view(x.shape[0], -1)

            y_dis_pred = model.dis(emb)
            y_clf_pred = model.clf(emb)

            # we use flipped because we want to confuse discriminator
            loss_dis = criterion_dis(y_dis_pred, y_dis_flipped)

            # Set true = predicted for target samples since we don't know what it is
            y_clf_true = torch.cat((y_true, y_clf_pred[-x_target.shape[0] :].detach()))
            # Loss fn does mean over all samples including target
            loss_clf = criterion_clf(y_clf_pred, y_clf_true)

            # Scale back up loss so mean doesn't include target
            # loss = (x.shape[0] / x_source.shape[0]) * loss_clf + ALPHA * loss_dis
            loss = loss_clf + ALPHA * loss_dis

            # loss = loss_clf + ALPHA * loss_dis

            enc_optimizer.zero_grad()
            loss.backward()
            enc_optimizer.step()

            model.dis.load_state_dict(dis_weights, strict=False)

            ## Train discriminator
            set_requires_grad(model.source_encoder, True)
            set_requires_grad(model.clf, True)
            set_requires_grad(model.dis, True)

            # save encoder and clf weights
            source_encoder_weights = deepcopy(model.source_encoder.state_dict())
            clf_weights = deepcopy(model.clf.state_dict())

            new_clf_weights = {}
            for k in clf_weights:
                if "num_batches_tracked" not in k:
                    new_clf_weights[k] = clf_weights[k]

            clf_weights = new_clf_weights

            new_source_encoder_weights = {}
            for k in source_encoder_weights:
                if "num_batches_tracked" not in k:
                    new_source_encoder_weights[k] = source_encoder_weights[k]

            source_encoder_weights = new_source_encoder_weights

            emb = model.source_encoder(x_d).view(x_d.shape[0], -1)
            y_pred = model.dis(emb)

            # we use the real domain labels to train discriminator
            loss = criterion_dis(y_pred, y_dis)

            dis_optimizer.zero_grad()
            loss.backward()
            dis_optimizer.step()

            model.clf.load_state_dict(clf_weights, strict=False)
            model.source_encoder.load_state_dict(source_encoder_weights, strict=False)

            # Save checkpoint every 100
            if iters % 100 == 99 or iters >= N_ITER - 1:
                torch.save(checkpoint, os.path.join(save_folder, f"checkpt{iters}.pth"))

                model.eval()
                source_loss = compute_acc(dataloader_source_train_eval, model)
                _, dis_accu = compute_acc_dis(
                    dataloader_source_train_eval, dataloader_target_train_eval, model
                )

                # Print the results
                print(
                    "iter:",
                    iters,
                    "source loss:",
                    round(source_loss, 6),
                    "dis accu:",
                    round(dis_accu, 6),
                )

            outer.update(1)

    torch.save(checkpoint, os.path.join(save_folder, f"final_model.pth"))


In [21]:
# st_sample_id_l = [SAMPLE_ID_N]


In [22]:
if TRAIN_USING_ALL_ST_SAMPLES:
    print(f"Adversarial training for all ST slides")
    save_folder = advtrain_folder

    best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
    model = best_checkpoint["model"]
    model.to(device)
    model.advtraining()

    train_adversarial(
        model,
        save_folder,
        sc_mix_d["train"],
        lab_mix_d["train"],
        mat_sp_train,
        dataloader_source_train,
        dataloader_target_train,
    )

else:
    for sample_id in st_sample_id_l:
        print(f"Adversarial training for ST slide {sample_id}: ")

        save_folder = os.path.join(advtrain_folder, sample_id)
        if not os.path.isdir(save_folder):
            os.makedirs(save_folder)

        best_checkpoint = torch.load(os.path.join(pretrain_folder, f"final_model.pth"))
        model = best_checkpoint["model"]
        model.to(device)
        model.advtraining()

        train_adversarial(
            model,
            save_folder,
            sc_mix_d["train"],
            lab_mix_d["train"],
            mat_sp_d[sample_id]["train"],
            dataloader_source_train,
            dataloader_target_train_d[sample_id],
        )


Adversarial training for ST slide 151507: 
Start adversarial training...


Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 5.4017 dis accu: 0.174399
iter: 199 source loss: 3.0693 dis accu: 0.174482
iter: 299 source loss: 1.633372 dis accu: 0.869355
iter: 399 source loss: 1.401322 dis accu: 0.972468
iter: 499 source loss: 1.726897 dis accu: 0.175555
iter: 599 source loss: 1.257941 dis accu: 0.861884
iter: 699 source loss: 1.128703 dis accu: 0.976224
iter: 799 source loss: 1.147126 dis accu: 0.462396
iter: 899 source loss: 1.043894 dis accu: 0.296789
iter: 999 source loss: 0.910239 dis accu: 0.652976
iter: 1099 source loss: 0.887649 dis accu: 0.929951
iter: 1199 source loss: 0.860319 dis accu: 0.710105
iter: 1299 source loss: 0.885945 dis accu: 0.169859
iter: 1399 source loss: 0.80102 dis accu: 0.43573
iter: 1499 source loss: 0.779277 dis accu: 0.72158
iter: 1599 source loss: 0.724345 dis accu: 0.782919
iter: 1699 source loss: 0.698748 dis accu: 0.516924
iter: 1799 source loss: 0.704437 dis accu: 0.720548
iter: 1899 source loss: 0.82637 dis accu: 0.826798
iter: 1999 source loss: 0.73510

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 4.504247 dis accu: 0.823163
iter: 199 source loss: 2.324254 dis accu: 0.853593
iter: 299 source loss: 1.566339 dis accu: 0.831775
iter: 399 source loss: 1.317972 dis accu: 0.835671
iter: 499 source loss: 1.453907 dis accu: 0.873401
iter: 599 source loss: 1.620304 dis accu: 0.893865
iter: 699 source loss: 1.266482 dis accu: 0.999795
iter: 799 source loss: 1.72018 dis accu: 0.23659
iter: 899 source loss: 1.317953 dis accu: 0.889149
iter: 999 source loss: 1.096587 dis accu: 0.98823
iter: 1099 source loss: 1.233276 dis accu: 0.160843
iter: 1199 source loss: 1.098925 dis accu: 0.46957
iter: 1299 source loss: 1.032381 dis accu: 0.393619
iter: 1399 source loss: 0.860727 dis accu: 0.692995
iter: 1499 source loss: 0.833525 dis accu: 0.938443
iter: 1599 source loss: 0.868044 dis accu: 0.655143
iter: 1699 source loss: 0.742541 dis accu: 0.730848
iter: 1799 source loss: 0.871275 dis accu: 0.212147
iter: 1899 source loss: 0.70027 dis accu: 0.876927
iter: 1999 source loss: 0.70

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 3.93655 dis accu: 0.193191
iter: 199 source loss: 3.225068 dis accu: 0.193191
iter: 299 source loss: 2.510133 dis accu: 0.193191
iter: 399 source loss: 2.037629 dis accu: 0.856348
iter: 499 source loss: 1.382378 dis accu: 0.822744
iter: 599 source loss: 1.527784 dis accu: 0.286296
iter: 699 source loss: 1.509835 dis accu: 0.952681
iter: 799 source loss: 1.07183 dis accu: 0.963452
iter: 899 source loss: 1.077785 dis accu: 0.709629
iter: 999 source loss: 0.98958 dis accu: 0.590746
iter: 1099 source loss: 0.921947 dis accu: 0.495099
iter: 1199 source loss: 0.841788 dis accu: 0.627819
iter: 1299 source loss: 0.81596 dis accu: 0.734156
iter: 1399 source loss: 0.763839 dis accu: 0.753237
iter: 1499 source loss: 0.808622 dis accu: 0.436242
iter: 1599 source loss: 0.813614 dis accu: 0.38981
iter: 1699 source loss: 0.899521 dis accu: 0.751019
iter: 1799 source loss: 0.712273 dis accu: 0.874824
iter: 1899 source loss: 0.832751 dis accu: 0.281778
iter: 1999 source loss: 0.62

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 5.448628 dis accu: 0.188033
iter: 199 source loss: 2.971683 dis accu: 0.192864
iter: 299 source loss: 2.027565 dis accu: 0.963993
iter: 399 source loss: 3.41801 dis accu: 0.197369
iter: 499 source loss: 1.576306 dis accu: 0.866851
iter: 599 source loss: 1.388497 dis accu: 0.790412
iter: 699 source loss: 1.306599 dis accu: 0.989283
iter: 799 source loss: 1.124599 dis accu: 0.893318
iter: 899 source loss: 1.038136 dis accu: 0.166964
iter: 999 source loss: 0.918856 dis accu: 0.863278
iter: 1099 source loss: 0.881821 dis accu: 0.89275
iter: 1199 source loss: 0.837144 dis accu: 0.20151
iter: 1299 source loss: 0.799059 dis accu: 0.45835
iter: 1399 source loss: 0.768453 dis accu: 0.729683
iter: 1499 source loss: 0.739612 dis accu: 0.925185
iter: 1599 source loss: 0.896248 dis accu: 0.468621
iter: 1699 source loss: 0.954024 dis accu: 0.891857
iter: 1799 source loss: 0.823147 dis accu: 0.250386
iter: 1899 source loss: 0.82359 dis accu: 0.093935
iter: 1999 source loss: 0.80

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 3.533735 dis accu: 0.154854
iter: 199 source loss: 4.541022 dis accu: 0.154727
iter: 299 source loss: 2.266331 dis accu: 0.997802
iter: 399 source loss: 1.472623 dis accu: 0.903935
iter: 499 source loss: 1.840753 dis accu: 0.306031
iter: 599 source loss: 1.433373 dis accu: 0.929758
iter: 699 source loss: 1.111331 dis accu: 0.99738
iter: 799 source loss: 1.237331 dis accu: 0.23917
iter: 899 source loss: 1.062344 dis accu: 0.816618
iter: 999 source loss: 0.92633 dis accu: 0.829255
iter: 1099 source loss: 0.908446 dis accu: 0.632391
iter: 1199 source loss: 0.849388 dis accu: 0.1628
iter: 1299 source loss: 0.771536 dis accu: 0.708634
iter: 1399 source loss: 0.761097 dis accu: 0.969317
iter: 1499 source loss: 0.772021 dis accu: 0.818731
iter: 1599 source loss: 0.761638 dis accu: 0.135751
iter: 1699 source loss: 0.759997 dis accu: 0.258527
iter: 1799 source loss: 0.807825 dis accu: 0.058113
iter: 1899 source loss: 0.680057 dis accu: 0.43802
iter: 1999 source loss: 0.645

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 3.077965 dis accu: 0.172696
iter: 199 source loss: 3.682122 dis accu: 0.148864
iter: 299 source loss: 1.434863 dis accu: 0.816367
iter: 399 source loss: 1.347847 dis accu: 0.98485
iter: 499 source loss: 1.373886 dis accu: 0.705379
iter: 599 source loss: 1.261557 dis accu: 0.88599
iter: 699 source loss: 1.401811 dis accu: 0.798196
iter: 799 source loss: 1.194591 dis accu: 0.748447
iter: 899 source loss: 1.00373 dis accu: 0.714699
iter: 999 source loss: 1.161378 dis accu: 0.350498
iter: 1099 source loss: 0.983043 dis accu: 0.937867
iter: 1199 source loss: 0.976846 dis accu: 0.86318
iter: 1299 source loss: 0.903391 dis accu: 0.258703
iter: 1399 source loss: 0.796254 dis accu: 0.577241
iter: 1499 source loss: 0.776308 dis accu: 0.918333
iter: 1599 source loss: 0.760606 dis accu: 0.791982
iter: 1699 source loss: 0.789654 dis accu: 0.632181
iter: 1799 source loss: 0.782875 dis accu: 0.839986
iter: 1899 source loss: 0.727043 dis accu: 0.816197
iter: 1999 source loss: 0.7

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 5.70873 dis accu: 0.170303
iter: 199 source loss: 2.62229 dis accu: 0.170469
iter: 299 source loss: 1.24146 dis accu: 0.982621
iter: 399 source loss: 1.284904 dis accu: 0.983409
iter: 499 source loss: 1.347267 dis accu: 0.832269
iter: 599 source loss: 1.890769 dis accu: 0.829614
iter: 699 source loss: 1.512262 dis accu: 0.833306
iter: 799 source loss: 1.39784 dis accu: 0.842472
iter: 899 source loss: 1.187963 dis accu: 0.841269
iter: 999 source loss: 1.141578 dis accu: 0.964869
iter: 1099 source loss: 1.4002 dis accu: 0.980382
iter: 1199 source loss: 1.316463 dis accu: 0.764081
iter: 1299 source loss: 1.160498 dis accu: 0.85421
iter: 1399 source loss: 1.073782 dis accu: 0.904521
iter: 1499 source loss: 0.94332 dis accu: 0.852841
iter: 1599 source loss: 0.932534 dis accu: 0.928246
iter: 1699 source loss: 0.88123 dis accu: 0.848237
iter: 1799 source loss: 0.857424 dis accu: 0.938407
iter: 1899 source loss: 0.860528 dis accu: 0.865533
iter: 1999 source loss: 0.791203

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 5.493362 dis accu: 0.167146
iter: 199 source loss: 2.542269 dis accu: 0.167187
iter: 299 source loss: 1.324297 dis accu: 0.844597
iter: 399 source loss: 1.320253 dis accu: 0.878867
iter: 499 source loss: 1.235583 dis accu: 0.994587
iter: 599 source loss: 1.289676 dis accu: 0.806704
iter: 699 source loss: 1.178023 dis accu: 0.990714
iter: 799 source loss: 1.123277 dis accu: 0.912846
iter: 899 source loss: 1.00393 dis accu: 0.275744
iter: 999 source loss: 1.045112 dis accu: 0.671039
iter: 1099 source loss: 1.027874 dis accu: 0.593046
iter: 1199 source loss: 1.314303 dis accu: 0.161816
iter: 1299 source loss: 0.826376 dis accu: 0.992213
iter: 1399 source loss: 0.827461 dis accu: 0.998751
iter: 1499 source loss: 0.833654 dis accu: 0.14745
iter: 1599 source loss: 0.868207 dis accu: 0.567312
iter: 1699 source loss: 0.831329 dis accu: 0.829273
iter: 1799 source loss: 0.738665 dis accu: 0.845721
iter: 1899 source loss: 0.860511 dis accu: 0.896773
iter: 1999 source loss: 1

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 5.119978 dis accu: 0.153941
iter: 199 source loss: 2.147019 dis accu: 0.168958
iter: 299 source loss: 1.614025 dis accu: 0.39236
iter: 399 source loss: 1.59496 dis accu: 0.894454
iter: 499 source loss: 1.480036 dis accu: 0.849275
iter: 599 source loss: 1.265024 dis accu: 0.850967
iter: 699 source loss: 1.576469 dis accu: 0.865646
iter: 799 source loss: 1.4425 dis accu: 0.888701
iter: 899 source loss: 1.263095 dis accu: 0.444139
iter: 999 source loss: 1.195137 dis accu: 0.997927
iter: 1099 source loss: 1.291001 dis accu: 0.826431
iter: 1199 source loss: 1.187274 dis accu: 0.65485
iter: 1299 source loss: 1.169618 dis accu: 0.98126
iter: 1399 source loss: 0.955157 dis accu: 0.991624
iter: 1499 source loss: 0.961021 dis accu: 0.29299
iter: 1599 source loss: 0.882833 dis accu: 0.358433
iter: 1699 source loss: 0.918198 dis accu: 0.694699
iter: 1799 source loss: 0.843339 dis accu: 0.664241
iter: 1899 source loss: 0.82003 dis accu: 0.542747
iter: 1999 source loss: 0.81332

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 1.886389 dis accu: 0.913868
iter: 199 source loss: 3.226952 dis accu: 0.155198
iter: 299 source loss: 1.594543 dis accu: 0.895493
iter: 399 source loss: 1.194303 dis accu: 0.868627
iter: 499 source loss: 1.868385 dis accu: 0.986567
iter: 599 source loss: 1.334354 dis accu: 0.997465
iter: 699 source loss: 1.881617 dis accu: 0.87923
iter: 799 source loss: 1.706665 dis accu: 0.847759
iter: 899 source loss: 1.401037 dis accu: 0.562159
iter: 999 source loss: 1.191573 dis accu: 0.902927
iter: 1099 source loss: 1.050078 dis accu: 0.99755
iter: 1199 source loss: 1.032 dis accu: 0.431631
iter: 1299 source loss: 1.020381 dis accu: 0.06104
iter: 1399 source loss: 0.884693 dis accu: 0.705572
iter: 1499 source loss: 0.868661 dis accu: 0.890593
iter: 1599 source loss: 0.860344 dis accu: 0.896887
iter: 1699 source loss: 0.808905 dis accu: 0.928146
iter: 1799 source loss: 0.84369 dis accu: 0.961982
iter: 1899 source loss: 0.804228 dis accu: 0.728467
iter: 1999 source loss: 0.7453

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 4.032641 dis accu: 0.857918
iter: 199 source loss: 3.114776 dis accu: 0.865463
iter: 299 source loss: 2.323483 dis accu: 0.867074
iter: 399 source loss: 1.841951 dis accu: 0.876272
iter: 499 source loss: 1.621158 dis accu: 0.883647
iter: 599 source loss: 1.555888 dis accu: 0.880214
iter: 699 source loss: 1.390868 dis accu: 0.868769
iter: 799 source loss: 1.320569 dis accu: 0.878772
iter: 899 source loss: 1.280394 dis accu: 0.905307
iter: 999 source loss: 1.361798 dis accu: 0.954773
iter: 1099 source loss: 1.338865 dis accu: 0.999449
iter: 1199 source loss: 1.573441 dis accu: 0.995041
iter: 1299 source loss: 2.425276 dis accu: 0.140853
iter: 1399 source loss: 1.223671 dis accu: 0.845075
iter: 1499 source loss: 1.096166 dis accu: 0.728467
iter: 1599 source loss: 1.083995 dis accu: 0.882884
iter: 1699 source loss: 0.995068 dis accu: 0.80078
iter: 1799 source loss: 0.966173 dis accu: 0.820617
iter: 1899 source loss: 0.930025 dis accu: 0.774161
iter: 1999 source loss: 

Iterations:   0%|          | 0/3000 [00:00<?, ?it/s]

iter: 99 source loss: 4.725677 dis accu: 0.147485
iter: 199 source loss: 2.547169 dis accu: 0.147485
iter: 299 source loss: 1.766173 dis accu: 0.462958
iter: 399 source loss: 2.169179 dis accu: 0.321441
iter: 499 source loss: 1.311489 dis accu: 0.869565
iter: 599 source loss: 1.281356 dis accu: 0.663598
iter: 699 source loss: 1.226743 dis accu: 0.973274
iter: 799 source loss: 1.00858 dis accu: 0.977579
iter: 899 source loss: 0.998201 dis accu: 0.202089
iter: 999 source loss: 0.924722 dis accu: 0.209122
iter: 1099 source loss: 0.879695 dis accu: 0.740239
iter: 1199 source loss: 0.876242 dis accu: 0.335976
iter: 1299 source loss: 0.80428 dis accu: 0.363811
iter: 1399 source loss: 0.782978 dis accu: 0.107332
iter: 1499 source loss: 0.789074 dis accu: 0.63035
iter: 1599 source loss: 0.748172 dis accu: 0.844672
iter: 1699 source loss: 0.774682 dis accu: 0.175661
iter: 1799 source loss: 0.718089 dis accu: 0.250085
iter: 1899 source loss: 0.932419 dis accu: 0.333632
iter: 1999 source loss: 0.