## Multi-task UnitedNet on perturbation prediction with the PBMC dataset

Motivation relies on the potential of scButterfly for perturbation response on the PBMC dataset stimulated by IFN-beta. scButterfly and UnitedNet share similar architectures. The seconds one leverages multi-task learning. Thus, we will attempt for a simple dataset with one perturbation type to explore the usage of UnitedNet with the tasks:

- cross modal prediction on perturbation (perturbation response prediction task)
- classification task for cell type annotation

We would like to investigate if multi-task can be beneficial for perturbation modeling. Thus we will compare the multi-task model with the single-task models of the classification and the cross modal prediction task.

In [29]:
# prepare dataset
# dataset from https://github.com/theislab/scgen-reproducibility, already preprocessed

import scanpy as sc

train_data = "../data/pbmc/train_pbmc.h5ad"
valid_data = "../data/pbmc/valid_pbmc.h5ad"

# pbmc data from scGen is already split into train and valid, but we will concatenate them to do a custom split
sc_data_train = sc.read_h5ad(train_data)
sc_data_valid = sc.read_h5ad(valid_data)
pbmc = sc_data_train.concatenate(sc_data_valid)

In [30]:
pbmc

AnnData object with n_obs × n_vars = 18868 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type', 'batch'
    var: 'gene_symbol', 'n_cells'
    obsm: 'X_pca', 'X_tsne', 'X_umap'

In [31]:
pbmc.obs.groupby(['cell_type', 'condition']).size()

cell_type    condition 
CD4T         control       2715
             stimulated    3483
CD14+Mono    control       2184
             stimulated     698
B            control        928
             stimulated    1105
CD8T         control        643
             stimulated     594
NK           control        571
             stimulated     733
FCGR3A+Mono  control       1232
             stimulated    2790
Dendritic    control        670
             stimulated     522
dtype: int64

In [32]:
pbmc.obs.groupby(['condition']).size()

condition
control       8943
stimulated    9925
dtype: int64

In [33]:
# simple inspection

# sc.pp.pca(pbmc)
# sc.pp.neighbors(pbmc)
# sc.tl.tsne(pbmc)
# sc.pl.tsne(pbmc, color=['condition', 'cell_type'], legend_loc='on data', legend_fontsize='small')

### Create pairs of control and stimulated data using scButterfly method

scButterfly showed the potential of using a dual VAR scheme for perturbation response prediction. They split the cells based on the cell type for control and stimulated, and then per cell type group, they create pairs of control and stimulated gene expression profile using optimal transport.

It should be noted that during perturbation, we can't measure the exact same cell that was perturbed. The current experimental methods, don't allow this type of information. Thus, deep learning architectures that are relying on pairs, such as scButterfly and UnitedNet (combining modalities), need to find a strategy to explore if creating pseudo pairs is beneficial.

scButterfly showed this exact potential.

In [34]:
from scButterfly.split_datasets import unpaired_split_dataset_perturb


# create pairs of data using scButterfly technique
control_data = pbmc[pbmc.obs.condition == 'control']
stimulate_data = pbmc[pbmc.obs.condition == 'stimulated']

control_data.obs.index = [str(i) for i in range(control_data.X.shape[0])]
stimulate_data.obs.index = [str(i) for i in range(stimulate_data.X.shape[0])]

cell_type_list = list(control_data.obs.cell_type.cat.categories)

In [35]:
id_list, id_list_dict = unpaired_split_dataset_perturb(control_data, stimulate_data)

optimal transport array torch.Size([2715, 3483])
CD4T, control num 2715 stimulate num 3483

optimal transport array torch.Size([2184, 698])
CD14+Mono, control num 2184 stimulate num 698

optimal transport array torch.Size([928, 1105])
B, control num 928 stimulate num 1105

optimal transport array torch.Size([643, 594])
CD8T, control num 643 stimulate num 594

optimal transport array torch.Size([571, 733])
NK, control num 571 stimulate num 733

optimal transport array torch.Size([1232, 2790])
FCGR3A+Mono, control num 1232 stimulate num 2790

optimal transport array torch.Size([670, 522])
Dendritic, control num 670 stimulate num 522


Start CD4T
Batch list ['Dendritic', 'B', 'FCGR3A+Mono', 'CD14+Mono', 'CD4T', 'NK', 'CD8T']
Test batch control ['Dendritic']
Validation batch control ['B', 'FCGR3A+Mono', 'CD14+Mono', 'CD4T', 'NK', 'CD8T']
Train batch control ['B', 'FCGR3A+Mono', 'CD14+Mono', 'CD4T', 'NK', 'CD8T']
Test batch stimulated ['Dendritic']
Validation batch stimulated ['B', 'FCGR3A+

In [36]:
#print(len(id_list[0]))
for idx, cell_type in enumerate(cell_type_list):
    print(cell_type)
    train_id_control, train_id_peturb, validation_id_control, validation_id_peturb, test_id_control, test_id_peturb = id_list[idx]

    control = len(train_id_control) + len(validation_id_control) + len(test_id_control)
    peturb = len(train_id_peturb) + len(validation_id_peturb) + len(test_id_peturb) 
    control_dict = id_list_dict[cell_type]["control"]
    peturb_dict = id_list_dict[cell_type]["stimulated"]
    print("control", len(train_id_control), len(validation_id_control), len(test_id_control), control)
    print("control", len(control_dict["train"]), len(control_dict["validation"]), len(control_dict["test"]))
    print("peturb", len(train_id_peturb), len(validation_id_peturb), len(test_id_peturb), peturb)    
    print("peturb", len(peturb_dict["train"]), len(peturb_dict["validation"]), len(peturb_dict["test"]))
    print()
    
# Some notes due to the lack of code readability from scButterfly:
# - the rows of the id_list correspond to the cell types, but they don't have any meaning.
# - the rows actually correspond to k-fold cross validation. One batch is a cell type that is held out.
# - the cell_type of the row and the one that is held out are irrelevant.
# - for the test, the control and the stimulated can have different size. But for train and validation we need pairs.

CD4T
control 6616 1657 670 8943
control 6616 1657 670
peturb 6616 1657 522 8795
peturb 6616 1657 522

CD14+Mono
control 6410 1605 928 8943
control 6410 1605 928
peturb 6410 1605 1105 9120
peturb 6410 1605 1105

B
control 6167 1544 1232 8943
control 6167 1544 1232
peturb 6167 1544 2790 10501
peturb 6167 1544 2790

CD8T
control 5405 1354 2184 8943
control 5405 1354 2184
peturb 5405 1354 698 7457
peturb 5405 1354 698

NK
control 4980 1248 2715 8943
control 4980 1248 2715
peturb 4980 1248 3483 9711
peturb 4980 1248 3483

FCGR3A+Mono
control 6696 1676 571 8943
control 6696 1676 571
peturb 6696 1676 733 9105
peturb 6696 1676 733

Dendritic
control 6638 1662 643 8943
control 6638 1662 643
peturb 6638 1662 594 8894
peturb 6638 1662 594



In [41]:
# create adatas for pair of control and stimulated
import numpy as np


batch = cell_type_list[0]
id_list_batch = id_list_dict[batch]
control = id_list_batch["control"]
peturb = id_list_batch["stimulated"]
train_id_control = control["train"]
validation_id_control = control["validation"]
test_id_control = control["test"]
train_id_peturb = peturb["train"]
validation_id_peturb = peturb["validation"]
test_id_peturb = peturb["test"]

control_train = pbmc[train_id_control]
control_valid = pbmc[validation_id_control]
control_test = pbmc[test_id_control]

peturb_train = pbmc[train_id_peturb]
peturb_valid = pbmc[validation_id_peturb]
peturb_test = pbmc[test_id_peturb]

all = [control_train, control_valid, control_test, peturb_train, peturb_valid, peturb_test]
for adata in all:
    adata.obs["label"] = list(adata.obs["cell_type"])
    adata.X = adata.X.toarray()
    print("min", np.max(adata.X))
    print("max", np.min(adata.X))
    print()

adatas_train = [control_train, peturb_train]
adatas_valid = [control_valid, peturb_valid]
adatas_test = [control_test, peturb_test]

features_num = adatas_train[0].n_vars
adatas_test


min 7.0346947
max 0.0

min 6.822089
max 0.0

min 6.7637987
max 0.0

min 6.879698
max 0.0

min 7.0346947
max 0.0

min 5.961386
max 0.0



[AnnData object with n_obs × n_vars = 670 × 6998
     obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type', 'batch', 'label'
     var: 'gene_symbol', 'n_cells'
     obsm: 'X_pca', 'X_tsne', 'X_umap',
 AnnData object with n_obs × n_vars = 522 × 6998
     obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type', 'batch', 'label'
     var: 'gene_symbol', 'n_cells'
     obsm: 'X_pca', 'X_tsne', 'X_umap']

### Prepare multi-task model

UnitedNet attempts to solve two tasks, the cell type annotation and the cross modal prediction.

Metrics overview:

Translation
- R2

Classification
- Confusion matrix
- Adjusted Rank Index
- Accuracy
- Normalized Mutual information

Both of the classification and the translation network's parameters are taken into account in the multi-task way orchestrated by UnitedNet.

In [40]:
from unitednet.interface import UnitedNet


pbmc_config = {
    "train_batch_size": 256,
    "finetune_batch_size": 5000,
    "transfer_batch_size": 512,
    "train_epochs": 50,
    "finetune_epochs": 10,
    "transfer_epochs": 20,
    "train_task": "supervised_group_identification", # -> translation, classification
    "finetune_task": None,
    "transfer_task": None,
    "train_loss_weight": None,
    "finetune_loss_weight": None,
    "transfer_loss_weight": None,
    "lr": 0.1,
    "checkpoint": 1,
    "n_head": 1,
    "noise_level":[0,0],
    "fuser_type":"WeightedMean",
    "encoders": [
        {
            "input": features_num,
            "hiddens": [64, 64],
            "output": 64,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "relu"],
            "use_batch_norms": [True, True, True],
            "use_layer_norms": [False, False, False],
            "is_binary_input": False,
        },
        {
            "input": features_num,
            "hiddens": [64, 64],
            "output": 64,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "relu"],
            "use_batch_norms": [True, True, True],
            "use_layer_norms": [False, False, False],
            "is_binary_input": False,
        },
    ],
    "latent_projector": None,
    "decoders": [
        {
            "input": 64,
            "hiddens": [64, 64],
            "output": features_num,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "sigmoid"],
            "use_batch_norms": [False, False, False],
            "use_layer_norms": [False, False, False],
        },
        {
            "input": 64,
            "hiddens": [64, 64],
            "output": features_num,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", None],
            "use_batch_norms": [False, False, False],
            "use_layer_norms": [False, False, False],
        },
    ],
    "discriminators": [
        {
            "input": features_num,
            "hiddens": [64],
            "output": 1,
            "use_biases": [True, True],
            "dropouts": [0, 0],
            "activations": ["relu", "sigmoid"],
            "use_batch_norms": [False, False],
            "use_layer_norms": [False, True],
        },
        {
            "input": features_num,
            "hiddens": [64],
            "output": 1,
            "use_biases": [True, True],
            "dropouts": [0, 0],
            "activations": ["relu", "sigmoid"],
            "use_batch_norms": [False, False],
            "use_layer_norms": [False, True],
        },
    ],
    "projectors": {
        "input": 64,
        "hiddens": [],
        "output": 100,
        "use_biases": [True],
        "dropouts": [0],
        "activations": ["relu"],
        "use_batch_norms": [False],
        "use_layer_norms": [True],
    },
    "clusters": {
        "input": 100,
        "hiddens": [],
        "output": len(cell_type_list),
        "use_biases": [False],
        "dropouts": [0],
        "activations": [None],
        "use_batch_norms": [False],
        "use_layer_norms": [False],
    },
}

device="cuda:0"
test_batch= batch
root_save_path = f"saved_results/thesis/pbmc/multi_task"


model = UnitedNet("multitask", f"{root_save_path}/{test_batch}", device=device, technique=pbmc_config)
model.train(adatas_train=adatas_train)
#model.finetune(adatas_train)
#model.transfer(adatas_train, adatas_transfer = adatas_test, verbose=True)
#print(model.evaluate(adatas_test))
# evaluation should be different for peturbation

training


  0%|          | 0/50 [00:00<?, ?it/s]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_1.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05405913]
                                [-1.99397302 -1.99397302]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  435    0    0    0   18    0]
                                [   0    0 2326   23    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   3    0    0    0  332    4    0]
                                [   2   32    1    0    4 1820    0]
                                [   0    0    0    2    0    1  478]]
Accuracy                       0.9863966142684402
Adjusted Rand Index            0.973166750722691
Normalized Mutual Information  0.9592932191659395


  2%|▏         | 1/50 [00:07<05:52,  7.19s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_2.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.13593295 -0.13593295]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   1    0 2330   16    2    0    0]
                                [   0    0    0  388    0    0    2]
                                [   0    0    0    0  335    4    0]
                                [   0    8    0    0    0 1851    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9950120918984281
Adjusted Rand Index            0.9896526460167843
Normalized Mutual Information  0.9829681328421657


  4%|▍         | 2/50 [00:13<05:33,  6.94s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_3.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.01139272 -0.01139272]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  434    0    0    0   19    0]
                                [   0    0 2349    0    0    0    0]
                                [   1    0    0  389    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9969770253929867
Adjusted Rand Index            0.9943063721182498
Normalized Mutual Information  0.990921059835599


  6%|▌         | 3/50 [00:21<05:32,  7.08s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_4.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00344543 -0.00344543]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  435    0    0    0   18    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    2  388    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    1 1858    0]
                                [   0    0    0    2    0    0  479]]
Accuracy                       0.9965235792019347
Adjusted Rand Index            0.9935451229708105
Normalized Mutual Information  0.9887757003241716


  8%|▊         | 4/50 [00:28<05:22,  7.00s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_5.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.0032968  -0.0032968 ]]
Confusion Matrix               [[ 738    0    3    0    3    1    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  338    1    0]
                                [   0    1    0    0    0 1858    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.998639661426844
Adjusted Rand Index            0.9974986703651537
Normalized Mutual Information  0.9942584722301264


 10%|█         | 5/50 [00:35<05:14,  6.98s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_6.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00337104 -0.00337104]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   2    0 2347    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    1    0    0    0 1858    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.999546553808948
Adjusted Rand Index            0.9989156227975491
Normalized Mutual Information  0.9979277117890427


 12%|█▏        | 6/50 [00:42<05:20,  7.28s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_7.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00380702 -0.00380702]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   5    1 2342    0    0    1    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9989419588875453
Adjusted Rand Index            0.9971341816602007
Normalized Mutual Information  0.9955061376145117


 14%|█▍        | 7/50 [00:50<05:18,  7.40s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_8.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00436152 -0.00436152]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    1  389    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9998488512696493
Adjusted Rand Index            0.9996506534238886
Normalized Mutual Information  0.9992854921266813


 16%|█▌        | 8/50 [00:57<05:09,  7.38s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_9.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.0045095  -0.0045095 ]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    1    3    0    1  476]]
Accuracy                       0.9992442563482467
Adjusted Rand Index            0.9990093184492642
Normalized Mutual Information  0.996934971801183


 18%|█▊        | 9/50 [01:05<05:01,  7.35s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_10.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00460219 -0.00460219]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  334    5    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9992442563482467
Adjusted Rand Index            0.9986014934805342
Normalized Mutual Information  0.997244000149065


 20%|██        | 10/50 [01:12<04:56,  7.42s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_11.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00493432 -0.00493432]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    2 2347    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0  212    0    0    3 1644    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9672007255139057
Adjusted Rand Index            0.9406093480255484
Normalized Mutual Information  0.9486363765232089


 22%|██▏       | 11/50 [01:19<04:46,  7.33s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_12.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00519205 -0.00519205]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0   12    0    0    0 1847    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.998186215235792
Adjusted Rand Index            0.9964739853276727
Normalized Mutual Information  0.9941804874314073


 24%|██▍       | 12/50 [01:26<04:33,  7.20s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_13.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00520101 -0.00520101]]
Confusion Matrix               [[ 743    0    2    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2347    1    0    0    1]
                                [   0    0    0  389    0    1    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9992442563482467
Adjusted Rand Index            0.9982144724346461
Normalized Mutual Information  0.996495922746686


 26%|██▌       | 13/50 [01:34<04:27,  7.24s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_14.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00573095 -0.00573095]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2336   13    0    0    0]
                                [   0    0    1  389    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9978839177750907
Adjusted Rand Index            0.9951239769361155
Normalized Mutual Information  0.9930274929504409


 28%|██▊       | 14/50 [01:40<04:15,  7.09s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_15.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.0058851  -0.0058851 ]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2346    3    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.999546553808948
Adjusted Rand Index            0.9989521683201368
Normalized Mutual Information  0.9981559923753824


 30%|███       | 15/50 [01:47<04:07,  7.06s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_16.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00589447 -0.00589447]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 32%|███▏      | 16/50 [01:54<04:01,  7.09s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_17.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00618335 -0.00618335]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 34%|███▍      | 17/50 [02:01<03:52,  7.05s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_18.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00643686 -0.00643686]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 36%|███▌      | 18/50 [02:08<03:44,  7.01s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_19.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00686404 -0.00686404]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 38%|███▊      | 19/50 [02:16<03:40,  7.10s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_20.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00668007 -0.00668007]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0   15    0    0    0 1844    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9977327690447401
Adjusted Rand Index            0.9955969909673037
Normalized Mutual Information  0.9930292266833847


 40%|████      | 20/50 [02:23<03:31,  7.06s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_21.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00706654 -0.00706654]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    1  389    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9998488512696493
Adjusted Rand Index            0.9996506534238886
Normalized Mutual Information  0.9992854921266813


 42%|████▏     | 21/50 [02:30<03:23,  7.02s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_22.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.0074088  -0.0074088 ]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  338    1    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9998488512696493
Adjusted Rand Index            0.999719672268272
Normalized Mutual Information  0.999302497231813


 44%|████▍     | 22/50 [02:36<03:15,  6.97s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_23.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00772077 -0.00772077]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 46%|████▌     | 23/50 [02:43<03:07,  6.96s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_24.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00781748 -0.00781748]]
Confusion Matrix               [[ 744    0    1    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9998488512696493
Adjusted Rand Index            0.9996053487829619
Normalized Mutual Information  0.9992560785476025


 48%|████▊     | 24/50 [02:50<03:01,  6.96s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_25.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00798506 -0.00798506]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 50%|█████     | 25/50 [02:57<02:54,  6.96s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_26.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00802139 -0.00802139]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 52%|█████▏    | 26/50 [03:05<02:49,  7.05s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_27.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00812726 -0.00812726]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 54%|█████▍    | 27/50 [03:12<02:42,  7.07s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_28.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00836856 -0.00836856]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2348    1    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0   58    0    0    0 1801    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9910822249093107
Adjusted Rand Index            0.982874747101983
Normalized Mutual Information  0.979423770278878


 56%|█████▌    | 28/50 [03:18<02:33,  6.98s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_29.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00838576 -0.00838576]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 58%|█████▊    | 29/50 [03:25<02:26,  6.97s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_30.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00874683 -0.00874683]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 60%|██████    | 30/50 [03:33<02:22,  7.10s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_31.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00878869 -0.00878869]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 62%|██████▏   | 31/50 [03:40<02:14,  7.09s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_32.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05406102 -0.05406102]
                                [-0.00914311 -0.00914311]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 64%|██████▍   | 32/50 [03:47<02:07,  7.09s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_33.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05411801 -0.05411801]
                                [-0.0092072  -0.0092072 ]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 66%|██████▌   | 33/50 [03:54<02:00,  7.10s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_34.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05598298 -0.05598298]
                                [-0.00902095 -0.00902095]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2348    1    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9998488512696493
Adjusted Rand Index            0.9996505603685436
Normalized Mutual Information  0.9992855115365117


 68%|██████▊   | 34/50 [04:01<01:55,  7.19s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_35.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642126 -0.05642126]
                                [-0.00922293 -0.00922293]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 70%|███████   | 35/50 [04:09<01:47,  7.16s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_36.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05687925 -0.05687924]
                                [-0.00719143 -0.00719143]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 72%|███████▏  | 36/50 [04:16<01:39,  7.13s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_37.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05687925 -0.05687911]
                                [ 0.00083175  0.00083175]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  451    0    0    0    2    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  335    4    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.999093107617896
Adjusted Rand Index            0.9982905670581783
Normalized Mutual Information  0.9964177605586726


 74%|███████▍  | 37/50 [04:23<01:32,  7.09s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_38.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642126 -0.05642126]
                                [ 0.00636285  0.00636285]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  450    0    0    0    3    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.999546553808948
Adjusted Rand Index            0.9991163071878205
Normalized Mutual Information  0.9981674317036964


 76%|███████▌  | 38/50 [04:29<01:24,  7.02s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_39.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642126 -0.0599937 ]
                                [ 0.00614589  0.00614589]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 78%|███████▊  | 39/50 [04:37<01:17,  7.05s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_40.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642126 -0.0582789 ]
                                [ 0.00610899  0.00610899]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 80%|████████  | 40/50 [04:43<01:09,  7.00s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_41.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.0569082  -0.07034057]
                                [ 0.00624936  0.00624936]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 82%|████████▏ | 41/50 [04:50<01:02,  6.94s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_42.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.0569082  -0.05790845]
                                [ 0.00604184  0.00604184]]
Confusion Matrix               [[ 744    0    1    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   4    0 2345    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9992442563482467
Adjusted Rand Index            0.9980284371004242
Normalized Mutual Information  0.9967848227521642


 84%|████████▍ | 42/50 [04:57<00:55,  6.99s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_43.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05664361 -0.06821833]
                                [ 0.00574843  0.00574843]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2347    0    0    0    2]
                                [   0    0    1  384    0    0    5]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    3    0    0  478]]
Accuracy                       0.9983373639661427
Adjusted Rand Index            0.9980483950877972
Normalized Mutual Information  0.9938584614463343


 86%|████████▌ | 43/50 [05:05<00:49,  7.06s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_44.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05664361 -0.05850124]
                                [ 0.0056385   0.0056385 ]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 88%|████████▊ | 44/50 [05:12<00:42,  7.03s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_45.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05786815 -0.07058606]
                                [ 0.00498007  0.00498007]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    1    0    0    0    0  480]]
Accuracy                       0.9998488512696493
Adjusted Rand Index            0.9998809406937492
Normalized Mutual Information  0.9993507393385854


 90%|█████████ | 45/50 [05:19<00:35,  7.05s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_46.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05786815 -0.060869  ]
                                [ 0.00516474  0.00516474]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    3 1856    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.999546553808948
Adjusted Rand Index            0.9991594354660043
Normalized Mutual Information  0.9982068829067328


 92%|█████████▏| 46/50 [05:25<00:27,  6.97s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_47.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05786815 -0.06415566]
                                [ 0.00531331  0.00531331]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 94%|█████████▍| 47/50 [05:32<00:20,  6.97s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_48.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05786815 -0.05972582]
                                [ 0.00522006  0.00522006]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 96%|█████████▌| 48/50 [05:39<00:13,  6.90s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_49.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05786815 -0.06501304]
                                [ 0.00497342  0.00497342]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 98%|█████████▊| 49/50 [05:46<00:06,  6.98s/it]

model saved at saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_50.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05786815 -0.07130054]
                                [ 0.00449098  0.00449098]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


100%|██████████| 50/50 [05:53<00:00,  7.07s/it]


## Single-task translation via UnitedNet 

The classification network's parameters are ignored.

It should be noted that all the metrics used in the multi-task way are still monitored for the single task as well. But we are only interested for the translation, so the metrics to pay attention to is the R2.

In [16]:
from unitednet.interface import UnitedNet


pbmc_config_translation = {
    "train_batch_size": 512,
    "finetune_batch_size": 5000,
    "transfer_batch_size": 512,
    "train_epochs": 10,
    "finetune_epochs": 10,
    "transfer_epochs": 20,
    "train_task": "cross_model_prediction", # translation
    "finetune_task": None,
    "transfer_task": None,
    "train_loss_weight": None,
    "finetune_loss_weight": None,
    "transfer_loss_weight": None,
    "lr": 0.1,
    "checkpoint": 1,
    "n_head": 1,
    "noise_level":[0,0],
    "fuser_type":"WeightedMean",
    "encoders": [
        {
            "input": features_num,
            "hiddens": [64, 64],
            "output": 64,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "relu"],
            "use_batch_norms": [True, True, True],
            "use_layer_norms": [False, False, False],
            "is_binary_input": False,
        },
        {
            "input": features_num,
            "hiddens": [64, 64],
            "output": 64,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "relu"],
            "use_batch_norms": [True, True, True],
            "use_layer_norms": [False, False, False],
            "is_binary_input": False,
        },
    ],
    "latent_projector": None,
    "decoders": [
        {
            "input": 64,
            "hiddens": [64, 64],
            "output": features_num,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "sigmoid"],
            "use_batch_norms": [False, False, False],
            "use_layer_norms": [False, False, False],
        },
        {
            "input": 64,
            "hiddens": [64, 64],
            "output": features_num,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", None],
            "use_batch_norms": [False, False, False],
            "use_layer_norms": [False, False, False],
        },
    ],
    "discriminators": [
        {
            "input": features_num,
            "hiddens": [64],
            "output": 1,
            "use_biases": [True, True],
            "dropouts": [0, 0],
            "activations": ["relu", "sigmoid"],
            "use_batch_norms": [False, False],
            "use_layer_norms": [False, True],
        },
        {
            "input": features_num,
            "hiddens": [64],
            "output": 1,
            "use_biases": [True, True],
            "dropouts": [0, 0],
            "activations": ["relu", "sigmoid"],
            "use_batch_norms": [False, False],
            "use_layer_norms": [False, True],
        },
    ],
    "projectors": {
        "input": 64,
        "hiddens": [],
        "output": 100,
        "use_biases": [True],
        "dropouts": [0],
        "activations": ["relu"],
        "use_batch_norms": [False],
        "use_layer_norms": [True],
    },
    "clusters": {
        "input": 100,
        "hiddens": [],
        "output": len(cell_type_list),
        "use_biases": [False],
        "dropouts": [0],
        "activations": [None],
        "use_batch_norms": [False],
        "use_layer_norms": [False],
    },
}

device="cuda:0"
test_batch= batch
root_save_path = f"saved_results/thesis/pbmc/translation"

model_translation = UnitedNet(f"{root_save_path}/{test_batch}", device=device, technique=pbmc_config_translation)

In [17]:
model_translation.train(adatas_train, verbose=True)


training


  0%|          | 0/10 [00:00<?, ?it/s]



Losses                    Value
-------------------  ----------
discriminator_loss   0.0595083
reconstruction_loss  0.0411343
generator_loss       0.00795011
contrastive_loss     9.3067
translation_loss     0.0411343
best_head            0
best model saved at saved_results/thesis/pbmc/translation/CD4T/CD4T/train_best.pt 

model saved at saved_results/thesis/pbmc/translation/CD4T/CD4T/train_epoch_1.pt 



 10%|█         | 1/10 [00:04<00:43,  4.88s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05620258 -0.05620258]
                                [-4.02247524 -4.02247524]]
Confusion Matrix               [[   7    0  626    0    2   61   49]
                                [   0    0  113    0    0  338    2]
                                [  15    0 1863    1   13  245  212]
                                [   3    0  297    1    1   61   27]
                                [   0    0  188    0    0  128   23]
                                [   0    0  565    0    1 1260   33]
                                [   2    0  365    0    1   67   46]]
Accuracy                       0.4801995163240629
Adjusted Rand Index            0.13020679729601353
Normalized Mutual Information  0.1497576428276985


Losses                   Value
-------------------  ---------
discriminator_loss   0.0555364
reconstruction_loss  0.0382109
generator_los

 20%|██        | 2/10 [00:08<00:31,  3.94s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05619341 -0.05619341]
                                [-1.3747524  -1.3747524 ]]
Confusion Matrix               [[   6    0  380    1    0  358    0]
                                [   1    0  170    0    0  282    0]
                                [  14    0 1407    0    0  928    0]
                                [   0    0  125    0    0  265    0]
                                [   3    0  221    0    0  115    0]
                                [   3    0  893    1    0  962    0]
                                [   1    0  146    0    0  334    0]]
Accuracy                       0.3589782345828295
Adjusted Rand Index            0.01396222126230952
Normalized Mutual Information  0.021216021550803477


Losses                   Value
-------------------  ---------
discriminator_loss   0.0549763
reconstruction_loss  0.0373039
generator_l

 30%|███       | 3/10 [00:12<00:27,  3.94s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05619341 -0.05619341]
                                [-0.3727389  -0.3727389 ]]
Confusion Matrix               [[   0    0  450    0    1  294    0]
                                [   0    0  271    0    1  181    0]
                                [   0    0 1314    0    7 1028    0]
                                [   0    0  278    0    0  112    0]
                                [   0    0  135    0    5  199    0]
                                [   0    0 1017    1   28  813    0]
                                [   0    0  361    0    0  120    0]]
Accuracy                       0.3222490931076179
Adjusted Rand Index            -0.0019820210528122663
Normalized Mutual Information  0.01291382574837221


Losses                   Value
-------------------  ---------
discriminator_loss   0.0560711
reconstruction_loss  0.0370539
generator

 40%|████      | 4/10 [00:15<00:22,  3.76s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05619341 -0.05619341]
                                [-0.10291274 -0.10291274]]
Confusion Matrix               [[   1    0  287    0    0  457    0]
                                [   0    1  119    0    0  333    0]
                                [   2    0  943    0    0 1404    0]
                                [   0    0   98    0    0  292    0]
                                [   0    0  152    0    0  187    0]
                                [  10    0  648    0    0 1201    0]
                                [   0    0  115    0    0  366    0]]
Accuracy                       0.32436517533252723
Adjusted Rand Index            -0.006136651407975774
Normalized Mutual Information  0.009027801315051877


Losses                   Value
-------------------  ---------
discriminator_loss   0.055265
reconstruction_loss  0.0369967
generator

 50%|█████     | 5/10 [00:19<00:18,  3.64s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.0561934  -0.05619341]
                                [-0.02891275 -0.02891275]]
Confusion Matrix               [[   6    1  458    0    0  280    0]
                                [   1    3  244    0    0  205    0]
                                [   9    0 1369    1    0  970    0]
                                [   0    0  289    0    0  101    0]
                                [   5    1  165    0    0  168    0]
                                [  11    5  948    0    0  895    0]
                                [   0    0  368    0    0  113    0]]
Accuracy                       0.34356106408706166
Adjusted Rand Index            0.000526491557714626
Normalized Mutual Information  0.013790615260497675


Losses                   Value
-------------------  ---------
discriminator_loss   0.0557129
reconstruction_loss  0.0369825
generator

 60%|██████    | 6/10 [00:22<00:14,  3.51s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.0561934  -0.05619341]
                                [-0.00906054 -0.00906054]]
Confusion Matrix               [[  10    0  271    1    0  463    0]
                                [   5    0  111    0    0  337    0]
                                [  25    0  914    1    1 1408    0]
                                [   0    0  102    0    0  288    0]
                                [   4    0  140    0    0  195    0]
                                [  29    0  582    0    0 1248    0]
                                [   5    0  108    0    0  368    0]]
Accuracy                       0.3282950423216445
Adjusted Rand Index            -0.006428085510263648
Normalized Mutual Information  0.008415955358042192


Losses                   Value
-------------------  ---------
discriminator_loss   0.0554763
reconstruction_loss  0.0369805
generator

 70%|███████   | 7/10 [00:25<00:10,  3.42s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05619337 -0.0561934 ]
                                [-0.00369519 -0.00369519]]
Confusion Matrix               [[   8    0  273    1    2  461    0]
                                [   2    0  121    0    1  329    0]
                                [  25    0  935    1    3 1385    0]
                                [   1    0  101    0    0  288    0]
                                [   8    0  138    0    3  190    0]
                                [  33    0  639    0    2 1185    0]
                                [   4    0  109    0    0  368    0]]
Accuracy                       0.32209794437726724
Adjusted Rand Index            -0.007186120719746333
Normalized Mutual Information  0.008687661675675442


Losses                   Value
-------------------  ---------
discriminator_loss   0.0555942
reconstruction_loss  0.0369629
generato

 80%|████████  | 8/10 [00:28<00:06,  3.43s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05606649 -0.05606649]
                                [-0.00237421 -0.00237421]]
Confusion Matrix               [[  13    0  274    1    0  457    0]
                                [   0    0  125    0    0  328    0]
                                [  20    1  937    0    0 1391    0]
                                [   2    0   99    0    0  289    0]
                                [   5    0  139    0    0  195    0]
                                [  10    1  637    0    0 1211    0]
                                [   5    0  106    0    0  370    0]]
Accuracy                       0.32663240628778717
Adjusted Rand Index            -0.006246173097579558
Normalized Mutual Information  0.00868536358440616
model saved at saved_results/thesis/pbmc/translation/CD4T/CD4T/train_epoch_9.pt 



 90%|█████████ | 9/10 [00:32<00:03,  3.42s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05606646 -0.05606649]
                                [-0.00198999 -0.00198999]]
Confusion Matrix               [[   4    1  282    0    0  458    0]
                                [   0    0  129    0    0  324    0]
                                [  11    1  956    0    0 1377    4]
                                [   1    0  101    0    0  288    0]
                                [   0    0  150    0    0  189    0]
                                [   7    1  663    0    0 1188    0]
                                [   1    0  108    0    0  371    1]]
Accuracy                       0.3248186215235792
Adjusted Rand Index            -0.006870036655446086
Normalized Mutual Information  0.008766087338492115


Losses                   Value
-------------------  ---------
discriminator_loss   0.0555608
reconstruction_loss  0.0369381
generator

100%|██████████| 10/10 [00:36<00:00,  3.61s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05582297 -0.05582298]
                                [-0.00215461 -0.00215461]]
Confusion Matrix               [[   5    0  275    0    0  463    2]
                                [   2    0  123    0    0  328    0]
                                [  13    0  946    0    0 1384    6]
                                [   1    0  101    0    0  288    0]
                                [   3    0  144    0    1  190    1]
                                [   7    0  635    0    1 1214    2]
                                [   2    0  106    0    0  371    2]]
Accuracy                       0.32769044740024184
Adjusted Rand Index            -0.006267101988215142
Normalized Mutual Information  0.008452413107276518





## Single-task cell type annotation (classification) via UnitedNet

The translation networks' parameters are ignored.

As mentioned above all metrics are monitored. For the classification task we are intested for the:
- Confusion matrix
- Adjusted Rank Index
- Accuracy
- Normalized Mutual information

In [18]:
from unitednet.interface import UnitedNet


pbmc_config_classification = {
    "train_batch_size": 512,
    "finetune_batch_size": 5000,
    "transfer_batch_size": 512,
    "train_epochs": 10,
    "finetune_epochs": 10,
    "transfer_epochs": 20,
    "train_task": "supervised_group_identification_only", # classification
    "finetune_task": None,
    "transfer_task": None,
    "train_loss_weight": None,
    "finetune_loss_weight": None,
    "transfer_loss_weight": None,
    "lr": 0.1,
    "checkpoint": 1,
    "n_head": 1,
    "noise_level":[0,0],
    "fuser_type":"WeightedMean",
    "encoders": [
        {
            "input": features_num,
            "hiddens": [64, 64],
            "output": 64,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "relu"],
            "use_batch_norms": [True, True, True],
            "use_layer_norms": [False, False, False],
            "is_binary_input": False,
        },
        {
            "input": features_num,
            "hiddens": [64, 64],
            "output": 64,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "relu"],
            "use_batch_norms": [True, True, True],
            "use_layer_norms": [False, False, False],
            "is_binary_input": False,
        },
    ],
    "latent_projector": None,
    "decoders": [
        {
            "input": 64,
            "hiddens": [64, 64],
            "output": features_num,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "sigmoid"],
            "use_batch_norms": [False, False, False],
            "use_layer_norms": [False, False, False],
        },
        {
            "input": 64,
            "hiddens": [64, 64],
            "output": features_num,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", None],
            "use_batch_norms": [False, False, False],
            "use_layer_norms": [False, False, False],
        },
    ],
    "discriminators": [
        {
            "input": features_num,
            "hiddens": [64],
            "output": 1,
            "use_biases": [True, True],
            "dropouts": [0, 0],
            "activations": ["relu", "sigmoid"],
            "use_batch_norms": [False, False],
            "use_layer_norms": [False, True],
        },
        {
            "input": features_num,
            "hiddens": [64],
            "output": 1,
            "use_biases": [True, True],
            "dropouts": [0, 0],
            "activations": ["relu", "sigmoid"],
            "use_batch_norms": [False, False],
            "use_layer_norms": [False, True],
        },
    ],
    "projectors": {
        "input": 64,
        "hiddens": [],
        "output": 100,
        "use_biases": [True],
        "dropouts": [0],
        "activations": ["relu"],
        "use_batch_norms": [False],
        "use_layer_norms": [True],
    },
    "clusters": {
        "input": 100,
        "hiddens": [],
        "output": len(cell_type_list),
        "use_biases": [False],
        "dropouts": [0],
        "activations": [None],
        "use_batch_norms": [False],
        "use_layer_norms": [False],
    },
}

device="cuda:0"
test_batch= batch
root_save_path = f"saved_results/thesis/pbmc/classification"

model_classification = UnitedNet(f"{root_save_path}/{test_batch}", device=device, technique=pbmc_config_classification)

In [19]:
model_classification.train(adatas_train, verbose=True)


training


  0%|          | 0/10 [00:00<?, ?it/s]



Losses                        Value
-------------------------  --------
cross_entropy_loss_head_0  0.226991
best_head                  0
best model saved at saved_results/thesis/pbmc/classification/CD4T/CD4T/train_best.pt 

model saved at saved_results/thesis/pbmc/classification/CD4T/CD4T/train_epoch_1.pt 



 10%|█         | 1/10 [00:04<00:37,  4.22s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -452.13931274  -348.13018799]
                                [-4292.48925781  -453.43054199]]
Confusion Matrix               [[ 740    0    2    1    2    0    0]
                                [   0  452    0    0    0    1    0]
                                [   6    7 2200  135    0    1    0]
                                [   0    0    0  381    0    0    9]
                                [   0    2    0    0  334    3    0]
                                [   3  860    2    0   64  923    7]
                                [   0    0    1    5    0    0  475]]
Accuracy                       0.8320737605804112
Adjusted Rand Index            0.7605106205207351
Normalized Mutual Information  0.8246357413941272


Losses                         Value
-------------------------  ---------
cross_entropy_loss_head_0  0.0244222
best_head      

 20%|██        | 2/10 [00:08<00:32,  4.06s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -450.79714966  -339.36514282]
                                [-3584.08374023  -287.04818726]]
Confusion Matrix               [[ 744    0    0    0    1    0    0]
                                [   0  452    0    0    0    1    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    1  389    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   1   36    1    0   15 1806    0]
                                [   0    0    0    0    0    1  480]]
Accuracy                       0.9913845223700121
Adjusted Rand Index            0.9834376510081113
Normalized Mutual Information  0.975097656907263


Losses                         Value
-------------------------  ---------
cross_entropy_loss_head_0  0.0165457
best_head       

 30%|███       | 3/10 [00:12<00:28,  4.03s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -456.08181763  -337.98233032]
                                [-3652.35205078  -281.98300171]]
Confusion Matrix               [[ 744    0    0    0    0    1    0]
                                [   0  446    0    0    0    7    0]
                                [   0    0 2346    0    0    0    3]
                                [   0    0    2  387    0    0    1]
                                [   0    0    0    0  339    0    0]
                                [   0    1    2    0    2 1853    1]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9969770253929867
Adjusted Rand Index            0.9934999581188468
Normalized Mutual Information  0.9875742590031654
model saved at saved_results/thesis/pbmc/classification/CD4T/CD4T/train_epoch_4.pt 



 40%|████      | 4/10 [00:16<00:24,  4.13s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -448.58981323  -343.15869141]
                                [-3563.70922852  -389.36175537]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2344    5    0    0    0]
                                [   0    0    2  388    0    0    0]
                                [   0    7    0    0  329    3    0]
                                [   0   88    1    0    0 1770    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9839782345828295
Adjusted Rand Index            0.9702427947216413
Normalized Mutual Information  0.9632255565190891


Losses                         Value
-------------------------  ---------
cross_entropy_loss_head_0  0.0118426
best_head      

 50%|█████     | 5/10 [00:20<00:20,  4.12s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -447.04916382  -345.290802  ]
                                [-3948.14868164  -275.65328979]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [  13    0 2325   10    0    0    1]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   1    7    1    0    6 1844    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.994105199516324
Adjusted Rand Index            0.9864122184025789
Normalized Mutual Information  0.9791647527694716


Losses                          Value
-------------------------  ----------
cross_entropy_loss_head_0  0.00420023
best_head    

 60%|██████    | 6/10 [00:24<00:16,  4.13s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -448.85372925  -346.35324097]
                                [-3850.38476562  -365.29223633]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    2  388    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    4    0    0    1 1854    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9989419588875453
Adjusted Rand Index            0.9978430681512408
Normalized Mutual Information  0.9956609471403723
model saved at saved_results/thesis/pbmc/classification/CD4T/CD4T/train_epoch_7.pt 



 70%|███████   | 7/10 [00:29<00:12,  4.29s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -446.55487061  -343.48538208]
                                [-3803.09863281  -358.09603882]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0  144    0    0    0 1715    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9782345828295043
Adjusted Rand Index            0.9596632337953627
Normalized Mutual Information  0.9623345294881893
model saved at saved_results/thesis/pbmc/classification/CD4T/CD4T/train_epoch_8.pt 



 80%|████████  | 8/10 [00:33<00:08,  4.36s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -447.47564697  -344.02441406]
                                [-3830.68359375  -332.47918701]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0   18    0    0    0 1841    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.997279322853688
Adjusted Rand Index            0.9947218098479961
Normalized Mutual Information  0.9919325919933397
model saved at saved_results/thesis/pbmc/classification/CD4T/CD4T/train_epoch_9.pt 



 90%|█████████ | 9/10 [00:38<00:04,  4.35s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -449.81253052  -346.45788574]
                                [-4052.74536133  -403.01547241]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0   50    0    0    0 1809    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9924425634824667
Adjusted Rand Index            0.9855006098757555
Normalized Mutual Information  0.9822068752851634


Losses                           Value
-------------------------  -----------
cross_entropy_loss_head_0  0.000172609
best_head

100%|██████████| 10/10 [00:42<00:00,  4.24s/it]



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -451.26821899  -345.86209106]
                                [-3976.87792969  -398.92633057]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0





## Conclusion

Perturbation response prediction is evaluated on a held out cell type group. The model given the control cells of a particular cell type, predicts the stimulated gene expression profiles, and with post-processing analysis, the DEGs are identified. scButterfly, scGen, and scPreGAN, do the comparison based on 1) the common DEGs, and 2) .

The issue in our initial thinking is that the classification task can't be tested with the above setup. To hold out a cell type group to test it for perturbation response, won't be an equivalent meaningful evaluation for the classification, since we will have only one type. At the same time, the cell type knowledge is a prerequisite to train and test the model, so the cell type classification task seems not a meaningful one.

In other words, the UnitedNet's classification task is meaningful, since it is possible from an experiment to have a multiple modality information for a cell, such as knowing the gene, protein or dna accessibility information. But from a perturbation perspective, the knowledge of a paired control and stimulated cell is unknown (we need cell type + optimal transport to create this pseudo information). On the other hand, we could assume that given the control cell, a model can predict the stimulated, and then having both, it could be leveraged for cell type. This would be an interesting use case to try it, and test the classification task, but the limitations of the model in the perturbation response will be carried along as the input for another task. That's we should explore other workflows.

Overall, we could change the way we test, by keeping from all cell types a portion of it for testing, instead of holding out a whole cell type group (as we currently do for the validation set). Having all the above combined, UnitedNet, as it is, with the translation, and classification task can be explored on perturbation modeling.