In [3]:
%load_ext autoreload
%autoreload 2


import numpy as np
import pandas as pd
import anndata as ad
import numpy as np
import pandas as pd
import jax.numpy as jnp
import jax
jax.config.update("jax_default_device", jax.devices()[0])

In [4]:
def create_synthetic_data(
    n_genes=50,
    n_control_cells=40,
    n_drugs=6,
    dosages_per_drug=3,
    cells_per_condition=50,
    n_cell_types=30,
    cell_type_embed_dim=50,
    drug_embed_dim=50,
):
    """
    Creates a synthetic AnnData object with multiple dosages per drug.

    Parameters
    ----------
    n_genes : int
        Number of genes to simulate
    n_control_cells : int
        Number of control cells per cell type
    n_drugs : int
        Total number of distinct drugs
    dosages_per_drug : int
        Number of different dosages for each drug
    cells_per_condition : int
        Number of cells per drug-dosage condition
    n_cell_types : int
        Number of cell types
    cell_type_embed_dim : int
        Embedding dimension for cell types
    drug_embed_dim : int
        Embedding dimension for drugs

    Returns
    -------
    dict
        Dictionary containing all DataManager parameters
    """

    # Create cell type names
    n_batches = n_cell_types
    cell_type_names = [f"cell_line_{chr(97 + i)}" for i in range(n_cell_types)]
    batch_names = [f"batch_{i+1}" for i in range(n_batches)]  # New: create batch names

    # Calculate total cells
    total_conditions = n_drugs * dosages_per_drug  # Total conditions excluding control
    total_cells_per_type = n_control_cells + (total_conditions * cells_per_condition)
    n_cells = n_cell_types * total_cells_per_type

    # Initialize lists for observation data
    cell_type_list = []
    control_list = []
    drug_list = []
    dosage_list = []
    batch_list = []

    # Generate data for each cell type
    # shuffle batch_names
    batch_names = np.random.permutation(batch_names)

    for i,cell_type in enumerate(cell_type_names):
        # Add control cells for this cell type
        cell_type_list.extend([cell_type] * total_cells_per_type)
        # Controls
        control_list.extend([True] * n_control_cells)
        drug_list.extend(["control"] * n_control_cells)
        dosage_list.extend([0.0] * n_control_cells)
        # batch_assignments = np.random.choice(batch_names[i], size=total_cells_per_type)
        batch_list.extend([batch_names[i]] * total_cells_per_type)
        # ensure that at lease one type is present 


        # Add perturbed cells for each drug-dosage combination
        control_list.extend([False] * (total_conditions * cells_per_condition))

        # Add drug-dosage combinations
        for drug_idx in range(1, n_drugs + 1):
            for dosage_idx in range(1, dosages_per_drug + 1):
                # Calculate dosage value (e.g., 0.1, 0.5, 1.0)
                dosage_value = dosage_idx / dosages_per_drug  # Normalize to [0,1] range

                # Add this drug-dosage combination
                drug_list.extend([f"drug{drug_idx}"] * cells_per_condition)
                dosage_list.extend([dosage_value] * cells_per_condition)

    # Generate random expression data
    X = np.random.normal(size=(n_cells, n_genes))

    # Create observation DataFrame
    obs = pd.DataFrame(
        {
            "control": control_list,
            "cell_type": pd.Categorical(cell_type_list),
            "drug": pd.Categorical(drug_list),
            "dosage": dosage_list,
            # "batch": pd.Categorical(batch_list),
        }
    )



    # Create AnnData object
    adata = ad.AnnData(X, obs=obs)

    # Add representations to uns (for covariate embeddings)
    adata.uns["drug"] = {
        "control": np.zeros(drug_embed_dim),
    }

    # Add drug embeddings
    for i in range(1, n_drugs + 1):
        adata.uns["drug"][f"drug{i}"] = np.random.normal(size=(drug_embed_dim,))

    # Add cell type embeddings
    adata.uns["cell_type"] = {}
    # adata.uns["batch"] = {}
    for cell_type in cell_type_names:
        adata.uns["cell_type"][cell_type] = np.random.normal(
            size=(cell_type_embed_dim,)
        )
    # for batch in batch_names:
    #     adata.uns["batch"][batch] = np.random.normal(
    #         size=(cell_type_embed_dim,)
    #     )

    # Define parameters for DataManager
    sample_rep = "X"
    control_key = "control"
    split_covariates = ["cell_type",]

    # Here we use a simpler structure with just one drug and dosage column
    perturbation_covariates = {"drug": ["drug"], "dosage": ["dosage"]}
    perturbation_covariate_reps = {"drug": "drug"}
    sample_covariates = ["cell_type", ]
    sample_covariate_reps = {"cell_type": "cell_type", }



    # Return a dictionary with all required parameters
    return {
        "adata": adata,
        "sample_rep": sample_rep,
        "control_key": control_key,
        "split_covariates": split_covariates,
        "perturbation_covariates": perturbation_covariates,
        "perturbation_covariate_reps": perturbation_covariate_reps,
        "sample_covariates": sample_covariates,
        "sample_covariate_reps": sample_covariate_reps,
    }


In [5]:
# Now you can initialize the DataManager
from cellflow.data._datamanager import DataManager

dm_args = create_synthetic_data()

dm = DataManager(
    **dm_args    
)



In [6]:
from cellflow.model._cellflow import CellFlow

cf = CellFlow(adata=dm_args["adata"],) 

In [7]:
import functools

import cellflow
import scanpy as sc
import numpy as np
import functools
from ott.solvers import utils as solver_utils
import optax
import anndata as ad

In [8]:
dm_args

{'adata': AnnData object with n_obs × n_vars = 28200 × 50
     obs: 'control', 'cell_type', 'drug', 'dosage'
     uns: 'drug', 'cell_type',
 'sample_rep': 'X',
 'control_key': 'control',
 'split_covariates': ['cell_type'],
 'perturbation_covariates': {'drug': ['drug'], 'dosage': ['dosage']},
 'perturbation_covariate_reps': {'drug': 'drug'},
 'sample_covariates': ['cell_type'],
 'sample_covariate_reps': {'cell_type': 'cell_type'}}

In [9]:
perturbation_covariates = dm_args["perturbation_covariates"]
perturbation_covariate_reps = dm_args["perturbation_covariate_reps"]
sample_covariates = dm_args["sample_covariates"]
sample_covariate_reps = dm_args["sample_covariate_reps"]


In [10]:
adata = dm_args["adata"]

In [11]:
cf.prepare_data(
    sample_rep="X",
    control_key="control",
    perturbation_covariates=perturbation_covariates,
    perturbation_covariate_reps=perturbation_covariate_reps,
    sample_covariates=sample_covariates,
    sample_covariate_reps=sample_covariate_reps,
    split_covariates=dm_args["split_covariates"],
)
print("Finished preparing data")

[########################################] | 100% Completed | 105.83 ms
[########################################] | 100% Completed | 105.76 ms
[########################################] | 100% Completed | 313.10 ms
Finished preparing data


In [12]:
# condition_embedding_dim: 256
# time_encoder_dims: [2048, 2048, 2048]
# time_encoder_dropout: 0.0
# hidden_dims: [2048, 2048, 2048]
# hidden_dropout: 0.0
# decoder_dims: [4096, 4096, 4096]
# decoder_dropout: 0.2
# pooling: "mean"
# layers_before_pool: 
#   drugs:
#     layer_type: mlp
#     dims: [1024, 1024]
#     dropout_rate: 0.5
#   dose:
#     layer_type: mlp
#     dims: [256, 256]
#     dropout_rate: 0.0
#   cell_line:
#     layer_type: mlp
#     dims: [1024, 1024]
#     dropout_rate: 0.2
# layers_after_pool:
#   layer_type: mlp
#   dims: [1024, 1024]
#   dropout_rate: 0.2
# cond_output_dropout: 0.9
# time_freqs: 1024
# flow_noise: 1.0
# learning_rate: 0.00005
# multi_steps: 20
# epsilon: 1.0
# tau_a: 1.0
# tau_b: 1.0
# flow_type: "constant_noise"
# linear_projection_before_concatenation: False
# layer_norm_before_concatenation: False

In [13]:



match_fn = functools.partial(
    solver_utils.match_linear,
    epsilon=1.0,
    scale_cost="mean",
    tau_a=1.0,
    tau_b=1.0,
)
optimizer = optax.MultiSteps(optax.adam(0.00005), 20)
flow = {
        "constant_noise": 1.0}
layers_before_pool = {
    "drug": {
        "layer_type": "mlp",
        "dims": [1024, 1024],
        "dropout_rate": 0.5,
    },
    "dosage": {
        "layer_type": "mlp",
        "dims": [256, 256],
        "dropout_rate": 0.0,
    },
    "cell_type": {
        "layer_type": "mlp",
        "dims": [1024, 1024],
        "dropout_rate": 0.2,
    },
}
layers_after_pool = {
    "layer_type": "mlp",
    "dims": [1024, 1024],
    "dropout_rate": 0.2,
}
condition_embedding_dim = 256
pooling = "mean"
time_encoder_dims = [2048, 2048, 2048]
time_encoder_dropout = 0.0
hidden_dims = [2048, 2048, 2048]
hidden_dropout = 0.0
decoder_dims = [4096, 4096, 4096]
decoder_dropout = 0.2
cond_output_dropout = 0.9
time_freqs = 1024
layer_norm_before_concatenation = False
linear_projection_before_concatenation = False
# Prepare the model
print("Preparing model...")
cf.prepare_model(
    encode_conditions=True,
    condition_embedding_dim=condition_embedding_dim,
    pooling=pooling,
    time_encoder_dims=time_encoder_dims,
    time_encoder_dropout=time_encoder_dropout,
    hidden_dims=hidden_dims,
    hidden_dropout=hidden_dropout,
    decoder_dims=decoder_dims,
    decoder_dropout=decoder_dropout,
    layers_before_pool=layers_before_pool,
    layers_after_pool=layers_after_pool,
    cond_output_dropout=cond_output_dropout,
    time_freqs=time_freqs,
    match_fn=match_fn,
    optimizer=optimizer,
    flow=flow,
    layer_norm_before_concatenation=False,
    linear_projection_before_concatenation=False,
)
print("Begin training")


Preparing model...
Begin training


In [14]:

cf.prepare_validation_data(
    adata,
    name="train",
    n_conditions_on_log_iteration=10,
    n_conditions_on_train_end=10,
)

cf.prepare_validation_data(
    adata,
    name="test",
    n_conditions_on_log_iteration=None,
    n_conditions_on_train_end=None,
)

[########################################] | 100% Completed | 105.86 ms
[########################################] | 100% Completed | 310.02 ms
[########################################] | 100% Completed | 105.81 ms
[########################################] | 100% Completed | 310.15 ms


In [15]:
import cellflow

In [16]:
metrics_callback = cellflow.training.Metrics(metrics=["r_squared", "mmd", "e_distance"])

# we don't pass the wandb_callback as it requires a user-specific account, but recommend setting it up
callbacks = [metrics_callback]

In [17]:
cf.solver.vf_state.params['condition_encoder']['after_pool_modules_mean_0']['kernel'].devices()

{CpuDevice(id=0)}

In [18]:
cf.train_data.control_to_perturbation

{0: array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17], dtype=int32),
 1: array([18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,
        35], dtype=int32),
 2: array([36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52,
        53], dtype=int32),
 3: array([54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70,
        71], dtype=int32),
 4: array([72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88,
        89], dtype=int32),
 5: array([ 90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102,
        103, 104, 105, 106, 107], dtype=int32),
 6: array([108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
        121, 122, 123, 124, 125], dtype=int32),
 7: array([126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138,
        139, 140, 141, 142, 143], dtype=int32),
 8: array([144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156,
        157, 158, 159,

In [19]:
cf.train_data.split_idx_to_covariates

{0: (np.str_('cell_line_a'),),
 1: (np.str_('cell_line_b'),),
 2: (np.str_('cell_line_c'),),
 3: (np.str_('cell_line_d'),),
 4: (np.str_('cell_line_e'),),
 5: (np.str_('cell_line_f'),),
 6: (np.str_('cell_line_g'),),
 7: (np.str_('cell_line_h'),),
 8: (np.str_('cell_line_i'),),
 9: (np.str_('cell_line_j'),),
 10: (np.str_('cell_line_k'),),
 11: (np.str_('cell_line_l'),),
 12: (np.str_('cell_line_m'),),
 13: (np.str_('cell_line_n'),),
 14: (np.str_('cell_line_o'),),
 15: (np.str_('cell_line_p'),),
 16: (np.str_('cell_line_q'),),
 17: (np.str_('cell_line_r'),),
 18: (np.str_('cell_line_s'),),
 19: (np.str_('cell_line_t'),),
 20: (np.str_('cell_line_u'),),
 21: (np.str_('cell_line_v'),),
 22: (np.str_('cell_line_w'),),
 23: (np.str_('cell_line_x'),),
 24: (np.str_('cell_line_y'),),
 25: (np.str_('cell_line_z'),),
 26: (np.str_('cell_line_{'),),
 27: (np.str_('cell_line_|'),),
 28: (np.str_('cell_line_}'),),
 29: (np.str_('cell_line_~'),)}

In [20]:
from cellflow.data._dataloader import CpuTrainSampler
dataloader = CpuTrainSampler(data=cf.train_data, batch_size=100)

In [21]:
import queue
import threading
import time  # Add this import at the top of the file
from collections.abc import Sequence
from typing import Any, Literal
import jax.numpy as jnp
import jax
import numpy as np
from numpy.typing import ArrayLike
from tqdm import tqdm

from cellflow.data._dataloader import CpuTrainSampler, ValidationSampler
from cellflow.solvers import _genot, _otfm
from cellflow.training._callbacks import BaseCallback, CallbackRunner


def prefetch_to_device(sampler, num_iterations, prefetch_factor=2, num_workers=4):
    seed = 42  # Set a fixed seed for reproducibility
    seq = np.random.SeedSequence(seed)
    random_generators = [np.random.default_rng(s) for s in seq.spawn(num_workers)]

    q = queue.Queue(maxsize=prefetch_factor*num_workers)
    sem = threading.Semaphore(num_iterations)
    stop_event = threading.Event()
    def worker(rng):
        while not stop_event.is_set() and sem.acquire(blocking=False):
            batch = sampler.sample(rng)
            batch = jax.device_put(batch, jax.devices()[0], donate=True)
            jax.block_until_ready(batch)
            while not stop_event.is_set():
                try:
                    q.put(batch, timeout=1.0)
                    break  # Batch successfully put into the queue; break out of retry loop
                except queue.Full:
                    continue

        return

    # Start multiple worker threads
    ts = []
    for i in range(num_workers):
        t = threading.Thread(target=worker, daemon=True, name=f"worker-{i}", args=(random_generators[i], ))
        t.start()
        ts.append(t)

    try:
        for _ in range(num_iterations):
            # Yield batches from the queue; will block waiting for available batch
            yield q.get()
    finally:
        # When the generator is closed or garbage collected, clean up the worker threads
        stop_event.set()  # Signal all workers to exit
        for t in ts:
            t.join()  # Wait for all worker threads to finish




In [22]:
cf.train_data.condition_data

{'drug': array([[[ 0.92979765, -0.80951226,  1.6539406 , ...,  1.7871232 ,
          -0.8401395 , -1.0057048 ]],
 
        [[ 0.92979765, -0.80951226,  1.6539406 , ...,  1.7871232 ,
          -0.8401395 , -1.0057048 ]],
 
        [[ 0.92979765, -0.80951226,  1.6539406 , ...,  1.7871232 ,
          -0.8401395 , -1.0057048 ]],
 
        ...,
 
        [[-0.808961  , -0.11472495, -1.0744399 , ...,  1.3312039 ,
           0.9897936 , -0.89862764]],
 
        [[-0.808961  , -0.11472495, -1.0744399 , ...,  1.3312039 ,
           0.9897936 , -0.89862764]],
 
        [[-0.808961  , -0.11472495, -1.0744399 , ...,  1.3312039 ,
           0.9897936 , -0.89862764]]], shape=(540, 1, 50), dtype=float32),
 'dosage': array([[[0.33333334]],
 
        [[0.6666667 ]],
 
        [[1.        ]],
 
        [[0.33333334]],
 
        [[0.6666667 ]],
 
        [[1.        ]],
 
        [[0.33333334]],
 
        [[0.6666667 ]],
 
        [[1.        ]],
 
        [[0.33333334]],
 
        [[0.6666667 ]],
 
    

In [23]:
# dl = prefetch_to_device(dataloader, num_iterations=100, prefetch_factor=2, num_workers=32)

In [24]:
# batch = next(dl)

In [25]:
rng = np.random.default_rng(42)

In [26]:
dataloader.sample(rng)['src_cell_data']

Source: 2 out of 30, Target: 39 out of 540


array([[ 1.3571423 , -0.7638487 ,  1.6204927 , ...,  0.02975455,
         1.6972175 ,  2.96692   ],
       [ 0.26846367,  0.43569034, -0.44510853, ..., -0.65811443,
        -0.5340969 , -0.03285657],
       [-0.2044165 ,  0.02244166, -0.27570003, ...,  1.9536748 ,
         0.73222953,  0.6051152 ],
       ...,
       [-0.2769228 , -1.4887941 ,  2.01487   , ..., -0.14210315,
         0.49233574, -0.47111952],
       [ 0.9451178 ,  0.62461495,  1.345472  , ..., -0.54375166,
         0.3643677 , -0.35084444],
       [ 0.6104594 , -0.6219473 , -0.07622625, ...,  0.6528209 ,
         1.6353574 , -0.5036002 ]], shape=(100, 50), dtype=float32)

In [27]:
next(dl)['src_cell_data']

NameError: name 'dl' is not defined

In [None]:
seed = 1
rng = np.random.default_rng(np.random.SeedSequence(seed + 10000))

In [None]:
dataloader.sample(rng)

Source: 13 out of 30, Target: 235 out of 540


{'src_cell_data': array([[ 0.18927507,  0.9769641 , -1.0818534 , ...,  0.21130458,
         -0.89402765,  0.9413086 ],
        [ 0.991466  , -0.90771914, -0.3253671 , ...,  1.8505851 ,
          0.9478476 ,  2.402185  ],
        [ 0.30937192,  0.6687417 , -0.4018558 , ..., -1.9595044 ,
         -0.51756215,  0.2516776 ],
        ...,
        [ 0.47926587,  0.5996488 , -0.65116936, ..., -0.40527347,
         -0.23594505,  0.66677547],
        [-0.06974106, -0.8401346 , -0.9127295 , ...,  0.5225384 ,
         -1.5457108 ,  1.3589613 ],
        [ 0.18927507,  0.9769641 , -1.0818534 , ...,  0.21130458,
         -0.89402765,  0.9413086 ]], shape=(100, 50), dtype=float32),
 'tgt_cell_data': array([[-2.0382755e+00,  3.0630982e-01,  1.5062246e+00, ...,
          9.5881987e-01,  3.7822732e-01, -8.7607884e-01],
        [ 2.2199468e-01, -8.5328233e-01, -5.3705686e-01, ...,
         -4.7658074e-01, -3.7956384e-01,  3.5889629e-02],
        [ 9.5031381e-01, -8.6416805e-01,  1.1294960e+00, ...,
     

In [None]:
from cellflow.data._dataloader import ValidationSampler
validation_loaders = {k: ValidationSampler(v) for k, v in cf.validation_data.items()}

In [None]:
validation_loaders['train']

<cellflow.data._dataloader.ValidationSampler at 0x34f0d9150>

In [28]:
cf.train(
    num_iterations=1000,
    batch_size=1024,
    callbacks=callbacks,
    valid_freq=100,
    prefetch_factor=3,
    num_workers=8,
)

  0%|          | 0/1000 [00:00<?, ?it/s]

Source: 2 out of 30, Target: 46 out of 540Source: 14 out of 30, Target: 262 out of 540

Source: 21 out of 30, Target: 385 out of 540
Source: 28 out of 30, Target: 510 out of 540
Source: 26 out of 30, Target: 484 out of 540
Source: 20 out of 30, Target: 364 out of 540
Source: 19 out of 30, Target: 353 out of 540
Source: 16 out of 30, Target: 292 out of 540
Source: 27 out of 30, Target: 492 out of 540
Source: 23 out of 30, Target: 418 out of 540
Source: 13 out of 30, Target: 240 out of 540
Source: 18 out of 30, Target: 327 out of 540
Source: 29 out of 30, Target: 522 out of 540
Source: 22 out of 30, Target: 413 out of 540
Source: 1 out of 30, Target: 19 out of 540
Source: 11 out of 30, Target: 208 out of 540
Source: 27 out of 30, Target: 494 out of 540
Source: 19 out of 30, Target: 344 out of 540
Source: 18 out of 30, Target: 341 out of 540
Source: 2 out of 30, Target: 51 out of 540
Source: 20 out of 30, Target: 370 out of 540
Source: 20 out of 30, Target: 366 out of 540
Source: 14 out o

  0%|          | 1/1000 [00:02<38:55,  2.34s/it]

Source: 7 out of 30, Target: 134 out of 540


  0%|          | 2/1000 [00:03<26:07,  1.57s/it]

Source: 3 out of 30, Target: 62 out of 540


  0%|          | 3/1000 [00:04<22:02,  1.33s/it]

Source: 2 out of 30, Target: 47 out of 540


  0%|          | 4/1000 [00:05<20:09,  1.21s/it]

Source: 15 out of 30, Target: 282 out of 540


  0%|          | 5/1000 [00:06<19:16,  1.16s/it]

Source: 23 out of 30, Target: 419 out of 540


  1%|          | 6/1000 [00:07<18:53,  1.14s/it]

Source: 2 out of 30, Target: 49 out of 540


  1%|          | 7/1000 [00:08<18:23,  1.11s/it]

Source: 20 out of 30, Target: 376 out of 540


  1%|          | 8/1000 [00:09<18:13,  1.10s/it]

Source: 3 out of 30, Target: 70 out of 540


  1%|          | 9/1000 [00:10<18:01,  1.09s/it]

Source: 7 out of 30, Target: 131 out of 540


  1%|          | 10/1000 [00:11<18:03,  1.09s/it]

Source: 12 out of 30, Target: 224 out of 540


  1%|          | 11/1000 [00:13<18:08,  1.10s/it]

Source: 12 out of 30, Target: 226 out of 540


  1%|          | 12/1000 [00:14<18:31,  1.12s/it]

Source: 5 out of 30, Target: 93 out of 540


  1%|▏         | 13/1000 [00:15<18:23,  1.12s/it]

Source: 1 out of 30, Target: 21 out of 540


  1%|▏         | 14/1000 [00:16<18:13,  1.11s/it]

Source: 7 out of 30, Target: 130 out of 540


  2%|▏         | 15/1000 [00:17<18:16,  1.11s/it]

Source: 4 out of 30, Target: 81 out of 540


  2%|▏         | 16/1000 [00:18<18:14,  1.11s/it]

Source: 0 out of 30, Target: 8 out of 540


  2%|▏         | 17/1000 [00:19<18:24,  1.12s/it]

Source: 14 out of 30, Target: 257 out of 540


  2%|▏         | 18/1000 [00:20<18:14,  1.11s/it]

Source: 7 out of 30, Target: 139 out of 540


  2%|▏         | 19/1000 [00:22<18:16,  1.12s/it]

Source: 6 out of 30, Target: 117 out of 540


  2%|▏         | 20/1000 [00:23<18:17,  1.12s/it]

Source: 20 out of 30, Target: 368 out of 540


  2%|▏         | 21/1000 [00:24<18:22,  1.13s/it]

Source: 9 out of 30, Target: 175 out of 540


  2%|▏         | 22/1000 [00:25<18:24,  1.13s/it]

Source: 14 out of 30, Target: 258 out of 540


  2%|▏         | 23/1000 [00:26<18:17,  1.12s/it]

Source: 25 out of 30, Target: 464 out of 540


  2%|▏         | 24/1000 [00:27<18:19,  1.13s/it]

Source: 20 out of 30, Target: 366 out of 540


  2%|▎         | 25/1000 [00:28<18:26,  1.13s/it]

Source: 14 out of 30, Target: 258 out of 540


  3%|▎         | 26/1000 [00:29<18:23,  1.13s/it]

Source: 9 out of 30, Target: 173 out of 540


  3%|▎         | 27/1000 [00:31<18:23,  1.13s/it]

Source: 26 out of 30, Target: 484 out of 540


  3%|▎         | 28/1000 [00:32<18:15,  1.13s/it]

Source: 21 out of 30, Target: 380 out of 540


  3%|▎         | 29/1000 [00:33<18:12,  1.13s/it]

Source: 25 out of 30, Target: 458 out of 540


  3%|▎         | 30/1000 [00:34<18:16,  1.13s/it]

Source: 10 out of 30, Target: 196 out of 540


  3%|▎         | 31/1000 [00:35<18:10,  1.12s/it]

Source: 29 out of 30, Target: 532 out of 540


  3%|▎         | 32/1000 [00:36<18:05,  1.12s/it]

Source: 19 out of 30, Target: 355 out of 540


  3%|▎         | 33/1000 [00:37<18:02,  1.12s/it]

Source: 17 out of 30, Target: 318 out of 540


  3%|▎         | 34/1000 [00:38<18:00,  1.12s/it]

Source: 27 out of 30, Target: 489 out of 540


  4%|▎         | 35/1000 [00:40<18:10,  1.13s/it]

Source: 21 out of 30, Target: 391 out of 540


  4%|▎         | 36/1000 [00:41<18:15,  1.14s/it]

Source: 1 out of 30, Target: 28 out of 540


  4%|▎         | 37/1000 [00:42<18:16,  1.14s/it]

Source: 20 out of 30, Target: 371 out of 540


  4%|▍         | 38/1000 [00:43<19:36,  1.22s/it]

Source: 3 out of 30, Target: 71 out of 540


  4%|▍         | 38/1000 [00:45<19:00,  1.19s/it]


KeyboardInterrupt: 