## Multi-task UnitedNet on perturbation prediction with the PBMC dataset

Motivation relies on the potential of scButterfly for perturbation response on the PBMC dataset stimulated by IFN-beta. scButterfly and UnitedNet share similar architectures. The seconds one leverages multi-task learning. Thus, we will attempt for a simple dataset with one perturbation type to explore the usage of UnitedNet with the tasks:

- cross modal prediction on perturbation (perturbation response prediction task)
- classification task for cell type annotation

We would like to investigate if multi-task can be beneficial for perturbation modeling. Thus we will compare the multi-task model with the single-task models of the classification and the cross modal prediction task.

In [1]:
# prepare dataset
# dataset from https://github.com/theislab/scgen-reproducibility, already preprocessed

import scanpy as sc

train_data = "../data/pbmc/train_pbmc.h5ad"
valid_data = "../data/pbmc/valid_pbmc.h5ad"

# pbmc data from scGen is already split into train and valid, but we will concatenate them to do a custom split
sc_data_train = sc.read_h5ad(train_data)
sc_data_valid = sc.read_h5ad(valid_data)
pbmc = sc_data_train.concatenate(sc_data_valid)


This is where adjacency matrices should go now.
  warn(

This is where adjacency matrices should go now.
  warn(

This is where adjacency matrices should go now.
  warn(

This is where adjacency matrices should go now.
  warn(
  pbmc = sc_data_train.concatenate(sc_data_valid)


In [2]:
pbmc

AnnData object with n_obs × n_vars = 18868 × 6998
    obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type', 'batch'
    var: 'gene_symbol', 'n_cells'
    obsm: 'X_pca', 'X_tsne', 'X_umap'

In [3]:
pbmc.obs.groupby(['cell_type', 'condition']).size()

  pbmc.obs.groupby(['cell_type', 'condition']).size()


cell_type    condition 
CD4T         control       2715
             stimulated    3483
CD14+Mono    control       2184
             stimulated     698
B            control        928
             stimulated    1105
CD8T         control        643
             stimulated     594
NK           control        571
             stimulated     733
FCGR3A+Mono  control       1232
             stimulated    2790
Dendritic    control        670
             stimulated     522
dtype: int64

In [4]:
pbmc.obs.groupby(['condition']).size()

  pbmc.obs.groupby(['condition']).size()


condition
control       8943
stimulated    9925
dtype: int64

In [5]:
# simple inspection

# sc.pp.pca(pbmc)
# sc.pp.neighbors(pbmc)
# sc.tl.tsne(pbmc)
# sc.pl.tsne(pbmc, color=['condition', 'cell_type'], legend_loc='on data', legend_fontsize='small')

### Create pairs of control and stimulated data using scButterfly method

scButterfly showed the potential of using a dual VAR scheme for perturbation response prediction. They split the cells based on the cell type for control and stimulated, and then per cell type group, they create pairs of control and stimulated gene expression profile using optimal transport.

It should be noted that during perturbation, we can't measure the exact same cell that was perturbed. The current experimental methods, don't allow this type of information. Thus, deep learning architectures that are relying on pairs, such as scButterfly and UnitedNet (combining modalities), need to find a strategy to explore if creating pseudo pairs is beneficial.

scButterfly showed this exact potential.

In [6]:
from scButterfly.split_datasets import unpaired_split_dataset_perturb


# create pairs of data using scButterfly technique
control_data = pbmc[pbmc.obs.condition == 'control']
stimulate_data = pbmc[pbmc.obs.condition == 'stimulated']

control_data.obs.index = [str(i) for i in range(control_data.X.shape[0])]
stimulate_data.obs.index = [str(i) for i in range(stimulate_data.X.shape[0])]

cell_type_list = list(control_data.obs.cell_type.cat.categories)



In [7]:
id_list, id_list_dict = unpaired_split_dataset_perturb(control_data, stimulate_data)

optimal transport array torch.Size([2715, 3483])
CD4T, control num 2715 stimulate num 3483

optimal transport array torch.Size([2184, 698])
CD14+Mono, control num 2184 stimulate num 698

optimal transport array torch.Size([928, 1105])
B, control num 928 stimulate num 1105

optimal transport array torch.Size([643, 594])
CD8T, control num 643 stimulate num 594

optimal transport array torch.Size([571, 733])
NK, control num 571 stimulate num 733

optimal transport array torch.Size([1232, 2790])
FCGR3A+Mono, control num 1232 stimulate num 2790

optimal transport array torch.Size([670, 522])
Dendritic, control num 670 stimulate num 522


Start CD4T
Batch list ['Dendritic', 'B', 'FCGR3A+Mono', 'CD14+Mono', 'CD4T', 'NK', 'CD8T']
Test batch control ['Dendritic']
Validation batch control ['B', 'FCGR3A+Mono', 'CD14+Mono', 'CD4T', 'NK', 'CD8T']
Train batch control ['B', 'FCGR3A+Mono', 'CD14+Mono', 'CD4T', 'NK', 'CD8T']
Test batch stimulated ['Dendritic']
Validation batch stimulated ['B', 'FCGR3A+

In [8]:
#print(len(id_list[0]))
for idx, cell_type in enumerate(cell_type_list):
    print(cell_type)
    train_id_control, train_id_peturb, validation_id_control, validation_id_peturb, test_id_control, test_id_peturb = id_list[idx]

    control = len(train_id_control) + len(validation_id_control) + len(test_id_control)
    peturb = len(train_id_peturb) + len(validation_id_peturb) + len(test_id_peturb) 
    control_dict = id_list_dict[cell_type]["control"]
    peturb_dict = id_list_dict[cell_type]["stimulated"]
    print("control", len(train_id_control), len(validation_id_control), len(test_id_control), control)
    print("control", len(control_dict["train"]), len(control_dict["validation"]), len(control_dict["test"]))
    print("peturb", len(train_id_peturb), len(validation_id_peturb), len(test_id_peturb), peturb)    
    print("peturb", len(peturb_dict["train"]), len(peturb_dict["validation"]), len(peturb_dict["test"]))
    print()
    
# Some notes due to the lack of code readability from scButterfly:
# - the rows of the id_list correspond to the cell types, but they don't have any meaning.
# - the rows actually correspond to k-fold cross validation. One batch is a cell type that is held out.
# - the cell_type of the row and the one that is held out are irrelevant.
# - for the test, the control and the stimulated can have different size. But for train and validation we need pairs.

CD4T
control 6616 1657 670 8943
control 6616 1657 670
peturb 6616 1657 522 8795
peturb 6616 1657 522

CD14+Mono
control 6410 1605 928 8943
control 6410 1605 928
peturb 6410 1605 1105 9120
peturb 6410 1605 1105

B
control 6167 1544 1232 8943
control 6167 1544 1232
peturb 6167 1544 2790 10501
peturb 6167 1544 2790

CD8T
control 5405 1354 2184 8943
control 5405 1354 2184
peturb 5405 1354 698 7457
peturb 5405 1354 698

NK
control 4980 1248 2715 8943
control 4980 1248 2715
peturb 4980 1248 3483 9711
peturb 4980 1248 3483

FCGR3A+Mono
control 6696 1676 571 8943
control 6696 1676 571
peturb 6696 1676 733 9105
peturb 6696 1676 733

Dendritic
control 6638 1662 643 8943
control 6638 1662 643
peturb 6638 1662 594 8894
peturb 6638 1662 594



In [15]:
# create adatas for pair of control and stimulated
import numpy as np


batch = cell_type_list[0]
id_list_batch = id_list_dict[batch]
control = id_list_batch["control"]
peturb = id_list_batch["stimulated"]
train_id_control = control["train"]
validation_id_control = control["validation"]
test_id_control = control["test"]
train_id_peturb = peturb["train"]
validation_id_peturb = peturb["validation"]
test_id_peturb = peturb["test"]

control_train = pbmc[train_id_control]
control_valid = pbmc[validation_id_control]
control_test = pbmc[test_id_control]

peturb_train = pbmc[train_id_peturb]
peturb_valid = pbmc[validation_id_peturb]
peturb_test = pbmc[test_id_peturb]

all = [control_train, control_valid, control_test, peturb_train, peturb_valid, peturb_test]
for adata in all:
    adata.obs["label"] = list(adata.obs["cell_type"])
    adata.X = adata.X.toarray()
    print("min", np.max(adata.X))
    print("max", np.min(adata.X))
    print()

adatas_train = [control_train, peturb_train]
adatas_valid = [control_valid, peturb_valid]
adatas_test = [control_test, peturb_test]

features_num = adatas_train[0].n_vars
adatas_train


min 7.0346947
max 0.0

min 6.822089
max 0.0

min 6.7637987
max 0.0

min 6.879698
max 0.0

min 7.0346947
max 0.0

min 5.961386
max 0.0



[AnnData object with n_obs × n_vars = 6616 × 6998
     obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type', 'batch', 'label'
     var: 'gene_symbol', 'n_cells'
     obsm: 'X_pca', 'X_tsne', 'X_umap',
 AnnData object with n_obs × n_vars = 6616 × 6998
     obs: 'condition', 'n_counts', 'n_genes', 'mt_frac', 'cell_type', 'batch', 'label'
     var: 'gene_symbol', 'n_cells'
     obsm: 'X_pca', 'X_tsne', 'X_umap']

### Prepare multi-task model

UnitedNet attempts to solve two tasks, the cell type annotation and the cross modal prediction.

Metrics overview:

Translation
- R2

Classification
- Confusion matrix
- Adjusted Rank Index
- Accuracy
- Normalized Mutual information

Both of the classification and the translation network's parameters are taken into account in the multi-task way orchestrated by UnitedNet.

In [10]:
from unitednet.interface import UnitedNet


pbmc_config = {
    "train_batch_size": 256,
    "finetune_batch_size": 5000,
    "transfer_batch_size": 512,
    "train_epochs": 50,
    "finetune_epochs": 10,
    "transfer_epochs": 20,
    "train_task": "supervised_group_identification", # -> translation, classification
    "finetune_task": None,
    "transfer_task": None,
    "train_loss_weight": None,
    "finetune_loss_weight": None,
    "transfer_loss_weight": None,
    "lr": 0.1,
    "checkpoint": 1,
    "n_head": 1,
    "noise_level":[0,0],
    "fuser_type":"WeightedMean",
    "encoders": [
        {
            "input": features_num,
            "hiddens": [64, 64],
            "output": 64,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "relu"],
            "use_batch_norms": [True, True, True],
            "use_layer_norms": [False, False, False],
            "is_binary_input": False,
        },
        {
            "input": features_num,
            "hiddens": [64, 64],
            "output": 64,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "relu"],
            "use_batch_norms": [True, True, True],
            "use_layer_norms": [False, False, False],
            "is_binary_input": False,
        },
    ],
    "latent_projector": None,
    "decoders": [
        {
            "input": 64,
            "hiddens": [64, 64],
            "output": features_num,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "sigmoid"],
            "use_batch_norms": [False, False, False],
            "use_layer_norms": [False, False, False],
        },
        {
            "input": 64,
            "hiddens": [64, 64],
            "output": features_num,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", None],
            "use_batch_norms": [False, False, False],
            "use_layer_norms": [False, False, False],
        },
    ],
    "discriminators": [
        {
            "input": features_num,
            "hiddens": [64],
            "output": 1,
            "use_biases": [True, True],
            "dropouts": [0, 0],
            "activations": ["relu", "sigmoid"],
            "use_batch_norms": [False, False],
            "use_layer_norms": [False, True],
        },
        {
            "input": features_num,
            "hiddens": [64],
            "output": 1,
            "use_biases": [True, True],
            "dropouts": [0, 0],
            "activations": ["relu", "sigmoid"],
            "use_batch_norms": [False, False],
            "use_layer_norms": [False, True],
        },
    ],
    "projectors": {
        "input": 64,
        "hiddens": [],
        "output": 100,
        "use_biases": [True],
        "dropouts": [0],
        "activations": ["relu"],
        "use_batch_norms": [False],
        "use_layer_norms": [True],
    },
    "clusters": {
        "input": 100,
        "hiddens": [],
        "output": len(cell_type_list),
        "use_biases": [False],
        "dropouts": [0],
        "activations": [None],
        "use_batch_norms": [False],
        "use_layer_norms": [False],
    },
}

device="cuda:0"
test_batch= batch
root_save_path = f"../saved_results/thesis/pbmc/multi_task"


model = UnitedNet("multitask", f"{root_save_path}/{test_batch}", device=device, technique=pbmc_config)
model.train(adatas_train=adatas_train)
#model.finetune(adatas_train)
#model.transfer(adatas_train, adatas_transfer = adatas_test, verbose=True)
#print(model.evaluate(adatas_test))
# evaluation should be different for peturbation

training


  0%|          | 0/50 [00:00<?, ?it/s]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_1.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05472612 -0.05900957]
                                [-1.84497035 -1.84497035]]
Confusion Matrix               [[ 740    0    4    0    1    0    0]
                                [   0  398    0    0    0   55    0]
                                [   1    0 2262   84    1    0    1]
                                [   0    0    0  389    0    0    1]
                                [   0    0    0    0  338    1    0]
                                [   4   12    0    0   29 1814    0]
                                [   0    0    0    2    0    0  479]]
Accuracy                       0.9703748488512697
Adjusted Rand Index            0.9393411933748044
Normalized Mutual Information  0.9276699113348741


  2%|▏         | 1/50 [00:08<06:54,  8.46s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_2.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.12899895 -0.12899895]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  338    1    0]
                                [   1    8    0    0    0 1849    1]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9983373639661427
Adjusted Rand Index            0.9967376651495392
Normalized Mutual Information  0.9936836497302288


  4%|▍         | 2/50 [00:15<05:57,  7.45s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_3.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.01190785 -0.01190785]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  451    0    0    0    2    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   2   10    0    0   40 1807    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9918379685610641
Adjusted Rand Index            0.9847986438519976
Normalized Mutual Information  0.9778700907718905


  6%|▌         | 3/50 [00:21<05:20,  6.82s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_4.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00429278 -0.00429278]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2348    1    0    0    0]
                                [   0    0    1  389    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   1    3    0    0    5 1845    5]
                                [   0    0    0    5    1    0  475]]
Accuracy                       0.9966747279322854
Adjusted Rand Index            0.9945442156805141
Normalized Mutual Information  0.987310230328031


  8%|▊         | 4/50 [00:27<05:06,  6.65s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_5.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00394938 -0.00394938]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2282   64    1    0    2]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    1    0    0    0 1858    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9897218863361548
Adjusted Rand Index            0.9766147246626197
Normalized Mutual Information  0.9756725290021383


 10%|█         | 5/50 [00:34<04:59,  6.65s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_6.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00438534 -0.00438534]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  452    0    0    0    1    0]
                                [   1    0 2323   24    1    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  337    2    0]
                                [   1    2    0    0    0 1855    1]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9950120918984281
Adjusted Rand Index            0.9888418414112929
Normalized Mutual Information  0.9836185818147879


 12%|█▏        | 6/50 [00:40<04:51,  6.61s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_7.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00462746 -0.00462746]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   1    0 2348    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9998488512696493
Adjusted Rand Index            0.9996052627117882
Normalized Mutual Information  0.9992561145216492


 14%|█▍        | 7/50 [00:47<04:46,  6.66s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_8.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00551282 -0.00551282]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    3  387    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.999546553808948
Adjusted Rand Index            0.9989530052056911
Normalized Mutual Information  0.9981559644724224


 16%|█▌        | 8/50 [00:54<04:40,  6.68s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_9.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00555034 -0.00555034]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 18%|█▊        | 9/50 [01:00<04:28,  6.54s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_10.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.0057021  -0.0057021 ]]
Confusion Matrix               [[ 744    0    1    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2346    3    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9993954050785974
Adjusted Rand Index            0.9985578566336064
Normalized Mutual Information  0.9974122016927038


 20%|██        | 10/50 [01:07<04:23,  6.59s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_11.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00605845 -0.00605845]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2347    0    1    1    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    7    2    0   17 1833    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9957678355501813
Adjusted Rand Index            0.9912577802531397
Normalized Mutual Information  0.9858553080100517


 22%|██▏       | 11/50 [01:13<04:15,  6.55s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_12.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00644008 -0.00644008]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  451    0    0    0    2    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    1    0    0    0 1858    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.999546553808948
Adjusted Rand Index            0.9991161378041702
Normalized Mutual Information  0.9979941644570681


 24%|██▍       | 12/50 [01:20<04:06,  6.49s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_13.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00647214 -0.00647214]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  452    0    0    0    1    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    1    0    0  332    6    0]
                                [   0    0    0    0    0 1858    1]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.998639661426844
Adjusted Rand Index            0.9976298880710341
Normalized Mutual Information  0.994734805392876


 26%|██▌       | 13/50 [01:26<03:56,  6.39s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_14.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00729888 -0.00729888]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  452    0    0    0    1    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9998488512696493
Adjusted Rand Index            0.9997051240746986
Normalized Mutual Information  0.9992893182711852


 28%|██▊       | 14/50 [01:32<03:49,  6.38s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_15.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00712604 -0.00712604]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2347    2    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    1 1858    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.999546553808948
Adjusted Rand Index            0.9990207506034441
Normalized Mutual Information  0.9979997065908862


 30%|███       | 15/50 [01:39<03:44,  6.42s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_16.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00765267 -0.00765267]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    3 1856    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.999546553808948
Adjusted Rand Index            0.9991594354660043
Normalized Mutual Information  0.9982068829067328


 32%|███▏      | 16/50 [01:45<03:40,  6.47s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_17.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00794911 -0.00794911]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0  111 1748    0]
                                [   0    0    0    0    1    0  480]]
Accuracy                       0.9830713422007256
Adjusted Rand Index            0.970000833960853
Normalized Mutual Information  0.9690961061744795


 34%|███▍      | 17/50 [01:52<03:41,  6.70s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_18.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00801118 -0.00801118]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 36%|███▌      | 18/50 [02:00<03:37,  6.81s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_19.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00827838 -0.00827838]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 38%|███▊      | 19/50 [02:07<03:37,  7.03s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_20.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00832983 -0.00832983]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 40%|████      | 20/50 [02:14<03:30,  7.01s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_21.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00879232 -0.00879232]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 42%|████▏     | 21/50 [02:21<03:22,  6.98s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_22.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00889857 -0.00889857]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   9    0 2340    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  338    1    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9984885126964933
Adjusted Rand Index            0.9961732441202164
Normalized Mutual Information  0.9944046971645242


 44%|████▍     | 22/50 [02:27<03:09,  6.76s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_23.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00935492 -0.00935492]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  384    0    0    6]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.999093107617896
Adjusted Rand Index            0.999337736098755
Normalized Mutual Information  0.9971222733333066


 46%|████▌     | 23/50 [02:33<02:57,  6.58s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_24.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476969]
                                [-0.00971907 -0.00971907]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  451    0    0    0    2    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    2    0    0  479]]
Accuracy                       0.9993954050785974
Adjusted Rand Index            0.9991888081884325
Normalized Mutual Information  0.9975454906786807


 48%|████▊     | 24/50 [02:41<02:56,  6.78s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_25.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476968]
                                [-0.00961243 -0.00961243]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 50%|█████     | 25/50 [02:48<02:53,  6.94s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_26.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05476969 -0.05476968]
                                [-0.00992751 -0.00992751]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 52%|█████▏    | 26/50 [02:54<02:42,  6.79s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_27.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05506452 -0.05506452]
                                [-0.00996751 -0.00996751]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 54%|█████▍    | 27/50 [03:03<02:46,  7.25s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_28.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05506452 -0.05506452]
                                [-0.01036059 -0.01036059]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 56%|█████▌    | 28/50 [03:10<02:42,  7.40s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_29.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05506452 -0.05563611]
                                [-0.01018077 -0.01018077]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    9    0    0    0 1850    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.998639661426844
Adjusted Rand Index            0.9973527881760367
Normalized Mutual Information  0.9954003829193265


 58%|█████▊    | 29/50 [03:17<02:32,  7.25s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_30.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05506452 -0.06849693]
                                [-0.01046568 -0.01046568]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 60%|██████    | 30/50 [03:24<02:22,  7.11s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_31.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05506452 -0.05506452]
                                [-0.01078609 -0.01078609]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 62%|██████▏   | 31/50 [03:31<02:12,  6.99s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_32.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05506452 -0.05506452]
                                [-0.01084994 -0.01084994]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 64%|██████▍   | 32/50 [03:37<02:03,  6.88s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_33.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05506452 -0.05520742]
                                [-0.01097941 -0.01097941]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 66%|██████▌   | 33/50 [03:44<01:54,  6.71s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_34.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05506452 -0.05549321]
                                [-0.01141557 -0.01141557]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 68%|██████▊   | 34/50 [03:51<01:48,  6.77s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_35.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05506452 -0.05506452]
                                [-0.01111746 -0.01111746]]
Confusion Matrix               [[ 744    0    0    1    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2301   48    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9925937122128174
Adjusted Rand Index            0.9832614624128413
Normalized Mutual Information  0.9818883585728778


 70%|███████   | 35/50 [03:58<01:43,  6.92s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_36.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05506452 -0.05506452]
                                [-0.00866119 -0.00866119]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2347    2    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9996977025392987
Adjusted Rand Index            0.9993012830728246
Normalized Mutual Information  0.998696984374044


 72%|███████▏  | 36/50 [04:06<01:43,  7.40s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_37.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05506452 -0.06849495]
                                [-0.00169769 -0.00169769]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0  433    0    0    0 1426    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9345525997581621
Adjusted Rand Index            0.8931154133209355
Normalized Mutual Information  0.9277161816357347


 74%|███████▍  | 37/50 [04:14<01:35,  7.33s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_38.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642027 -0.05642027]
                                [ 0.00431911  0.00431911]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    1    0    0    0 1858    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9998488512696493
Adjusted Rand Index            0.9997050677033135
Normalized Mutual Information  0.9992893335825449


 76%|███████▌  | 38/50 [04:21<01:27,  7.29s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_39.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642027 -0.05642027]
                                [ 0.0043975   0.0043975 ]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 78%|███████▊  | 39/50 [04:28<01:19,  7.21s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_40.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642027 -0.06670892]
                                [ 0.00398215  0.00398215]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2348    1    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    3    9    0    0  469]]
Accuracy                       0.9980350665054414
Adjusted Rand Index            0.9975815309974926
Normalized Mutual Information  0.9934281313772423


 80%|████████  | 40/50 [04:35<01:11,  7.16s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_41.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642027 -0.05642027]
                                [ 0.00396012  0.00396012]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 82%|████████▏ | 41/50 [04:42<01:03,  7.10s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_42.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642027 -0.05642027]
                                [ 0.00360769  0.00360769]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 84%|████████▍ | 42/50 [04:49<00:56,  7.01s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_43.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642027 -0.05642027]
                                [ 0.00350734  0.00350734]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    2    0    0    0 1857    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9996977025392987
Adjusted Rand Index            0.9994103344058928
Normalized Mutual Information  0.9987046261815996


 86%|████████▌ | 43/50 [04:56<00:50,  7.21s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_44.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642027 -0.05642027]
                                [ 0.00304605  0.00304605]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 88%|████████▊ | 44/50 [05:04<00:44,  7.45s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_45.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642027 -0.05642027]
                                [ 0.00348069  0.00348069]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 90%|█████████ | 45/50 [05:12<00:37,  7.52s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_46.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642027 -0.05642027]
                                [ 0.00333501  0.00333501]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 92%|█████████▏| 46/50 [05:19<00:29,  7.32s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_47.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642027 -0.05642027]
                                [ 0.0026004   0.0026004 ]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2348    1    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9998488512696493
Adjusted Rand Index            0.9996505603685436
Normalized Mutual Information  0.9992855115365117


 94%|█████████▍| 47/50 [05:25<00:21,  7.09s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_48.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642027 -0.05642027]
                                [ 0.00218566  0.00218566]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   2    0 2347    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    3    0    0    0 1856    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9992442563482467
Adjusted Rand Index            0.9983260764226445
Normalized Mutual Information  0.9968061363178338


 96%|█████████▌| 48/50 [05:32<00:13,  6.88s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_49.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642027 -0.05642027]
                                [ 0.00307248  0.00307248]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


 98%|█████████▊| 49/50 [05:38<00:06,  6.82s/it]

model saved at ../saved_results/thesis/pbmc/multi_task/CD4T/train_epoch_50.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05642027 -0.05642027]
                                [ 0.00206853  0.00206853]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2349    0    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       1.0
Adjusted Rand Index            1.0
Normalized Mutual Information  1.0


100%|██████████| 50/50 [05:45<00:00,  6.91s/it]


## Single-task translation via UnitedNet 

The classification network's parameters are ignored.

It should be noted that all the metrics used in the multi-task way are still monitored for the single task as well. But we are only interested for the translation, so the metrics to pay attention to is the R2.

In [11]:
from unitednet.interface import UnitedNet


pbmc_config_translation = {
    "train_batch_size": 512,
    "finetune_batch_size": 5000,
    "transfer_batch_size": 512,
    "train_epochs": 10,
    "finetune_epochs": 10,
    "transfer_epochs": 20,
    "train_task": "cross_model_prediction", # translation
    "finetune_task": None,
    "transfer_task": None,
    "train_loss_weight": None,
    "finetune_loss_weight": None,
    "transfer_loss_weight": None,
    "lr": 0.1,
    "checkpoint": 1,
    "n_head": 1,
    "noise_level":[0,0],
    "fuser_type":"WeightedMean",
    "encoders": [
        {
            "input": features_num,
            "hiddens": [64, 64],
            "output": 64,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "relu"],
            "use_batch_norms": [True, True, True],
            "use_layer_norms": [False, False, False],
            "is_binary_input": False,
        },
        {
            "input": features_num,
            "hiddens": [64, 64],
            "output": 64,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "relu"],
            "use_batch_norms": [True, True, True],
            "use_layer_norms": [False, False, False],
            "is_binary_input": False,
        },
    ],
    "latent_projector": None,
    "decoders": [
        {
            "input": 64,
            "hiddens": [64, 64],
            "output": features_num,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "sigmoid"],
            "use_batch_norms": [False, False, False],
            "use_layer_norms": [False, False, False],
        },
        {
            "input": 64,
            "hiddens": [64, 64],
            "output": features_num,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", None],
            "use_batch_norms": [False, False, False],
            "use_layer_norms": [False, False, False],
        },
    ],
    "discriminators": [
        {
            "input": features_num,
            "hiddens": [64],
            "output": 1,
            "use_biases": [True, True],
            "dropouts": [0, 0],
            "activations": ["relu", "sigmoid"],
            "use_batch_norms": [False, False],
            "use_layer_norms": [False, True],
        },
        {
            "input": features_num,
            "hiddens": [64],
            "output": 1,
            "use_biases": [True, True],
            "dropouts": [0, 0],
            "activations": ["relu", "sigmoid"],
            "use_batch_norms": [False, False],
            "use_layer_norms": [False, True],
        },
    ],
    "projectors": {
        "input": 64,
        "hiddens": [],
        "output": 100,
        "use_biases": [True],
        "dropouts": [0],
        "activations": ["relu"],
        "use_batch_norms": [False],
        "use_layer_norms": [True],
    },
    "clusters": {
        "input": 100,
        "hiddens": [],
        "output": len(cell_type_list),
        "use_biases": [False],
        "dropouts": [0],
        "activations": [None],
        "use_batch_norms": [False],
        "use_layer_norms": [False],
    },
}

device="cuda:0"
test_batch= batch
root_save_path = f"../saved_results/thesis/pbmc/translation"

model_translation = UnitedNet("translation", f"{root_save_path}/{test_batch}", device=device, technique=pbmc_config_translation)

In [12]:
model_translation.train(adatas_train, verbose=True)


training


  0%|          | 0/10 [00:00<?, ?it/s]

model saved at ../saved_results/thesis/pbmc/translation/CD4T/train_epoch_1.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05630496 -0.05630494]
                                [-4.47049189 -4.47049189]]
Confusion Matrix               [[   0    0  556    0    0  143   46]
                                [   0    0  328    0    0   96   29]
                                [   0    0 1888    0    0  255  206]
                                [   0    0  334    0    0   29   27]
                                [   0    0  219    0    0   93   27]
                                [   0    0 1203    0    0  438  218]
                                [   0    0  397    0    0   34   50]]
Accuracy                       0.3591293833131802
Adjusted Rand Index            0.008820493532019568
Normalized Mutual Information  0.018271206722835588


 10%|█         | 1/10 [00:04<00:36,  4.05s/it]

model saved at ../saved_results/thesis/pbmc/translation/CD4T/train_epoch_2.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05740947 -0.05740947]
                                [-1.77818835 -1.77818835]]
Confusion Matrix               [[   0    2  721    0    0   22    0]
                                [   0    2  446    0    0    5    0]
                                [   0    7 2170    0    0  172    0]
                                [   0    1  360    0    0   29    0]
                                [   0    2  320    0    1   16    0]
                                [   0   47 1655    0   19  138    0]
                                [   0    2  431    0    0   48    0]]
Accuracy                       0.3493047158403869
Adjusted Rand Index            -0.0074504385595088015
Normalized Mutual Information  0.018393853836542146


 20%|██        | 2/10 [00:08<00:34,  4.26s/it]

model saved at ../saved_results/thesis/pbmc/translation/CD4T/train_epoch_3.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05740947 -0.05740947]
                                [-0.48814806 -0.48814806]]
Confusion Matrix               [[   0    0  700    0    0    4   41]
                                [   0    0  368    0    0   80    5]
                                [   0    0 2144    0    0   24  181]
                                [   0    0  352    0    0    1   37]
                                [   0    0  289    0    2   36   12]
                                [   0    0 1355    0    1  405   98]
                                [   0    0  431    0    0    8   42]]
Accuracy                       0.3919286577992745
Adjusted Rand Index            0.008982578324036174
Normalized Mutual Information  0.0703032601794225


 30%|███       | 3/10 [00:13<00:30,  4.40s/it]

model saved at ../saved_results/thesis/pbmc/translation/CD4T/train_epoch_4.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05740947 -0.05740947]
                                [-0.12387719 -0.12387719]]
Confusion Matrix               [[   0    0  690    0    3   52    0]
                                [   0    1  427    0    2   23    0]
                                [   0    0 2131    0    6  212    0]
                                [   0    0  337    0    2   51    0]
                                [   0    0  313    0    4   22    0]
                                [   0    2 1599    0   19  239    0]
                                [   0    0  411    0    1   69    0]]
Accuracy                       0.3589782345828295
Adjusted Rand Index            -0.001009996972991858
Normalized Mutual Information  0.007818893413208126


 40%|████      | 4/10 [00:17<00:26,  4.41s/it]

model saved at ../saved_results/thesis/pbmc/translation/CD4T/train_epoch_5.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05740947 -0.05740947]
                                [-0.03337806 -0.03337806]]
Confusion Matrix               [[   0    1  673    0    0   71    0]
                                [   0    1  387    0    0   65    0]
                                [   0    5 2103    0    1  240    0]
                                [   0    1  337    0    0   52    0]
                                [   0    1  291    0    0   47    0]
                                [   0   10 1346    0    0  503    0]
                                [   0    0  399    0    0   82    0]]
Accuracy                       0.39404474002418377
Adjusted Rand Index            0.0071958521691532
Normalized Mutual Information  0.02189013125604787


 50%|█████     | 5/10 [00:22<00:22,  4.48s/it]

model saved at ../saved_results/thesis/pbmc/translation/CD4T/train_epoch_6.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05740947 -0.05740947]
                                [-0.00998109 -0.00998109]]
Confusion Matrix               [[   0    3  678    0    0   64    0]
                                [   0    5  383    0    0   65    0]
                                [   0    5 2083    0    0  261    0]
                                [   0    1  335    0    0   53    1]
                                [   0    4  293    0    0   42    0]
                                [   0   13 1349    0    0  497    0]
                                [   0    3  401    0    0   75    2]]
Accuracy                       0.3910217654171705
Adjusted Rand Index            0.004942116970259271
Normalized Mutual Information  0.021400146390759184


 60%|██████    | 6/10 [00:25<00:17,  4.28s/it]

model saved at ../saved_results/thesis/pbmc/translation/CD4T/train_epoch_7.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05730562 -0.05730562]
                                [-0.00412386 -0.00412386]]
Confusion Matrix               [[   3    2  663    0    0   77    0]
                                [   0    1  404    0    0   48    0]
                                [   8    7 2018    0    0  316    0]
                                [   0    0  328    0    0   62    0]
                                [   1    1  290    0    0   47    0]
                                [   0    4 1406    1    0  448    0]
                                [   2    0  389    0    0   90    0]]
Accuracy                       0.3733373639661427
Adjusted Rand Index            -0.001966001722448588
Normalized Mutual Information  0.012476071430539516


 70%|███████   | 7/10 [00:30<00:12,  4.25s/it]

model saved at ../saved_results/thesis/pbmc/translation/CD4T/train_epoch_8.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05730561 -0.05730561]
                                [-0.00267519 -0.00267519]]
Confusion Matrix               [[   1    0  665    0    0   79    0]
                                [   0    0  383    0    0   70    0]
                                [   1    0 2045    0    0  303    0]
                                [   0    0  329    0    0   61    0]
                                [   0    0  287    0    0   52    0]
                                [   0    0 1296    0    0  563    0]
                                [   0    0  395    0    0   86    0]]
Accuracy                       0.39434703748488514
Adjusted Rand Index            0.005143447920724046
Normalized Mutual Information  0.02113158951875281


 80%|████████  | 8/10 [00:33<00:08,  4.09s/it]

model saved at ../saved_results/thesis/pbmc/translation/CD4T/train_epoch_9.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05729266 -0.05729265]
                                [-0.00222913 -0.00222913]]
Confusion Matrix               [[   0    0  662    0    0   83    0]
                                [   0    0  375    0    0   78    0]
                                [   0    0 2047    0    0  302    0]
                                [   0    0  330    0    0   60    0]
                                [   0    0  284    0    0   55    0]
                                [   0    0 1314    0    1  544    0]
                                [   0    0  395    0    0   86    0]]
Accuracy                       0.3916263603385732
Adjusted Rand Index            0.005903229545142348
Normalized Mutual Information  0.018500712246052094


 90%|█████████ | 9/10 [00:37<00:04,  4.03s/it]

model saved at ../saved_results/thesis/pbmc/translation/CD4T/train_epoch_10.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[-0.05769907 -0.05769907]
                                [-0.00230476 -0.00230476]]
Confusion Matrix               [[   0    0  662    0    0   83    0]
                                [   0    0  383    0    0   70    0]
                                [   0    0 2029    0    0  320    0]
                                [   0    0  328    0    0   62    0]
                                [   0    0  289    0    0   50    0]
                                [   0    0 1322    0    0  536    1]
                                [   0    0  393    0    0   88    0]]
Accuracy                       0.3876964933494559
Adjusted Rand Index            0.0026792206741629515
Normalized Mutual Information  0.01679305761277922


100%|██████████| 10/10 [00:41<00:00,  4.15s/it]


## Single-task cell type annotation (classification) via UnitedNet

The translation networks' parameters are ignored.

As mentioned above all metrics are monitored. For the classification task we are intested for the:
- Confusion matrix
- Adjusted Rank Index
- Accuracy
- Normalized Mutual information

In [13]:
from unitednet.interface import UnitedNet


pbmc_config_classification = {
    "train_batch_size": 512,
    "finetune_batch_size": 5000,
    "transfer_batch_size": 512,
    "train_epochs": 10,
    "finetune_epochs": 10,
    "transfer_epochs": 20,
    "train_task": "supervised_group_identification_only", # classification
    "finetune_task": None,
    "transfer_task": None,
    "train_loss_weight": None,
    "finetune_loss_weight": None,
    "transfer_loss_weight": None,
    "lr": 0.1,
    "checkpoint": 1,
    "n_head": 1,
    "noise_level":[0,0],
    "fuser_type":"WeightedMean",
    "encoders": [
        {
            "input": features_num,
            "hiddens": [64, 64],
            "output": 64,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "relu"],
            "use_batch_norms": [True, True, True],
            "use_layer_norms": [False, False, False],
            "is_binary_input": False,
        },
        {
            "input": features_num,
            "hiddens": [64, 64],
            "output": 64,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "relu"],
            "use_batch_norms": [True, True, True],
            "use_layer_norms": [False, False, False],
            "is_binary_input": False,
        },
    ],
    "latent_projector": None,
    "decoders": [
        {
            "input": 64,
            "hiddens": [64, 64],
            "output": features_num,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", "sigmoid"],
            "use_batch_norms": [False, False, False],
            "use_layer_norms": [False, False, False],
        },
        {
            "input": 64,
            "hiddens": [64, 64],
            "output": features_num,
            "use_biases": [True, True, True],
            "dropouts": [0, 0, 0],
            "activations": ["relu", "relu", None],
            "use_batch_norms": [False, False, False],
            "use_layer_norms": [False, False, False],
        },
    ],
    "discriminators": [
        {
            "input": features_num,
            "hiddens": [64],
            "output": 1,
            "use_biases": [True, True],
            "dropouts": [0, 0],
            "activations": ["relu", "sigmoid"],
            "use_batch_norms": [False, False],
            "use_layer_norms": [False, True],
        },
        {
            "input": features_num,
            "hiddens": [64],
            "output": 1,
            "use_biases": [True, True],
            "dropouts": [0, 0],
            "activations": ["relu", "sigmoid"],
            "use_batch_norms": [False, False],
            "use_layer_norms": [False, True],
        },
    ],
    "projectors": {
        "input": 64,
        "hiddens": [],
        "output": 100,
        "use_biases": [True],
        "dropouts": [0],
        "activations": ["relu"],
        "use_batch_norms": [False],
        "use_layer_norms": [True],
    },
    "clusters": {
        "input": 100,
        "hiddens": [],
        "output": len(cell_type_list),
        "use_biases": [False],
        "dropouts": [0],
        "activations": [None],
        "use_batch_norms": [False],
        "use_layer_norms": [False],
    },
}

device="cuda:0"
test_batch= batch
root_save_path = f"../saved_results/thesis/pbmc/classification"

model_classification = UnitedNet("classification", f"{root_save_path}/{test_batch}", device=device, technique=pbmc_config_classification)

In [14]:
model_classification.train(adatas_train, verbose=True)


training


  0%|          | 0/10 [00:00<?, ?it/s]

model saved at ../saved_results/thesis/pbmc/classification/CD4T/train_epoch_1.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -456.57522583  -343.61352539]
                                [-6887.48681641  -431.98233032]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [  50    3 1962  331    0    3    0]
                                [   0    0    0  380    0    0   10]
                                [   1    0    0    0  337    1    0]
                                [   2  670    2    1  115 1066    3]
                                [   0    0    0    2    0    0  479]]
Accuracy                       0.8195284159613059
Adjusted Rand Index            0.6891361118659025
Normalized Mutual Information  0.7920490580732706


 10%|█         | 1/10 [00:04<00:42,  4.77s/it]

model saved at ../saved_results/thesis/pbmc/classification/CD4T/train_epoch_2.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -459.04974365  -343.96884155]
                                [-6372.09667969  -415.36251831]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  451    0    0    0    2    0]
                                [   0    0 2343    5    0    1    0]
                                [   0    0    0  389    0    0    1]
                                [   0    0    0    0  333    6    0]
                                [   0   32    0    0    2 1824    1]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9924425634824667
Adjusted Rand Index            0.9851843871001829
Normalized Mutual Information  0.9766101335090821


 20%|██        | 2/10 [00:09<00:39,  4.90s/it]

model saved at ../saved_results/thesis/pbmc/classification/CD4T/train_epoch_3.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -445.1177063   -338.70904541]
                                [-5351.71533203  -372.30682373]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2347    2    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    5    0    0    1 1852    1]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.998639661426844
Adjusted Rand Index            0.9972494893583417
Normalized Mutual Information  0.9944664818328194


 30%|███       | 3/10 [00:14<00:33,  4.80s/it]

model saved at ../saved_results/thesis/pbmc/classification/CD4T/train_epoch_4.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -458.28692627  -340.18048096]
                                [-6903.26318359  -428.5944519 ]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  346    0    0    0  107    0]
                                [   1    0 2346    1    0    1    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    1 1858    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9832224909310762
Adjusted Rand Index            0.9686647959860603
Normalized Mutual Information  0.9667129920110877


 40%|████      | 4/10 [00:18<00:27,  4.66s/it]

model saved at ../saved_results/thesis/pbmc/classification/CD4T/train_epoch_5.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -443.11416626  -340.34884644]
                                [-6311.97705078  -423.41625977]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  443    0    0    0   10    0]
                                [   0    0 2345    3    0    1    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    3    0    0    1 1855    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.997279322853688
Adjusted Rand Index            0.9943235867012179
Normalized Mutual Information  0.9898315166169429


 50%|█████     | 5/10 [00:23<00:24,  4.81s/it]

model saved at ../saved_results/thesis/pbmc/classification/CD4T/train_epoch_6.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -445.84524536  -340.25708008]
                                [-6489.96728516  -419.35403442]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2348    1    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    1    0    0    0 1858    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9996977025392987
Adjusted Rand Index            0.9993555554257957
Normalized Mutual Information  0.9985749540194646


 60%|██████    | 6/10 [00:28<00:19,  4.82s/it]

model saved at ../saved_results/thesis/pbmc/classification/CD4T/train_epoch_7.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -445.00866699  -340.08703613]
                                [-6794.06347656  -489.80105591]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   1    0 2346    1    1    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    1    0    0  338    0    0]
                                [   0  208    0    0    0 1651    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.967956469165659
Adjusted Rand Index            0.9419360907804107
Normalized Mutual Information  0.9495031414583308


 70%|███████   | 7/10 [00:33<00:14,  4.83s/it]

model saved at ../saved_results/thesis/pbmc/classification/CD4T/train_epoch_8.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -441.54043579  -343.38574219]
                                [-6296.71582031  -578.20947266]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2348    1    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0   22    0    0    0 1837    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9965235792019347
Adjusted Rand Index            0.9932067066460201
Normalized Mutual Information  0.9898271682436064


 80%|████████  | 8/10 [00:38<00:09,  4.95s/it]

model saved at ../saved_results/thesis/pbmc/classification/CD4T/train_epoch_9.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -442.73681641  -347.91082764]
                                [-6529.03564453  -631.69818115]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   0    0 2348    1    0    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0    0    0    0    0 1859    0]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9998488512696493
Adjusted Rand Index            0.9996505603685436
Normalized Mutual Information  0.9992855115365117


 90%|█████████ | 9/10 [00:43<00:04,  4.82s/it]

model saved at ../saved_results/thesis/pbmc/classification/CD4T/train_epoch_10.pt 



Metrics                        Value
-----------------------------  --------------------------------------
R2                             [[ -446.23391724  -349.51376343]
                                [-6712.47851562  -612.07824707]]
Confusion Matrix               [[ 745    0    0    0    0    0    0]
                                [   0  453    0    0    0    0    0]
                                [   1    0 2347    0    1    0    0]
                                [   0    0    0  390    0    0    0]
                                [   0    0    0    0  339    0    0]
                                [   0   28    0    0    1 1828    2]
                                [   0    0    0    0    0    0  481]]
Accuracy                       0.9950120918984281
Adjusted Rand Index            0.9902046138824164
Normalized Mutual Information  0.9851338471320409


100%|██████████| 10/10 [00:47<00:00,  4.77s/it]


## Conclusion

Perturbation response prediction is evaluated on a held out cell type group. The model given the control cells of a particular cell type, predicts the stimulated gene expression profiles, and with post-processing analysis, the DEGs are identified. scButterfly, scGen, and scPreGAN, do the comparison based on 1) the common DEGs, and 2) .

The issue in our initial thinking is that the classification task can't be tested with the above setup. To hold out a cell type group to test it for perturbation response, won't be an equivalent meaningful evaluation for the classification, since we will have only one type. At the same time, the cell type knowledge is a prerequisite to train and test the model, so the cell type classification task seems not a meaningful one.

In other words, the UnitedNet's classification task is meaningful, since it is possible from an experiment to have a multiple modality information for a cell, such as knowing the gene, protein or dna accessibility information. But from a perturbation perspective, the knowledge of a paired control and stimulated cell is unknown (we need cell type + optimal transport to create this pseudo information). On the other hand, we could assume that given the control cell, a model can predict the stimulated, and then having both, it could be leveraged for cell type. This would be an interesting use case to try it, and test the classification task, but the limitations of the model in the perturbation response will be carried along as the input for another task. That's we should explore other workflows.

Overall, we could change the way we test, by keeping from all cell types a portion of it for testing, instead of holding out a whole cell type group (as we currently do for the validation set). Having all the above combined, UnitedNet, as it is, with the translation, and classification task can be explored on perturbation modeling.