# Assess label transfer using weighted accuracy

The reviewer suggested we assess the result of the label transfer using weighted accuracy.
Unfortunately, scikit-learn does not have the function to calculate weighted accuracy.
It has balanced accuracy which essentially calculate the accuracy for each true cell type and take the mean (**note: not weighted mean**).

The `classification_report` function calculates precision, recall, F1-score for each true cell type.
It gives accuracy, but I don't think it is calculated for each true cell type. 
This is rather just the accuracy in general, i.e. number of correctly labelled cells / total number of cells.

So, we have to do the calculation by hand. 

There is a big challenge here in that the true and predicted label are not the same.
Not only the naming are not the same, but the predicted label (the CITEseq label) are much more granular. 
Additionally, some true labels may not exist in the predicted labels, and vice versa.
For example, Basophils in the Cytof data (true label) does not exist in the CITEseq data. 

So to get around this, we will try to match the labels between the two modalities ***manually***.
The cell types that have differentiated, can be grouped rather easily.
Those that have not, typically your stem cells and progenitors, can't be grouped easily. 
We will do our best.

For cell types that only exist in one modality, we will just leave the labels as they are.
In other words, cell types that only exist in the Cytof data, such as Basophils, are guaranteed to all be misclassified.
If a cell in the Cytof data is given a label from the CITEseq data that does not exist in the Cytof data, e.g., GammaDelta T cells that only exist in CITEseq but not Cytof, it will count as just wrongly classified. 

In [6]:
import pandas as pd
import numpy as np
from sklearn.metrics import classification_report

Note, the data below has been pre-processed using R where cells that are not assigned cell types have been removed.

See `draw_umap.R` script in `code/label_transfer/additional_code`.

In [7]:
df = pd.read_csv("~/Documents/GitHub/SuperCellCyto-analysis/output/label_transfer/umap.csv")
df.head()

Unnamed: 0,Time,Cell_length,DNA1,DNA2,CD45RA,CD133,CD19,CD22,CD11b,CD4,...,CD41_asinh_cf5,Gated_Population,Sample,CellId,UMAP_X,UMAP_Y,harmony_supercell,rpca_supercell,rpca_singlecell,harmony_singlecell
0,2693.0,22,201.783295,253.016647,0.817049,-0.147947,-0.033482,0.332183,-0.045922,1.858334,...,-0.001961,Basophils,H1,Cell_1,-8.491841,0.86061,Late erythroid progenitor,Aberrant erythroid,Aberrant erythroid,Late erythroid progenitor
1,3736.0,35,191.828598,308.86908,3.801385,-0.191446,-0.083274,0.372388,4.494379,-0.177158,...,0.868014,Basophils,H1,Cell_2,-7.229076,-0.115077,Plasmacytoid dendritic cells,Plasmacytoid dendritic cells,Aberrant erythroid,Plasmacytoid dendritic cell progenitors
2,7015.0,32,116.111923,200.839218,3.204439,-0.161106,0.369613,-0.214952,-0.009404,-0.043904,...,-0.010413,Basophils,H1,Cell_3,-8.54015,1.096968,Aberrant erythroid,Aberrant erythroid,Late erythroid progenitor,Late erythroid progenitor
3,7099.0,29,176.248505,313.022461,2.237382,-0.138071,-0.088311,-0.22043,4.006598,-0.095335,...,-0.026039,Basophils,H1,Cell_4,-7.897731,0.045576,Plasmacytoid dendritic cells,Plasmacytoid dendritic cells,Aberrant erythroid,Plasmacytoid dendritic cells
4,7700.0,25,133.332779,226.467758,-0.044047,-0.151509,0.402548,2.581769,6.74206,2.906627,...,-0.040488,Basophils,H1,Cell_5,-7.830106,-0.183927,Plasmacytoid dendritic cells,Plasmacytoid dendritic cells,Aberrant erythroid,Conventional dendritic cell 1


In [8]:
df.columns

Index(['Time', 'Cell_length', 'DNA1', 'DNA2', 'CD45RA', 'CD133', 'CD19',
       'CD22', 'CD11b', 'CD4', 'CD8', 'CD34', 'Flt3', 'CD20', 'CXCR4',
       'CD235ab', 'CD45', 'CD123', 'CD321', 'CD14', 'CD33', 'CD47', 'CD11c',
       'CD7', 'CD15', 'CD16', 'CD44', 'CD38', 'CD13', 'CD3', 'CD61', 'CD117',
       'CD49d', 'HLA-DR', 'CD64', 'CD41', 'Viability', 'file_number',
       'event_number', 'CD45RA_asinh_cf5', 'CD133_asinh_cf5', 'CD19_asinh_cf5',
       'CD22_asinh_cf5', 'CD11b_asinh_cf5', 'CD4_asinh_cf5', 'CD8_asinh_cf5',
       'CD34_asinh_cf5', 'Flt3_asinh_cf5', 'CD20_asinh_cf5', 'CXCR4_asinh_cf5',
       'CD235ab_asinh_cf5', 'CD45_asinh_cf5', 'CD123_asinh_cf5',
       'CD321_asinh_cf5', 'CD14_asinh_cf5', 'CD33_asinh_cf5', 'CD47_asinh_cf5',
       'CD11c_asinh_cf5', 'CD7_asinh_cf5', 'CD15_asinh_cf5', 'CD16_asinh_cf5',
       'CD44_asinh_cf5', 'CD38_asinh_cf5', 'CD13_asinh_cf5', 'CD3_asinh_cf5',
       'CD61_asinh_cf5', 'CD117_asinh_cf5', 'CD49d_asinh_cf5',
       'HLA-DR_asinh_cf5', '

We will reconcile the subsets of HSPCs and HSCs into one population name "CD34+_HSCs_and_HSPCs".

Why do we do this? It is because the CITEseq data contains way too many subsets of progenitors and it is not clear how they were defined based on CD38 and CD123 expression. 
Since they are all CD34+, there is commonality there, and thus we can group them.

However, we also then need to make sure that the CITEseq subset that we "group" together and match to CD34+_HSCs_and_HSPCs have to be CD34+. 

In [9]:
# reconcile the various HSCs and HSPCs we have into just HSCs and HSCPs
reconciled_stem_and_prog_label = 'CD34+_HSCs_and_HSPCs'
def reconcile_gated_label(label):
    if label in ['CD34+CD38+CD123-_HSPCs', 'CD34+CD38+CD123+_HSPCs', 'CD34+CD38lo_HSCs']:
        return 'CD34+_HSCs_and_HSPCs'
    else:
        return label
df['Gated_Population_reconciled'] = df['Gated_Population'].apply(reconcile_gated_label)
df.head()

Unnamed: 0,Time,Cell_length,DNA1,DNA2,CD45RA,CD133,CD19,CD22,CD11b,CD4,...,Gated_Population,Sample,CellId,UMAP_X,UMAP_Y,harmony_supercell,rpca_supercell,rpca_singlecell,harmony_singlecell,Gated_Population_reconciled
0,2693.0,22,201.783295,253.016647,0.817049,-0.147947,-0.033482,0.332183,-0.045922,1.858334,...,Basophils,H1,Cell_1,-8.491841,0.86061,Late erythroid progenitor,Aberrant erythroid,Aberrant erythroid,Late erythroid progenitor,Basophils
1,3736.0,35,191.828598,308.86908,3.801385,-0.191446,-0.083274,0.372388,4.494379,-0.177158,...,Basophils,H1,Cell_2,-7.229076,-0.115077,Plasmacytoid dendritic cells,Plasmacytoid dendritic cells,Aberrant erythroid,Plasmacytoid dendritic cell progenitors,Basophils
2,7015.0,32,116.111923,200.839218,3.204439,-0.161106,0.369613,-0.214952,-0.009404,-0.043904,...,Basophils,H1,Cell_3,-8.54015,1.096968,Aberrant erythroid,Aberrant erythroid,Late erythroid progenitor,Late erythroid progenitor,Basophils
3,7099.0,29,176.248505,313.022461,2.237382,-0.138071,-0.088311,-0.22043,4.006598,-0.095335,...,Basophils,H1,Cell_4,-7.897731,0.045576,Plasmacytoid dendritic cells,Plasmacytoid dendritic cells,Aberrant erythroid,Plasmacytoid dendritic cells,Basophils
4,7700.0,25,133.332779,226.467758,-0.044047,-0.151509,0.402548,2.581769,6.74206,2.906627,...,Basophils,H1,Cell_5,-7.830106,-0.183927,Plasmacytoid dendritic cells,Plasmacytoid dendritic cells,Aberrant erythroid,Conventional dendritic cell 1,Basophils


In [10]:
def reconcile_label(label):
    if label == 'CD56brightCD16- NK cells':
        return 'CD16-_NK_cells'
    elif label == 'CD56dimCD16+ NK cells':
        return 'CD16+_NK_cells'
    elif label in ['CD4+ cytotoxic T cells', 'CD4+ memory T cells', 'Naive CD4+T cells']:
        return 'CD4_T_cells'
    elif label in ["CD8+ central memory T cells", "CD8+ effector memory T cells", "CD8+ naive T cells", "CD8+CD103+ tissue resident memory T cells"]:
        return 'CD8_T_cells'
    elif label in ['CD11c+ memory B cells', 'Mature naive B cells', 'Class switched memory B cells', 'Nonswitched memory B cells']:
        return 'Mature_B_cells'
    elif label in ['Classical Monocytes', 'Non-classical monocytes']:
        return 'Monocytes'
    elif label == 'Plasmacytoid dendritic cells':
        return 'pDCs'
    elif label == 'Plasma cells':
        return 'Plasma_B_cells'
    elif label == 'Small pre-B cell':
        return 'Pre_B_cells'
    elif label == 'pro-B cells':
        return 'Pro_B_cells'
    elif label in ['Lymphoid-primed multipotent progenitors', 'Megakaryocyte progenitors', 
                   'Erythro-myeloid progenitors', 'NK cell progenitors', 'HSCs & MPPs']:
        return reconciled_stem_and_prog_label
    else:
        return label

In [11]:
df['harmony_supercell_reconciled'] = df['harmony_supercell'].apply(reconcile_label)
df['rpca_supercell_reconciled'] = df['rpca_supercell'].apply(reconcile_label)
df['rpca_singlecell_reconciled'] = df['rpca_singlecell'].apply(reconcile_label)
df['harmony_singlecell_reconciled'] = df['harmony_singlecell'].apply(reconcile_label)
df.head()

Unnamed: 0,Time,Cell_length,DNA1,DNA2,CD45RA,CD133,CD19,CD22,CD11b,CD4,...,UMAP_Y,harmony_supercell,rpca_supercell,rpca_singlecell,harmony_singlecell,Gated_Population_reconciled,harmony_supercell_reconciled,rpca_supercell_reconciled,rpca_singlecell_reconciled,harmony_singlecell_reconciled
0,2693.0,22,201.783295,253.016647,0.817049,-0.147947,-0.033482,0.332183,-0.045922,1.858334,...,0.86061,Late erythroid progenitor,Aberrant erythroid,Aberrant erythroid,Late erythroid progenitor,Basophils,Late erythroid progenitor,Aberrant erythroid,Aberrant erythroid,Late erythroid progenitor
1,3736.0,35,191.828598,308.86908,3.801385,-0.191446,-0.083274,0.372388,4.494379,-0.177158,...,-0.115077,Plasmacytoid dendritic cells,Plasmacytoid dendritic cells,Aberrant erythroid,Plasmacytoid dendritic cell progenitors,Basophils,pDCs,pDCs,Aberrant erythroid,Plasmacytoid dendritic cell progenitors
2,7015.0,32,116.111923,200.839218,3.204439,-0.161106,0.369613,-0.214952,-0.009404,-0.043904,...,1.096968,Aberrant erythroid,Aberrant erythroid,Late erythroid progenitor,Late erythroid progenitor,Basophils,Aberrant erythroid,Aberrant erythroid,Late erythroid progenitor,Late erythroid progenitor
3,7099.0,29,176.248505,313.022461,2.237382,-0.138071,-0.088311,-0.22043,4.006598,-0.095335,...,0.045576,Plasmacytoid dendritic cells,Plasmacytoid dendritic cells,Aberrant erythroid,Plasmacytoid dendritic cells,Basophils,pDCs,pDCs,Aberrant erythroid,pDCs
4,7700.0,25,133.332779,226.467758,-0.044047,-0.151509,0.402548,2.581769,6.74206,2.906627,...,-0.183927,Plasmacytoid dendritic cells,Plasmacytoid dendritic cells,Aberrant erythroid,Conventional dendritic cell 1,Basophils,pDCs,pDCs,Aberrant erythroid,Conventional dendritic cell 1


# Weighted accuracy

Calculate accuracy per class, so number of correctly classified / number of cells in the class.
Weight computed as number of cells in the class / total number of cells.

Weighted accuracy per class is then accuracy * weight.

Final reported value is the sum of the weighted accuracy of all classes

In [12]:
# First we need to compute the weight. 
n_cells = df['Gated_Population_reconciled'].value_counts() / df.shape[0]
n_cells

Gated_Population_reconciled
CD4_T_cells             0.253071
Monocytes               0.202517
CD8_T_cells             0.193005
Mature_B_cells          0.158566
Pre_B_cells             0.058886
CD34+_HSCs_and_HSPCs    0.043337
CD16-_NK_cells          0.037482
CD16+_NK_cells          0.021577
pDCs                    0.011883
Basophils               0.011585
Pro_B_cells             0.004924
Plasma_B_cells          0.003167
Name: count, dtype: float64

In [13]:
weighted_accuracy_per_method = {}
accuracy_per_method_and_celltype = {}
methods = ['harmony_supercell_reconciled', 'harmony_singlecell_reconciled', 'rpca_supercell_reconciled', 'rpca_singlecell_reconciled']
for method in methods:
    weighted_accuracies = []
    for i, v in n_cells.items():
        df_subset = df[df['Gated_Population_reconciled'] == i]
        accuracy = df_subset[df_subset[method] == i].shape[0] / df_subset.shape[0]

        if method not in accuracy_per_method_and_celltype.keys():
            accuracy_per_method_and_celltype[method] = {}
        accuracy_per_method_and_celltype[method][i] = accuracy
        
        weighted_accuracy = v * accuracy
        weighted_accuracies.append(weighted_accuracy)
        
    weighted_accuracy_per_method[method] = np.sum(weighted_accuracies)

In [14]:
# have to reshape the dictionary
weighted_accuracy_per_method = pd.DataFrame({
    "method": weighted_accuracy_per_method.keys(), 
    "weighted_accuracy": weighted_accuracy_per_method.values()
})
weighted_accuracy_per_method

Unnamed: 0,method,weighted_accuracy
0,harmony_supercell_reconciled,0.665064
1,harmony_singlecell_reconciled,0.544441
2,rpca_supercell_reconciled,0.711549
3,rpca_singlecell_reconciled,0.650484


In [15]:
accuracy_per_method_and_celltype = pd.DataFrame(accuracy_per_method_and_celltype)
accuracy_per_method_and_celltype

Unnamed: 0,harmony_supercell_reconciled,harmony_singlecell_reconciled,rpca_supercell_reconciled,rpca_singlecell_reconciled
CD4_T_cells,0.748502,0.408101,0.911894,0.781802
Monocytes,0.34286,0.144651,0.512204,0.305986
CD8_T_cells,0.977671,0.971902,0.918689,0.967476
Mature_B_cells,0.952845,0.971308,0.917676,0.793462
Pre_B_cells,0.117196,0.111491,0.0,0.0
CD34+_HSCs_and_HSPCs,0.604208,0.551495,0.197121,0.280842
CD16-_NK_cells,0.191037,0.332907,0.395391,0.883483
CD16+_NK_cells,0.930605,0.985765,0.893238,0.966192
pDCs,0.267367,0.260097,0.390953,0.418417
Basophils,0.0,0.0,0.0,0.0


Export results

In [16]:
weighted_accuracy_per_method.to_csv("/Users/putri.g/Documents/GitHub/SuperCellCyto-analysis/output/label_transfer/weighted_accuracy.csv", 
                                    index=False)
accuracy_per_method_and_celltype.to_csv("/Users/putri.g/Documents/GitHub/SuperCellCyto-analysis/output/label_transfer/accuracy_per_celltype.csv")

In [29]:
# so we can rename "count" to the right name as it's not count.
weights = pd.DataFrame(n_cells).rename(columns={"count": "proportion_of_cells"})
weights.to_csv("/Users/putri.g/Documents/GitHub/SuperCellCyto-analysis/output/label_transfer/accuracy_weights.csv")

# Additional code

Classification report. Run me if required.

In [None]:
print(classification_report(df['Gated_Population_reconciled'], df['harmony_supercell_reconciled']))

In [None]:
print(classification_report(df['Gated_Population_reconciled'], df['harmony_singlecell_reconciled']))

In [None]:
print(classification_report(df['Gated_Population_reconciled'], df['rpca_supercell_reconciled']))

In [None]:
print(classification_report(df['Gated_Population_reconciled'], df['rpca_singlecell_reconciled']))