# Dataset Exploration
This notebook serves to dive deeper into the properties of our data.

In [1]:
import pandas as pd
from pathlib import Path

%load_ext autoreload
%autoreload 2

In [3]:
# directories we will need
git_root = Path('..')
data_root = git_root / 'data'
assert data_root.exists()

# dataset split files
b1_train_path = data_root / 'splits' / 'batch1_training_filtered.tsv'
b1_val_path = data_root / 'splits' / 'batch1_val_filtered.tsv'
b1_test_path = data_root / 'splits' / 'batch1_test_filtered.tsv'
b2_train_path = data_root / 'splits' / 'batch2_training_filtered.tsv'
b2_val_path = data_root / 'splits' / 'batch2_val_filtered.tsv'
b2_test_path = data_root / 'splits' / 'batch2_test_filtered.tsv'

In [4]:
b1_full_data = data_root / 'preprocessing' / 'classification_median_batch_1.tsv'
b2_full_data = data_root / 'preprocessing' / 'classification_median_batch_2.tsv'

## Barcode - gene pairs
Previously, we noticed during embedding calculation, that not all barcodes have targets/labels all genes, which is required for supervised learning. 
Thus, we analyse the barcode - gene pairs in the data here.
In particular, we want to identify the minimal gene set, that is shared among all barcodes of a batch.

In [None]:
# algoritm idea make iterative intersections.
# 1. start will a set of all uniq genes
# 2. for each uniq barcode intersect the uniq gene ids associated to this barcode
# 3. return the result of all intersections

In [13]:
def find_minimal_gene_set(df: pd.DataFrame):
    uniq_gene_ids = set(df['gene_id'])
    uniq_barodes = set(df['barcode'])
    minimal_gene_set = uniq_gene_ids
    for barcode in uniq_barodes:
        minimal_gene_set = minimal_gene_set.intersection(
            set(df[df['barcode'] == barcode]['gene_id'])
        )
    return minimal_gene_set

In [8]:
uniq_gene_ids = set(b1_df['gene_id'])
len(uniq_gene_ids)
list(uniq_gene_ids)[:10]

['ENSG00000186047',
 'ENSG00000188001',
 'ENSG00000233008',
 'ENSG00000154485',
 'ENSG00000196787',
 'ENSG00000287200',
 'ENSG00000224982',
 'ENSG00000130844',
 'ENSG00000152583',
 'ENSG00000053254']

In [9]:
uniq_barodes = set(b1_df['barcode'])
len(uniq_barodes)
list(uniq_barodes)[:10]

['ACCCTGTTCCAGGAAA-1',
 'TATATCCTCCTGGTCT-1',
 'TTCGGTACAAATTGCT-1',
 'GTACTGGTCATGTGGT-1',
 'CTTGCGCGTTTATGGG-1',
 'GCACTAAGTTTACGTC-1',
 'GCCTTAACATTGACAT-1',
 'TGAAGGATCATTTGCT-1',
 'CTCCTGAGTTGTGACA-1',
 'TTAAGTGTCAGGTCCA-1']

In [10]:
minimal_gene_overlap = uniq_gene_ids
for barcode in uniq_barodes:
    minimal_gene_overlap = minimal_gene_overlap.intersection(
        set(b1_df[b1_df['barcode'] == barcode]['gene_id'])
    )
len(minimal_gene_overlap)

2000

In [None]:
b1_min_gene_set = find_minimal_gene_set(
    b1_df = pd.read_csv(b1_full_data, sep='\t')
)
len(b1_min_gene_set)

In vestigate minimal gene set for batch 2.

In [12]:
b2_df = pd.read_csv(b2_full_data, sep='\t')

In [14]:
b2_min_gene_set = find_minimal_gene_set(b2_df)
len(b2_min_gene_set)

2000

### Minimal gene set of splits
As we see, the minimal gene overlap contains all 2000 genes we selected for our analysis.
Thus, check if the error arrises from the split data.

#### Batch 1 train
Minimal gene overlap set for training split of batch 1.

In [15]:
b1_train_min_gene_set = find_minimal_gene_set(
    pd.read_csv(b1_train_path, sep='\t')
)
len(b1_train_min_gene_set)

0

#### Batch 1 validation
Minimal gene overlap set for training split of batch 1.

In [16]:
b1_val_min_gene_set = find_minimal_gene_set(
    pd.read_csv(b1_val_path, sep='\t')
)
len(b1_val_min_gene_set)

0

#### Batch 1 test
Minimal gene overlap set for training split of batch 1.

In [17]:
b1_test_min_gene_set = find_minimal_gene_set(
    pd.read_csv(b1_test_path, sep='\t')
)
len(b1_test_min_gene_set)

0

#### Batch 2 train

In [18]:
b2_train_min_gene_set = find_minimal_gene_set(
    pd.read_csv(b2_train_path, sep='\t')
)
len(b2_train_min_gene_set)

0

#### Batch 2 validation

In [19]:
b2_val_min_gene_set = find_minimal_gene_set(
    pd.read_csv(b2_val_path, sep='\t')
)
len(b2_val_min_gene_set)

0

#### Batch 2 test

In [20]:
b2_test_min_gene_set = find_minimal_gene_set(
    pd.read_csv(b2_test_path, sep='\t')
)
len(b2_test_min_gene_set)

0

## Dataset class
This section tests the dataset class `CnvDataset` from `src/data/dataset.py`.

In [None]:
import sys
sys.path.append('..') # add the parent directory to system path
from src.data.dataset import CnvDataset

### File format benchmark
We discussed multiple file formats to use in the backend for storing computed embeddings on the disk. Suggestions were the vanilla pytorch format `.pt`, pickle files `.pkl` and the scipy matrix format `.mtx`.
Since pytorch is using pickle in the backend for creating `.pt` files, we decided to only use the `.pt` and `.mtx` formats.
In the following we benchmark reading from these file types, as this could be a bottleneck during training. 

In [None]:
b1_val_path = data_root / 'splits' / 'batch1_val_filtered.tsv'
b1_val_dataset = CnvDataset(
    root=data_root / 'embeddings' / 'batch_1' / 'val',
    data_df=pd.read_csv(b1_val_path, sep='\t'),
    embedding_mode='single_gene_barcode',
    file_format='mtx'
)

In [None]:
b2_val_path = data_root / 'splits' / 'batch2_val_filtered.tsv'
b2_val_dataset = CnvDataset(
    root=data_root / 'embeddings' / 'batch_2',
    data_df=pd.read_csv(b2_val_path, sep='\t'),
    embedding_mode='single_gene_barcode',
    file_format='pt'
)

In [None]:
b1_test_path = b1_val_dataset.data_df['embedding_path'].iloc[42]
b2_test_path = b2_val_dataset.data_df['embedding_path'].iloc[42]
print(b1_test_path)
print(b2_test_path)

../data/embeddings/batch_1/val/single_gene_barcode/AAAGGTTAGGGTGGAT-1/ENSG00000117984.mtx
../data/embeddings/batch_2/single_gene_barcode/AAACCAACATTGCGGT-2/ENSG00000172985.pt


In [None]:
from torch import load as pyt_load
from scipy.io import mmread

In [None]:
%%timeit
t = mmread(b1_test_path)

3.93 ms ± 40.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%%timeit
t = pyt_load(b2_test_path)

257 μs ± 646 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
%%timeit
t = b1_val_dataset[42]

5.16 ms ± 204 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
%%timeit
t = b2_val_dataset[42]

511 μs ± 25.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


This means, using the `mtx` format is `# samples * # iterations * 1000 μs` longer