<a href="https://colab.research.google.com/github/vtecftwy/metagenomics/blob/refactor_cnn_virus/nbs/2_02_EC_preprocess_data_dataset_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Use `tf.data.Dataset` to preprocess data

In [3]:
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import psutil
import os
import shutil
import sys
import tensorflow as tf
from datetime import datetime, timedelta
from pathlib import Path
from tensorflow.keras.utils import to_categorical
from tensorflow.python.client import device_lib
print(f"Tensorflow version: {tf.__version__}\n")

%load_ext autoreload
%autoreload 2

devices = device_lib.list_local_devices()
print('\nDevices:')
for d in devices:
    t = d.device_type
    name = d.physical_device_desc
    l = [item.split(':', 1) for item in name.split(', ')]
    name_attr = dict([x for x in l if len(x)==2])
    dev = name_attr.get('name', ' ')
    print(f"  - {t}  {d.name} {dev:25s}")

Tensorflow version: 2.8.2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Devices:
  - CPU  /device:CPU:0                          


In [4]:
try:
    from google.colab import drive
    ON_COLAB = True
    print('Running on colab')
    print('Installing custom project code')
    try:
        from src.architecture import build_model
    except:
        !pip install -U git+https://github.com/vtecftwy/metagenomics.git@refactor_cnn_virus
    
    drive.mount('/content/gdrive')
    p2drive = Path('/content/gdrive/MyDrive/Metagenonics')
    assert p2drive.is_dir()
    p2data =  p2drive / 'CNN_Virus_data'
    assert p2data.is_dir()

except:
    ON_COLAB = False
    print('Running locally')
    print('Make sure you have installed the custom project code in your environment')
    pdata = Path('data/cnn_virus')
    assert p2data.is_dir()

Running on colab
Installing custom project code
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/vtecftwy/metagenomics.git@refactor_cnn_virus
  Cloning https://github.com/vtecftwy/metagenomics.git (to revision refactor_cnn_virus) to /tmp/pip-req-build-rlocd821
  Running command git clone -q https://github.com/vtecftwy/metagenomics.git /tmp/pip-req-build-rlocd821
  Running command git checkout -b refactor_cnn_virus --track origin/refactor_cnn_virus
  Switched to a new branch 'refactor_cnn_virus'
  Branch 'refactor_cnn_virus' set up to track remote branch 'refactor_cnn_virus' from 'origin'.
Building wheels for collected packages: src
  Building wheel for src (setup.py) ... [?25l[?25hdone
  Created wheel for src: filename=src-1.0.2-py3-none-any.whl size=14773 sha256=60f848650ced82a9fd766df29a177070a589ec97896d8b3b07d53b57e7e4ffce
  Stored in directory: /tmp/pip-ephem-wheel-cache-rvivvrr8/wheels/10/e4/2b/

# Import custom code and setup paths

In [5]:
from src.architecture import build_model
from src.preprocessing import get_learning_weights, get_params_50mer, get_kmer_from_50mer
from src.preprocessing import DataGenerator_from_50mer
from src.utils import TrainingExperiment

In [6]:
#path for the training file
filepath_train= p2data /"50mer_training"
assert filepath_train.is_file()
#path for the validating file
filepath_val= p2data / "50mer_validating"
assert filepath_val.is_file()
#path for the learning weights file
filepath_weights=p2data / "weight_of_classes"
assert filepath_weights.is_file()

# Review training data

- check size, structure, time to load

In [7]:
with open(filepath_train, 'r') as fp:
    lines = []
    for i in range(10):
        lines.append(fp.readline())
print(''.join(lines))

TCAAAATAATCAGAAATGTTGAACCTAGGGTTGGACACATAATGACCAGC	76	0
ATTGTTTAACAATTTGTGCTCGTCCCGGTCACCCGCATCCAATCTTGATG	4	9
AATCTTGTCCTATCCTACCCGCAGGGGAATTGATGATAGANGTGCTTTTA	181	0
GGAGCGGAGCCAACCCCTATGCTCACTTGCAACCCAAGGGGCGTTCCAGT	74	3
TGGATCCTGCGCGGGACGTCCTTTGTCTACGTCCCGTCGGCGCATCCCGC	60	3
GAGAGACTTACTAAAAAGCTGGCACTTACCATCAGTGTTTCACCTACATG	44	0
ACACACGACACTAGAGATAATGTGTCAGTGGATTATAAACAAACCAAGTT	43	7
TTGTAGCATAAGAACTGGTCTTCGCTGAAATTCTTGTCTTGATCTCATCT	35	2
TGGCCCTGCGGTCTGGGGCCCAGAAGCATATGTCAAGTCCTTTGAGAAGT	73	4
TAGATTTAGTGGTTAGGTAGTAAGGCTACAATGTAAACACGTAGTGGCAA	11	6



Compare reading time to pass the full file if on gdrive and on server disk

In [8]:
# From gdrive
%%time
nlines = 0
with open(filepath_train, 'r') as fp:
    while True:
        line = fp.readline()
        if line == '':
            break
        else:
            nlines += 1
print(f"{nlines:,d}")

50,903,296
CPU times: user 14.1 s, sys: 1.6 s, total: 15.7 s
Wall time: 29.5 s


From gdrive:
```
    50,903,296
    CPU times: user 17.4 s, sys: 1.49 s, total: 18.9 s
    Wall time: 19.5 s

    50,903,296
    CPU times: user 14.1 s, sys: 1.6 s, total: 15.7 s
    Wall time: 29.5 s
```

In [9]:
p2train_txt = Path('train')
shutil.copy(filepath_train, p2train_txt)

PosixPath('train')

In [11]:
# Locally on server
%%time
nlines = 0
with open(p2train_txt, 'r') as fp:
    while True:
        line = fp.readline()
        if line == '':
            break
        else:
            nlines += 1
print(f"{nlines:,d}")

50,903,296
CPU times: user 11.5 s, sys: 719 ms, total: 12.3 s
Wall time: 12.3 s


From "local" file:
```
    50,903,296
    CPU times: user 12.3 s, sys: 752 ms, total: 13.1 s
    Wall time: 12.9 s

    50,903,296
    CPU times: user 11.5 s, sys: 719 ms, total: 12.3 s
    Wall time: 12.3 s
```
Seems that copying the file locally before processing will accelerate a little the process. But not very clear.

## Review Data in `50mer_training` and `50mer_validating`

Data is provided as long tab separated text, organized as follows:
- one line per sample
- three columns per sample: *sequence*, *virus label* , *position label*
```
TCAAAATAATCAGAAATGTTGAACCTAGGGTTGGACACATAATGACCAGC	76	0
ATTGTTTAACAATTTGTGCTCGTCCCGGTCACCCGCATCCAATCTTGATG	4	9
AATCTTGTCCTATCCTACCCGCAGGGGAATTGATGATAGANGTGCTTTTA	181	0
GGAGCGGAGCCAACCCCTATGCTCACTTGCAACCCAAGGGGCGTTCCAGT	74	3
TGGATCCTGCGCGGGACGTCCTTTGTCTACGTCCCGTCGGCGCATCCCGC	60	3
GAGAGACTTACTAAAAAGCTGGCACTTACCATCAGTGTTTCACCTACATG	44	0
ACACACGACACTAGAGATAATGTGTCAGTGGATTATAAACAAACCAAGTT	43	7
TTGTAGCATAAGAACTGGTCTTCGCTGAAATTCTTGTCTTGATCTCATCT	35	2
TGGCCCTGCGGTCTGGGGCCCAGAAGCATATGTCAAGTCCTTTGAGAAGT	73	4
TAGATTTAGTGGTTAGGTAGTAAGGCTACAATGTAAACACGTAGTGGCAA	11	6
```

# Preprocess data

We can use 
- `tf.data.TextLineDataset` to define the initial dataset pointing to the file and returning strings
- then apply a transform to convet the string into 3 tensors datasets: `train_ds`, `label_ds`, `pos_ds`

## 1. Load data from text file using a dataset: `text_ds`
```python
    tf.data.TextLineDataset(
        filenames,
        compression_type=None,
        buffer_size=None,
        num_parallel_reads=None,
        name=None
    )
```

Create `text_ds` and print a few small batches

In [14]:
text_ds = tf.data.TextLineDataset(
    filepath_val,
    compression_type='',
    name='text_ds'
)

text_ds

<TextLineDatasetV2 element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>

In [15]:
text_ds.element_spec

TensorSpec(shape=(), dtype=tf.string, name=None)

```
TensorSpec(shape=(), dtype=tf.string, name=None)
```
- `shape=()` means each element is a scalar (a single value)
- `dtype=tf.string` means that each element is of type string.

Each element is a single string

Let's use `.take(n)` method to retrieve n elements one by one

In [16]:
text_ds.take(5)

<TakeDataset element_spec=TensorSpec(shape=(), dtype=tf.string, name=None)>

In [18]:
for element in text_ds.take(5):
    display(element)

<tf.Tensor: shape=(), dtype=string, numpy=b'CCATCGGCGTCCCGGAATCGTATACCGGGCACACGAAGCGTTATAACAAT\t6\t8'>

<tf.Tensor: shape=(), dtype=string, numpy=b'GTGGCTATGGTGAGGAGTTTGGAGGAAATAATATACATCATATATTCTGA\t6\t4'>

<tf.Tensor: shape=(), dtype=string, numpy=b'CCCTCTTCTGCAGACTGCTTACGGTTTCGTCCGTGTTGCAGTCGATTATC\t117\t0'>

<tf.Tensor: shape=(), dtype=string, numpy=b'GGAACGCGAACACGCCCGGAAGATTCTTCATCGTAATAAATGGACAGGTA\t2\t3'>

<tf.Tensor: shape=(), dtype=string, numpy=b'CAAACTGATATTCTTAGTGAAGAAAGACCACCTAATCATCATACCTACAT\t20\t4'>

What we want is to work with batches of elements. We can do that using the `.batch(n)` method.

In [19]:
for batch in text_ds.batch(4).take(3):
    display(batch)

<tf.Tensor: shape=(4,), dtype=string, numpy=
array([b'CCATCGGCGTCCCGGAATCGTATACCGGGCACACGAAGCGTTATAACAAT\t6\t8',
       b'GTGGCTATGGTGAGGAGTTTGGAGGAAATAATATACATCATATATTCTGA\t6\t4',
       b'CCCTCTTCTGCAGACTGCTTACGGTTTCGTCCGTGTTGCAGTCGATTATC\t117\t0',
       b'GGAACGCGAACACGCCCGGAAGATTCTTCATCGTAATAAATGGACAGGTA\t2\t3'],
      dtype=object)>

<tf.Tensor: shape=(4,), dtype=string, numpy=
array([b'CAAACTGATATTCTTAGTGAAGAAAGACCACCTAATCATCATACCTACAT\t20\t4',
       b'TTTACTATTTATATGNCTACTTGGATATCTTCAGTGTTGTGTGCGTCAGT\t32\t9',
       b'CTCTTATGTTTGAGCAACATCAAAATTCACTTATGAAGGAGAGGACACAT\t26\t1',
       b'GTGGTCTAGATATTTAAATAAATTTGACTATTATTTTCGGAGGTAAAATA\t52\t1'],
      dtype=object)>

<tf.Tensor: shape=(4,), dtype=string, numpy=
array([b'CCGCCGCCCGCCGCCCGGACACGTGCGCCCGGAGCGCCGCCGCCTCCTCG\t21\t6',
       b'GATGACTTGATAATAAAGTTGATACCTTCCACGTTCGATTTATGATCGGT\t18\t4',
       b'GCAGCCATGTCTATGGGTGCACAAGAACCCAGAAACATATTCCTGTTGTG\t13\t0',
       b'TTGTGAGCGAGGTTACCGGTGTCCAAGTCGTAATCCGAAAATCATATATT\t12\t7'],
      dtype=object)>

- `text_ds.batch(n)` returns batches as a tensor  of shape (4,) and `dtype=tf.string`, that is, a tensor of 4 rows, each row being one single string.
- each string corresponds to one line in the dataset


## 2. Create a function to transform the data into the desired formats
The function must take in a tensor of strings and returns three tensors. 

For a batch of `n` samples, the three tensors are:
- `x_seqs`: the input tensor, of shape (n, 50, 5) where each base letter in "base-hot-encoded" (BHE), i.e. "A":[1,0,0,0,0], "C":[0,1,0,0,0], "G":[0,0,1,0,0], "T":[0,0,0,1,0], "N":[0,0,0,0,1].
- `y_labels`: the virus target tensor, of shape (n, 187) with "one-hot-encoded" (OHE) virus labels
- `y_pos`: the position target tensor, of shape (n, 10) with "one-hot-encoded" (OHE) position labels

Steps to design this function
1. First, experiment with the transform, step by step. To do this, we pick a single batch from the dataset, then build the three tensors step by step.
2. When it works, create a function and test it on one batch of data
3. Finally, test it again by applying directly to the text_ds dataset

> Techical Note: 
>
> When applied to the dataset, the function may not make any assumption on the size of the batch, and cannot retrieve it either.

First attempt. It is rather slow.

In [None]:
RUN_SLOW_CODE = False
# can be run for output comparison with new method

if RUN_SLOW_CODE:
    now = datetime.now()
    batch_size = 1024

    # Create a text dataset and test each step on it
    print('Create text line dataset')
    ds = text_ds.batch(batch_size)
    it = iter(ds)
    b = next(it)
    duration = (datetime.now()-now).total_seconds()
    print('>>', duration)

    # EXPERIMENT WITH THE TRANSFORM
    # 1. Split the string in three: 
    # Notes:
    # tf.split returns a ragged tensor. It must be converted into a normal tensor
    # Tensor shape is (batch size, 3), one sequences, one for labels and one for position. Dtype is all tf.strings
    t = tf.strings.split(b, '\t').to_tensor(default_value = '', shape=[None, 3])
    print('\nLoad and split lines in three sections:')
    print(t.shape)
    print(t[:3, :])
    duration = (datetime.now()-now).total_seconds() - duration
    print('>>', duration)

    # 2. Split into sequences, labels and positions
    print('\nSplit string tensor into three tensors:')
    # Split string sequence in a sequence of single base strings 
    seqs = tf.strings.bytes_split(t[:, 0]).to_tensor(shape=(None, 50))
    # Labels and Posisionts are converted from tf.string to tf.int32
    labels = tf.strings.to_number(t[:, 1], out_type=tf.int32)
    pos = tf.strings.to_number(t[:, 2], out_type=tf.int32)
    print(seqs.shape, labels.shape, pos.shape)
    duration = (datetime.now()-now).total_seconds() - duration
    print('>>', duration)

    # 3. One-Hot-Encode labels, using only tf functions for performance
    print('\nOne-Hot-Encode labels')
    n_labels = 187
    y_labels = tf.gather(tf.eye(n_labels), labels)
    print(y_labels.shape)
    print(y_labels[:3, :15])
    duration = (datetime.now()-now).total_seconds() - duration
    print('>>', duration)

    # 4. One-Hot-Encode positions, using only tf functions for performance
    print('\nOne-Hot-Encode positions')
    n_pos = 10
    y_pos= tf.gather(tf.eye(n_pos), pos)
    print(y_pos.shape)
    print(y_pos[:3])
    duration = (datetime.now()-now).total_seconds() - duration
    print('>>', duration)

    # 5. Base-Hot-Encode sequences
    # Each base letter (A, C, G, T, N) is replaced by a OHE vector
    # Notes:
    # a. the batch of sequence seqs has a shape (batch_size, 50) after splitting each byte. 
    #    need to flatten it to apply the transform on each base, then reshape to original shape
    # b. We need to map each letter to one vector/tensor. 
    #    Would normally use a dict for that, but it does not work in this case because the key comes
    #    from the incoming tensor, and tensors are not hashable. Instead, we will use tf.case()
    print('\nBase-Hote-Encode the sequences')
    flattened_seqs = tf.reshape(seqs, shape=[-1])

    def base_hot_encoder(a):

        # Define the encoding functions returning the encoding tensor for each base
        def encode_A(): return tf.constant([1,0,0,0,0])
        def encode_C(): return tf.constant([0,1,0,0,0])
        def encode_G(): return tf.constant([0,0,1,0,0])
        def encode_T(): return tf.constant([0,0,0,1,0])
        def encode_N(): return tf.constant([0,0,0,0,1])

        # Define the mapping, as a list of tuples: (condition, encoding function)
        # values of the strings are tested against b'A', ... as bytes
        case_mapping = [
            (tf.math.equal(a, b'A'), encode_A),
            (tf.math.equal(a, b'C'), encode_C),
            (tf.math.equal(a, b'G'), encode_G),
            (tf.math.equal(a, b'T'), encode_T),
            (tf.math.equal(a, b'N'), encode_N)
        ]

        return tf.case(pred_fn_pairs = case_mapping)

    # Process seqs tensor
    processed_seqs = tf.map_fn(base_hot_encoder, flattened_seqs, fn_output_signature=tf.int32)
    print(processed_seqs.shape)
    x_seqs = tf.reshape(processed_seqs, shape=(-1, 50, 5))
    print(x_seqs.shape)
    duration = (datetime.now()-now).total_seconds() - duration
    print('>>', duration)

    # print(x_seqs)

Create text line dataset
>> 0.021305

Load and split lines in three sections:
(1024, 3)
tf.Tensor(
[[b'CCATCGGCGTCCCGGAATCGTATACCGGGCACACGAAGCGTTATAACAAT' b'6' b'8']
 [b'GTGGCTATGGTGAGGAGTTTGGAGGAAATAATATACATCATATATTCTGA' b'6' b'4']
 [b'CCCTCTTCTGCAGACTGCTTACGGTTTCGTCCGTGTTGCAGTCGATTATC' b'117' b'0']], shape=(3, 3), dtype=string)
>> 0.013077000000000002

Split string tensor into three tensors:
(1024, 50) (1024,) (1024,)
>> 0.042208999999999997

One-Hot-Encode labels
(1024, 187)
tf.Tensor(
[[0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(3, 15), dtype=float32)
>> 0.021422999999999998

One-Hot-Encode positions
(1024, 10)
tf.Tensor(
[[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(3, 10), dtype=float32)
>> 0.04852000000000001

Base-Hote-Encode the sequences
(51200, 5)
(1024, 50, 5)
>> 108.536776


NOTE: This Method is very slow. 
Saved output for batch size of 1024
```
    Create text line dataset
    >> 0.020242

    Load and split lines in three sections:
    (1024, 3)
    tf.Tensor(
    [[b'CCATCGGCGTCCCGGAATCGTATACCGGGCACACGAAGCGTTATAACAAT' b'6' b'8']
    [b'GTGGCTATGGTGAGGAGTTTGGAGGAAATAATATACATCATATATTCTGA' b'6' b'4']
    [b'CCCTCTTCTGCAGACTGCTTACGGTTTCGTCCGTGTTGCAGTCGATTATC' b'117' b'0']], shape=(3, 3), dtype=string)
    >> 0.014626000000000004

    Split string tensor into three tensors:
    (1024, 50) (1024,) (1024,)
    >> 0.04225999999999999

    One-Hot-Encode labels
    (1024, 187)
    tf.Tensor(
    [[0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
    [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
    [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(3, 15), dtype=float32)
    >> 0.023466000000000015

    One-Hot-Encode positions
    (1024, 10)
    tf.Tensor(
    [[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
    [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
    [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(3, 10), dtype=float32)
    >> 0.049815999999999985

    Base-Hote-Encode the sequences
    (51200, 5)
    (1024, 50, 5)
    >> 154.10547400000002
```
See next cell for much faster method, 

Other approach, faster by avoiding case function, and using simple operations. Typical results are:

```
    Create text line dataset
    >> 0.018218

    Load and split lines in three sections:
    (1024, 3)
    >> 0.026377999999999995

    Split string tensor into three tensors:
    (1024, 50) (1024,) (1024,)
    >> 0.054369

    One-Hot-Encode labels
    (1024, 187)
    >> 0.041302000000000005

    One-Hot-Encode positions
    (1024, 10)
    >> 0.063554

    Base-Hote-Encode the sequences
    (1024, 50, 5)
    >> 0.095141
```

In [21]:
now = datetime.now()
batch_size = 1024

# Create a text dataset with bach and retrieve the first batch with an iterator
print('Create text line dataset')
ds = text_ds.batch(batch_size)
it = iter(ds)
b = next(it)

duration = (datetime.now()-now).total_seconds()
print('>>', duration)

# EXPERIMENT WITH THE TRANSFORM

# 1. Split the string in three: 
# Notes:
# tf.strings.split returns a ragged tensor. It must be converted into a normal tensor
# Tensor shape is (batch size, 3), one sequences, one for labels and one for position. Dtype is all tf.strings
t = tf.strings.split(b, '\t').to_tensor(default_value = '', shape=[None, 3])
print('\nLoad and split lines in three sections:')
print(t.shape)
print(t[:, :])
duration = (datetime.now()-now).total_seconds() - duration
print('>>', duration)

# 2. Split into sequences, labels and positions
print('\nSplit string tensor into three tensors:')
# Split string sequence in a sequence of single base strings 
seqs = tf.strings.bytes_split(t[:, 0]).to_tensor(shape=(None, 50))
# Labels and Posisionts are converted from tf.string to tf.int32
labels = tf.strings.to_number(t[:, 1], out_type=tf.int32)
pos = tf.strings.to_number(t[:, 2], out_type=tf.int32)
print(seqs.shape, labels.shape, pos.shape)
duration = (datetime.now()-now).total_seconds() - duration
print('>>', duration)

# 3. One-Hot-Encode labels, using only tf functions for performance
print('\nOne-Hot-Encode labels')
n_labels = 187
y_labels = tf.gather(tf.eye(n_labels), labels)
print(y_labels.shape)
print(y_labels[:3, :15])
duration = (datetime.now()-now).total_seconds() - duration
print('>>', duration)

# 4. One-Hot-Encode positions, using only tf functions for performance
print('\nOne-Hot-Encode positions')
n_pos = 10
y_pos= tf.gather(tf.eye(n_pos), pos)
print(y_pos.shape)
print(y_pos[:3])
duration = (datetime.now()-now).total_seconds() - duration
print('>>', duration)

# 5. Base-Hot-Encode sequences
# Each base letter (A, C, G, T, N) is replaced by a OHE vector
# Notes:
# a. the batch of sequence seqs has a shape (batch_size, 50) after splitting each byte. 
#    need to flatten it to apply the transform on each base, then reshape to original shape
# b. We need to map each letter to one vector/tensor. 
#    Using tf.case seems slow. Trying other approach: 
#       - Convert bytes seqs in integer sequence (uint8 to work byte by byte)
#       - Make a matrix with one column for each of the 5 base letters
#       - Create BHE tensor as a matrix multiplication with the 5 base tensors [1, 0, 0, 0, 0] ...
# 
print('\nBase-Hote-Encode the sequences')
display(seqs[:3])
seqs_uint8 = tf.io.decode_raw(seqs, out_type=tf.uint8)
    # note: tf.io.decode_raw adds one dimension at the end in the process
    #       [b'C', b'A', b'T'] will return [[67], [65], [84]] and not [67, 65, 84]
    #       this is actually what we want to contatenate the values for each base letter
# display(seqs_uint8[:3, :])

A, C, G, T, N = 65, 67, 71, 84, 78

# Create 5 tensors seqs_X where the value of a base in the sequence is 1 if the base is X and 0 otherwise.
# We do that with a boolean slicing casted as a float32 (0.0 or 1.0)
# The final tensor is the concatenation of these 5 tensors
seqs_A = tf.cast(seqs_uint8 == A, tf.float32)
seqs_C = tf.cast(seqs_uint8 == C, tf.float32)
seqs_G = tf.cast(seqs_uint8 == G, tf.float32)
seqs_T = tf.cast(seqs_uint8 == T, tf.float32)
seqs_N = tf.cast(seqs_uint8 == N , tf.float32)
x_seqs_2 = tf.concat([seqs_A, seqs_C, seqs_G, seqs_T, seqs_N], axis=2)
print(x_seqs_2.shape)
duration = (datetime.now()-now).total_seconds() - duration
print('>>', duration)

x_seqs_2.shape, y_labels.shape, y_pos.shape

Create text line dataset
>> 0.03397

Load and split lines in three sections:
(1024, 3)
tf.Tensor(
[[b'CCATCGGCGTCCCGGAATCGTATACCGGGCACACGAAGCGTTATAACAAT' b'6' b'8']
 [b'GTGGCTATGGTGAGGAGTTTGGAGGAAATAATATACATCATATATTCTGA' b'6' b'4']
 [b'CCCTCTTCTGCAGACTGCTTACGGTTTCGTCCGTGTTGCAGTCGATTATC' b'117' b'0']
 ...
 [b'AAGTGCATTCAAGTTTTAATTGATATTTAGTTATGTAGTCATTTAGAGTA' b'115' b'3']
 [b'AGAAGCTGGCTCCGGAGCAGCAGTAGAGGGAAAACCACGGAGGCNGACAG' b'62' b'7']
 [b'CCTTGGTGAAGGTATTAACAAATCGATTAAGTTGGGAGGGATGCATGCGA' b'27' b'1']], shape=(1024, 3), dtype=string)
>> 0.008703000000000002

Split string tensor into three tensors:
(1024, 50) (1024,) (1024,)
>> 0.06675

One-Hot-Encode labels
(1024, 187)
tf.Tensor(
[[0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(3, 15), dtype=float32)
>> 0.014312999999999992

One-Hot-Encode positions
(1024, 10)
tf.Tensor(
[[0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 

<tf.Tensor: shape=(3, 50), dtype=string, numpy=
array([[b'C', b'C', b'A', b'T', b'C', b'G', b'G', b'C', b'G', b'T', b'C',
        b'C', b'C', b'G', b'G', b'A', b'A', b'T', b'C', b'G', b'T', b'A',
        b'T', b'A', b'C', b'C', b'G', b'G', b'G', b'C', b'A', b'C', b'A',
        b'C', b'G', b'A', b'A', b'G', b'C', b'G', b'T', b'T', b'A', b'T',
        b'A', b'A', b'C', b'A', b'A', b'T'],
       [b'G', b'T', b'G', b'G', b'C', b'T', b'A', b'T', b'G', b'G', b'T',
        b'G', b'A', b'G', b'G', b'A', b'G', b'T', b'T', b'T', b'G', b'G',
        b'A', b'G', b'G', b'A', b'A', b'A', b'T', b'A', b'A', b'T', b'A',
        b'T', b'A', b'C', b'A', b'T', b'C', b'A', b'T', b'A', b'T', b'A',
        b'T', b'T', b'C', b'T', b'G', b'A'],
       [b'C', b'C', b'C', b'T', b'C', b'T', b'T', b'C', b'T', b'G', b'C',
        b'A', b'G', b'A', b'C', b'T', b'G', b'C', b'T', b'T', b'A', b'C',
        b'G', b'G', b'T', b'T', b'T', b'C', b'G', b'T', b'C', b'C', b'G',
        b'T', b'G', b'T', b'T', b'G', b'C', b'A'

(1024, 50, 5)
>> 0.02928199999999999


(TensorShape([1024, 50, 5]), TensorShape([1024, 187]), TensorShape([1024, 10]))

Note on iterators in python doc
- [iterator](https://docs.python.org/3.8/glossary.html#term-iterator)
```
    The iterator objects themselves are required to support the following two methods, which together 
    form the iterator protocol:

    iterator.__iter__()
    Return the iterator object itself. This is required to allow both containers and iterators to be used with 
    the for and in statements. This method corresponds to the tp_iter slot of the type structure 
    for Python objects in the Python/C API.

    iterator.__next__()
    Return the next item from the container. If there are no further items, raise the StopIteration exception. 
    This method corresponds to the tp_iternext slot of the type structure for Python objects in the Python/C API.
```
- [`iter()`](https://docs.python.org/3.8/library/functions.html?highlight=iter#iter) `.__iter__()`
```
    iter(object[, sentinel])
    Return an iterator object. The first argument is interpreted very differently depending on the presence of the 
    second argument. Without a second argument, object must be a collection object which supports the iteration 
    protocol (the __iter__() method), or it must support the sequence protocol (the __getitem__() method with integer 
    arguments starting at 0). If it does not support either of those protocols, TypeError is raised. If the second 
    argument, sentinel, is given, then object must be a callable object. The iterator created in this case will call 
    object with no arguments for each call to its __next__() method; if the value returned is equal to sentinel, 
    StopIteration will be raised, otherwise the value will be returned.
```
- [`next()`](https://docs.python.org/3.8/library/functions.html?highlight=iter#next) `.__next__()`
``` 
    next(iterator[, default])
    Retrieve the next item from the iterator by calling its __next__() method. 
    If default is given, it is returned if the iterator is exhausted, otherwise StopIteration is raised.
```

In [66]:
sequence = [1, 2, 3, 4, 5, 6]
it = iter(sequence)
it

<list_iterator at 0x7f1f7eacf810>

In [67]:
it.__iter__()

<list_iterator at 0x7f1f7eacf810>

In [68]:
next(it), next(it)

(1, 2)

In [69]:
it.__next__(), it.__next__()

(3, 4)

In [70]:
try:
    print(next(it))
    print(next(it))
    print(next(it))
    print(next(it))
    print(next(it))
except StopIteration:
    print('Iterator reached the end of the sequence')

5
6
Iterator reached the end of the sequence


## Compare slow and fast methods to show result are the same

In [71]:
# np.array_equal(x_seqs.numpy().astype('float'), x_seqs_2.numpy())

This is much faster and provides the same results:
- former method: > 150 sec for batch of 1024
- new method: 0.03 sec for same batch size

And the result is the same

### Transform function

In [73]:
def strings_to_tensors(b):
    """Convert batch of strings into three tensors: (x_seqs, (y_labels, y_pos))"""
    
    # Split the string in three : returns a ragged tensor which needs to be converted into a normal tensor
    t = tf.strings.split(b, '\t').to_tensor(default_value = '', shape=[None, 3])

    # Split string sequence into a sequence of single base strings 
    seqs = tf.strings.bytes_split(t[:, 0]).to_tensor(shape=(None, 50))

    # OHE labels
    n_labels = 187
    y_labels = tf.strings.to_number(t[:, 1], out_type=tf.int32)
    y_labels = tf.gather(tf.eye(n_labels), y_labels)

    # OHE positions
    n_pos = 10
    y_pos = tf.strings.to_number(t[:, 2], out_type=tf.int32)
    y_pos= tf.gather(tf.eye(n_pos), y_pos)

    # BHE sequences
    # Each base letter (A, C, G, T, N) is replaced by a OHE vector
    # Notes:
    # a. the batch of sequence seqs has a shape (batch_size, 50) after splitting each byte. 
    #    need to flatten it to apply the transform on each base, then reshape to original shape
    # b. We need to map each letter to one vector/tensor. 
    #    Using tf.case seems slow. Trying other approach: 
    #       - Convert bytes seqs in integer sequence (uint8 to work byte by byte)
    #       - For each base letter (A, C, G, T, N) create one tensor (batch_size, 50)
    #         Value is 1 if it is the base in the sequence, otherwise
    #       - Concatenate these 5 tensors into a tensor of shape (batch_size, 50, 5)
 
    seqs_uint8 = tf.io.decode_raw(seqs, out_type=tf.uint8)
        # note: tf.io.decode_raw adds one dimension at the end in the process
        #       [b'C', b'A', b'T'] will return [[67], [65], [84]] and not [67, 65, 84]
        #       this is actually what we want to contatenate the values for each base letter

    # base_codes = tf.constant(['A', 'C', 'G', 'T', 'N'], dtype=tf.string)
    # A, C, G, T, N = tf.reshape(tf.io.decode_raw(base_codes, out_type=tf.uint8), shape=(-1)).numpy()
    A, C, G, T, N = 65, 67, 71, 84, 78

    seqs_A = tf.cast(seqs_uint8 == A, tf.float32)
    seqs_C = tf.cast(seqs_uint8 == C, tf.float32)
    seqs_G = tf.cast(seqs_uint8 == G, tf.float32)
    seqs_T = tf.cast(seqs_uint8 == T, tf.float32)
    seqs_N = tf.cast(seqs_uint8 == N , tf.float32)

    x_seqs = tf.concat([seqs_A, seqs_C, seqs_G, seqs_T, seqs_N], axis=2)

    return (x_seqs, (y_labels, y_pos))

In [74]:
display(b[:2])
x_seqs, (y_labels, y_pos) = strings_to_tensors(b)
b.shape, x_seqs.shape, y_labels.shape, y_pos.shape

<tf.Tensor: shape=(2,), dtype=string, numpy=
array([b'CCATCGGCGTCCCGGAATCGTATACCGGGCACACGAAGCGTTATAACAAT\t6\t8',
       b'GTGGCTATGGTGAGGAGTTTGGAGGAAATAATATACATCATATATTCTGA\t6\t4'],
      dtype=object)>

(TensorShape([1024]),
 TensorShape([1024, 50, 5]),
 TensorShape([1024, 187]),
 TensorShape([1024, 10]))

In [None]:
def strings_to_tensors_old(b):
    """Convert batch of strings into three tensors: (x_seqs, (y_labels, y_pos))"""
    
    # Split the string in three : returns a ragged tensor which needs to be converted into a normal tensor
    t = tf.strings.split(b, '\t').to_tensor(default_value = '', shape=[None, 3])

    # Split string sequence into a sequence of single base strings 
    seqs = tf.strings.bytes_split(t[:, 0]).to_tensor(shape=(None, 50))

    # OHE labels
    n_labels = 187
    y_labels = tf.strings.to_number(t[:, 1], out_type=tf.int32)
    y_labels = tf.gather(tf.eye(n_labels), y_labels)

    # OHE positions
    n_pos = 10
    y_pos = tf.strings.to_number(t[:, 2], out_type=tf.int32)
    y_pos= tf.gather(tf.eye(n_pos), y_pos)

    # BHE sequences

    # flatten the sequence tensor to map bytes (b'A', ...) to BHE vector
    flattened_seqs = tf.reshape(seqs, shape=[-1])

    def base_hot_encoder(a):
       
        # Define the encoding functions returning the encoding tensor for each base
        def encode_A(): return tf.constant([1,0,0,0,0])
        def encode_C(): return tf.constant([0,1,0,0,0])
        def encode_G(): return tf.constant([0,0,1,0,0])
        def encode_T(): return tf.constant([0,0,0,1,0])
        def encode_N(): return tf.constant([0,0,0,0,1])

        # Define the mapping, as a list of tuples: (condition, encoding function)
        # values of the strings are tested against b'A', ... as bytes
        case_mapping = [
            (tf.math.equal(a, b'A'), encode_A),
            (tf.math.equal(a, b'C'), encode_C),
            (tf.math.equal(a, b'G'), encode_G),
            (tf.math.equal(a, b'T'), encode_T),
            (tf.math.equal(a, b'N'), encode_N)
        ]

        return tf.case(pred_fn_pairs = case_mapping)

    # Process seqs tensor
    processed_seqs = tf.map_fn(base_hot_encoder, flattened_seqs, fn_output_signature=tf.int32)
    x_seqs = tf.reshape(processed_seqs, shape=(-1, 50, 5))

    return (x_seqs, (y_labels, y_pos))


## 3. Apply transformation to text_ds

In [75]:
batch_size = 1024
text_ds = tf.data.TextLineDataset(
    filepath_train,
    compression_type='',
    name='text_ds'
).batch(batch_size)

tensor_ds = text_ds.map(strings_to_tensors)
print(tensor_ds.element_spec)

(TensorSpec(shape=(None, 50, None), dtype=tf.float32, name=None), (TensorSpec(shape=(None, 187), dtype=tf.float32, name=None), TensorSpec(shape=(None, 10), dtype=tf.float32, name=None)))


In [76]:
%%time
for seqs_b, (labels_b, pos_b) in tensor_ds.take(2):
    print(seqs_b.shape, labels_b.shape, pos_b.shape)

(1024, 50, 5) (1024, 187) (1024, 10)
(1024, 50, 5) (1024, 187) (1024, 10)
CPU times: user 63.2 ms, sys: 50.9 ms, total: 114 ms
Wall time: 137 ms


## 4. Check that the output are the same as in the original code

We use the validation set and pick the first `bs` values

In [77]:
bs = 2048
text_ds = tf.data.TextLineDataset(
    filepath_val,
    compression_type='',
    name='text_ds'
).batch(bs)

tensor_ds = text_ds.map(strings_to_tensors)
it = iter(zip(text_ds, tensor_ds))
b_text, (b_seqs, (b_labels, b_pos)) = next(it)
print(b_text.shape, b_seqs.shape, b_labels.shape, b_pos.shape)

(2048,) (2048, 50, 5) (2048, 187) (2048, 10)


Using the original code `src.preprocessing.DataGenerator_from_50mer`

```python
def __data_generation(self, index):
        x_train=[]
        for i in index:
            seq=self.matrix[i]
            seq_list=[j for j in seq]
            x_train.append(seq_list)
        x_train=np.array(x_train)
        x_tensor=np.zeros(list(x_train.shape)+[5])
        for row in range(len(x_train)):
            for col in range(50):
                x_tensor[row,col,d_nucl[x_train[row,col]]]=1
        y_pos=[]
        y_label=[self.labels[i] for i in index]
        y_label=np.array(y_label)
        y_label=to_categorical(y_label, num_classes=self.n_classes)
        y_pos=[self.pos[i] for i in index]
        y_pos=np.array(y_pos)
        y_pos=to_categorical(y_pos, num_classes=10)
        return x_tensor,{'output1': y_label, 'output2': y_pos}
```

In [84]:
def data_generation(index):
    d_nucl={"A":0,"C":1,"G":2,"T":3,"N":4}
    x_train=[]
    for i in index:
        seq=f_matrix[i]
        seq_list=[j for j in seq]
        x_train.append(seq_list)
    x_train=np.array(x_train)
    x_tensor=np.zeros(list(x_train.shape)+[5])
    for row in range(len(x_train)):
        for col in range(50):
            x_tensor[row,col,d_nucl[x_train[row,col]]]=1
    y_pos=[]
    y_label=[f_labels[i] for i in index]
    y_label=np.array(y_label)
    y_label=to_categorical(y_label, num_classes=187)
    y_pos=[f_pos[i] for i in index]
    y_pos=np.array(y_pos)
    y_pos=to_categorical(y_pos, num_classes=10)
    return x_tensor, y_label, y_pos

f_matrix,f_labels,f_pos=get_kmer_from_50mer(filepath_val, max_seqs=bs)
x_seqs_orig, y_label_orig, y_pos_orig = data_generation(list(range(bs)))
print(x_seqs_orig.shape, y_label_orig.shape, y_pos_orig.shape)

(2048, 50, 5) (2048, 187) (2048, 10)


### Test for the BHE encoded sequences

In [85]:
sample_idx = 0

print('Processed Values for x_seqs from Dataset:\n')
print(b_text[sample_idx].numpy())
print(b_seqs.numpy( )[sample_idx, :10, :])

Processed Values for x_seqs from Dataset:

b'CCATCGGCGTCCCGGAATCGTATACCGGGCACACGAAGCGTTATAACAAT\t6\t8'
[[0. 1. 0. 0. 0.]
 [0. 1. 0. 0. 0.]
 [1. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 1. 0. 0. 0.]
 [0. 0. 1. 0. 0.]
 [0. 0. 0. 1. 0.]]


In [86]:
print('Processed Values for x_seqs_orig from original code:\n')
print(f_matrix[sample_idx])
print(x_seqs_orig[sample_idx, :10, :].astype('int'))

Processed Values for x_seqs_orig from original code:

CCATCGGCGTCCCGGAATCGTATACCGGGCACACGAAGCGTTATAACAAT
[[0 1 0 0 0]
 [0 1 0 0 0]
 [1 0 0 0 0]
 [0 0 0 1 0]
 [0 1 0 0 0]
 [0 0 1 0 0]
 [0 0 1 0 0]
 [0 1 0 0 0]
 [0 0 1 0 0]
 [0 0 0 1 0]]


In [87]:
print(b_seqs.numpy().shape)
print(x_seqs_orig.astype('int').shape)
np.array_equal(b_seqs.numpy(),x_seqs_orig.astype('int') )

(2048, 50, 5)
(2048, 50, 5)


True

### Test for the labels and Position

In [88]:
print(b_labels.shape, y_label_orig.shape)
np.array_equal(b_labels.numpy(), y_label_orig.astype('int'))

(2048, 187) (2048, 187)


True

In [89]:
print(b_pos.shape, y_pos_orig.shape)
np.array_equal(b_pos.numpy(), y_pos_orig.astype('int'))

(2048, 10) (2048, 10)


True