In [234]:
# pylint: disable=import-error, redefined-outer-name
"""Workbooks to analyze metadata."""
import itertools
import random
from collections import Counter, defaultdict
from pathlib import Path

import pandas as pd
from IPython.display import display

from epi_ml.core.metadata import Metadata
from epi_ml.utils.modify_metadata import TRACKS_MAPPING, epiatlas_assays, filter_by_pairs

CELL_TYPE = "harmonized_sample_ontology_intermediate"
ASSAY = "assay_epiclass"

In [210]:
base = Path("/home/local/USHERBROOKE/rabj2301/Projects/epilap/input/metadata")
path = base / "hg38_2023_epiatlas_dfreeze_plus_encode_noncore_formatted_JR.json"
my_metadata = Metadata(path)

In [233]:
# my_metadata.get_categories()

In [251]:
def display_gen_info(metadata: Metadata):
    """Display track type, assay and cell type class counts."""
    metadata.display_labels("track_type")
    metadata.display_labels(ASSAY)
    metadata.display_labels(CELL_TYPE)

In [213]:
display_gen_info(my_metadata)


Label breakdown for track_type
0 labels missing and ignored from count
raw: 9552
fc: 5471
pval: 5471
Unique_minusRaw: 1463
Unique_plusRaw: 1463
ctl_raw: 965
gembs_neg: 644
gembs_pos: 644
Unique_raw: 159
For a total of 25832 examples


Label breakdown for assay_epiclass
0 labels missing and ignored from count
h3k27ac: 4739
non-core: 3599
h3k4me1: 2958
h3k4me3: 2411
rna_seq: 2365
h3k36me3: 2159
h3k27me3: 2111
h3k9me3: 2049
wgbs-standard: 1024
input: 965
mrna_seq: 720
CTCF: 468
wgbs-pbat: 264
For a total of 25832 examples


Label breakdown for harmonized_sample_ontology_intermediate
4067 labels missing and ignored from count
T cell: 3004
monocyte: 2090
neutrophil: 2059
brain: 1595
lymphocyte of B lineage: 1516
myeloid cell: 1149
macrophage: 741
mesoderm-derived structure: 704
venous blood: 681
endoderm-derived structure: 613
colon: 605
connective tissue cell: 559
hepatocyte: 496
muscle organ: 400
mammary gland epithelial cell: 393
hematopoietic cell: 340
extraembryonic cell: 322
epitheli

In [214]:
my_metadata.select_category_subsets("track_type", TRACKS_MAPPING.keys())
my_metadata.select_category_subsets(ASSAY, epiatlas_assays)

In [215]:
my_metadata.remove_category_subsets(CELL_TYPE, [""])
my_metadata.remove_small_classes(10, CELL_TYPE)

0 labels missing and ignored from count
55/66 labels left from harmonized_sample_ontology_intermediate after removing classes with less than 10 signals.


In [216]:
display_gen_info(my_metadata)


Label breakdown for track_type
0 labels missing and ignored from count
raw: 5443
Unique_plusRaw: 1451
ctl_raw: 957
gembs_pos: 635
For a total of 8486 examples


Label breakdown for assay_epiclass
0 labels missing and ignored from count
h3k27ac: 1575
rna_seq: 1174
h3k4me1: 982
input: 957
h3k4me3: 796
h3k36me3: 712
h3k27me3: 697
h3k9me3: 681
wgbs-standard: 508
mrna_seq: 277
wgbs-pbat: 127
For a total of 8486 examples


Label breakdown for harmonized_sample_ontology_intermediate
0 labels missing and ignored from count
T cell: 1189
monocyte: 860
neutrophil: 820
lymphocyte of B lineage: 596
brain: 575
myeloid cell: 448
macrophage: 293
mesoderm-derived structure: 273
venous blood: 266
endoderm-derived structure: 240
colon: 236
connective tissue cell: 221
hepatocyte: 195
muscle organ: 164
mammary gland epithelial cell: 153
extraembryonic cell: 132
hematopoietic cell: 130
meso-epithelial cell: 119
epithelial cell derived cell line: 118
digestive system: 117
endo-epithelial cell: 105
mucosa: 7

In [252]:
def count_trios(metadata: Metadata) -> Counter:
    """
    Count the occurrences of unique (track_type, assay, cell_type) trios in the metadata.

    Returns:
        Counter: A Counter object of the unique trios.
    """
    trios = Counter(
        [(dset["track_type"], dset[ASSAY], dset[CELL_TYPE]) for dset in metadata.datasets]
    )
    return trios

In [218]:
my_metadata = filter_by_pairs(
    my_metadata=my_metadata, assay_cat=ASSAY, cat2=CELL_TYPE, nb_pairs=1, min_per_pair=10
)

Applying metadata filter function 'filter_by_pairs'.
Still 28 harmonized_sample_ontology_intermediate labels.
Keeping 28 harmonized_sample_ontology_intermediate labels.

Label breakdown for harmonized_sample_ontology_intermediate
0 labels missing and ignored from count
T cell: 1189
monocyte: 854
neutrophil: 819
lymphocyte of B lineage: 596
brain: 575
myeloid cell: 448
macrophage: 280
mesoderm-derived structure: 268
venous blood: 266
colon: 232
endoderm-derived structure: 231
connective tissue cell: 214
hepatocyte: 194
muscle organ: 162
mammary gland epithelial cell: 152
extraembryonic cell: 132
hematopoietic cell: 113
meso-epithelial cell: 107
digestive system: 100
epithelial cell derived cell line: 97
endo-epithelial cell: 97
mucosa: 60
lymphoma or leukaemia cell line: 51
lymph node: 45
neural cell: 31
cancer cell line: 20
kidney: 12
ESC derived cell line: 10
For a total of 7355 examples



In [219]:
display_gen_info(my_metadata)


Label breakdown for track_type
0 labels missing and ignored from count
raw: 4745
Unique_plusRaw: 1251
ctl_raw: 860
gembs_pos: 499
For a total of 7355 examples


Label breakdown for assay_epiclass
0 labels missing and ignored from count
h3k27ac: 1449
rna_seq: 1072
h3k4me1: 866
input: 860
h3k4me3: 690
h3k36me3: 589
h3k9me3: 580
h3k27me3: 571
wgbs-standard: 428
mrna_seq: 179
wgbs-pbat: 71
For a total of 7355 examples


Label breakdown for harmonized_sample_ontology_intermediate
0 labels missing and ignored from count
T cell: 1189
monocyte: 854
neutrophil: 819
lymphocyte of B lineage: 596
brain: 575
myeloid cell: 448
macrophage: 280
mesoderm-derived structure: 268
venous blood: 266
colon: 232
endoderm-derived structure: 231
connective tissue cell: 214
hepatocyte: 194
muscle organ: 162
mammary gland epithelial cell: 152
extraembryonic cell: 132
hematopoietic cell: 113
meso-epithelial cell: 107
digestive system: 100
epithelial cell derived cell line: 97
endo-epithelial cell: 97
mucosa: 60
l

In [235]:
trios = count_trios(my_metadata)
print(len(trios), len(trios) * 5)

211 1055


In [236]:
# display(trios.most_common())

In [222]:
display_gen_info(my_metadata)
my_metadata.display_labels("harmonized_donor_sex")


Label breakdown for track_type
0 labels missing and ignored from count
raw: 4745
Unique_plusRaw: 1251
ctl_raw: 860
gembs_pos: 499
For a total of 7355 examples


Label breakdown for assay_epiclass
0 labels missing and ignored from count
h3k27ac: 1449
rna_seq: 1072
h3k4me1: 866
input: 860
h3k4me3: 690
h3k36me3: 589
h3k9me3: 580
h3k27me3: 571
wgbs-standard: 428
mrna_seq: 179
wgbs-pbat: 71
For a total of 7355 examples


Label breakdown for harmonized_sample_ontology_intermediate
0 labels missing and ignored from count
T cell: 1189
monocyte: 854
neutrophil: 819
lymphocyte of B lineage: 596
brain: 575
myeloid cell: 448
macrophage: 280
mesoderm-derived structure: 268
venous blood: 266
colon: 232
endoderm-derived structure: 231
connective tissue cell: 214
hepatocyte: 194
muscle organ: 162
mammary gland epithelial cell: 152
extraembryonic cell: 132
hematopoietic cell: 113
meso-epithelial cell: 107
digestive system: 100
epithelial cell derived cell line: 97
endo-epithelial cell: 97
mucosa: 60
l

In [254]:
def select_datasets(metadata: Metadata, n=5) -> list:
    """
    Select a random subset of n datasets for each unique (track_type, assay, cell_type) trio.

    Returns:
        list: A list of sampled md5sums of the selected datasets.
    """
    trio_files = defaultdict(list)
    for md5sum, dset in metadata.items:
        trio = (dset["track_type"], dset[ASSAY], dset[CELL_TYPE])
        trio_files[trio].append(md5sum)
    print(len(trio_files))

    sampled_md5s = list(
        itertools.chain.from_iterable(
            [random.sample(md5_list, n) for md5_list in trio_files.values()]
        )
    )
    return sampled_md5s

In [239]:
md5s = select_datasets(my_metadata)
print(len(md5s))

211
1055


In [225]:
pd.DataFrame(md5s).to_csv("md5_shap_assay_background.list", index=False, header=False)

In [227]:
# import copy

# meta2 = copy.deepcopy(my_metadata)
meta2 = Metadata(path)

meta2.remove_category_subsets("track_type", ["Unique_raw"])
meta2.select_category_subsets(ASSAY, epiatlas_assays)

meta2.remove_category_subsets(CELL_TYPE, [""])
meta2.remove_small_classes(10, CELL_TYPE)

# meta2.select_category_subsets(CELL_TYPE, ["T cell"])
# meta2.display_labels(ASSAY)

# meta2.select_category_subsets(ASSAY, ["wgbs-pbat"])
# meta2.display_labels(CELL_TYPE)

display_gen_info(meta2)

0 labels missing and ignored from count
63/66 labels left from harmonized_sample_ontology_intermediate after removing classes with less than 10 signals.

Label breakdown for track_type
0 labels missing and ignored from count
raw: 5485
fc: 5471
pval: 5471
Unique_minusRaw: 1455
Unique_plusRaw: 1455
ctl_raw: 965
gembs_neg: 637
gembs_pos: 637
For a total of 21576 examples


Label breakdown for assay_epiclass
0 labels missing and ignored from count
h3k27ac: 4739
h3k4me1: 2958
h3k4me3: 2411
rna_seq: 2354
h3k36me3: 2159
h3k27me3: 2111
h3k9me3: 2049
wgbs-standard: 1018
input: 965
mrna_seq: 556
wgbs-pbat: 256
For a total of 21576 examples


Label breakdown for harmonized_sample_ontology_intermediate
0 labels missing and ignored from count
T cell: 2924
monocyte: 2088
neutrophil: 2059
brain: 1594
lymphocyte of B lineage: 1498
myeloid cell: 1149
macrophage: 741
mesoderm-derived structure: 702
venous blood: 681
endoderm-derived structure: 613
colon: 604
connective tissue cell: 559
hepatocyte: 494
m

In [255]:
def count_cell_types(metadata: Metadata):
    """
    Count the occurrences of each cell type for each assay in the dataset.

    Returns:
        defaultdict(Counter): A defaultdict of Counter objects with the count of cell types per assay.
    """
    cell_type_count = defaultdict(Counter)
    for dset in metadata.datasets:
        assay, cell_type = dset[ASSAY], dset[CELL_TYPE]
        cell_type_count[assay].update([cell_type])
    return cell_type_count


def select_cell_types(metadata: Metadata, n=70) -> defaultdict(list):
    """
    Determines which cell types are needed to attain n datasets, for a given assay.
    Starts with T cell and then selects the most common cell types.

    Args:
        metadata (Metadata): A Metadata object containing dataset metadata.
        n (int, optional): Maximum number of cell types to select for each assay. Defaults to 70.

    Returns:
        defaultdict(list): A defaultdict with selected cell types for each assay.
    """
    cell_count = count_cell_types(metadata)

    selected_ct = defaultdict(list)
    for assay, counter in cell_count.items():
        selected_ct[assay].append("T cell")
        i = min(counter["T cell"], n)
        del counter["T cell"]
        while i < n and counter:
            for cell_type, count in counter.most_common():
                i += min(count, n - i)
                selected_ct[assay].append(cell_type)
                del counter[cell_type]
                break
        if i < n:
            print(f"There is not at least {n} files for {assay}. Final number={i}")

    return selected_ct

In [247]:
# count_cell_types(my_metadata)
selected_ct = select_cell_types(my_metadata, n=70)
display(selected_ct)

defaultdict(list,
            {'h3k27ac': ['T cell'],
             'h3k27me3': ['T cell', 'lymphocyte of B lineage'],
             'h3k36me3': ['T cell', 'lymphocyte of B lineage'],
             'input': ['T cell'],
             'h3k4me1': ['T cell'],
             'h3k4me3': ['T cell'],
             'h3k9me3': ['T cell', 'lymphocyte of B lineage'],
             'rna_seq': ['T cell'],
             'wgbs-standard': ['T cell', 'lymphocyte of B lineage'],
             'mrna_seq': ['T cell',
              'colon',
              'endoderm-derived structure',
              'connective tissue cell'],
             'wgbs-pbat': ['T cell',
              'hepatocyte',
              'extraembryonic cell',
              'endo-epithelial cell',
              'connective tissue cell']})

In [248]:
def select_explain_files(
    metadata: Metadata, selected_cell_types: defaultdict(list), n: int
):
    """
    Sample 'n' random datasets for each assay from the given cell types.

    Args:
        metadata (Metadata): A Metadata object containing dataset metadata.
        selected_cell_types (defaultdict(list)): A dictionary with selected cell types for each assay.
        n (int): Number of files to select for each assay.

    Returns:
        list: A list of sampled md5sums of the selected files.
    """
    md5sums = defaultdict(list)
    for md5sum, dset in metadata.items:
        assay, ct = dset[ASSAY], dset[CELL_TYPE]
        if ct in selected_cell_types[assay]:
            md5sums[assay].append(md5sum)

    sampled_md5s = list(
        itertools.chain.from_iterable(
            [random.sample(md5_list, n) for md5_list in md5sums.values()]
        )
    )

    return sampled_md5s

In [249]:
files_to_explain = select_explain_files(meta2, selected_ct, 70)
len(files_to_explain)

770

In [250]:
pd.DataFrame(files_to_explain).to_csv(
    "md5_shap_assay_explain.list", index=False, header=False
)