In [None]:
#| default_exp cnn_virus.data

In [None]:
#| hide
from eccore.ipython import nb_setup
from eccore.core import files_in_tree
from fastcore.test import test_fail
from nbdev import show_doc, nbdev_export
from pprint import pprint

In [None]:
#| hide
nb_setup()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Set autoreload mode


In [None]:
#|export
import collections
import json
import os
import re
import warnings
from functools import partial, partialmethod
from pathlib import Path
from typing import Any, Optional

In [None]:
#|export
# Set pytorch as backend
os.environ['KERAS_BACKEND'] = 'torch'

In [None]:
#| export
import keras
import numpy as np
import pandas as pd
import torch
from eccore.core import safe_path, validate_path
from torch.utils.data import DataLoader, Dataset, IterableDataset
from tqdm.notebook import tqdm, trange

from metagentorch.bio import q_score2prob_error
from metagentorch.core import ProjectFileSystem, TextFileBaseReader

In [None]:
#| export
# Retrieve the package root
from metagentorch import __file__
CODE_ROOT = Path(__file__).parents[0]
PACKAGE_ROOT = Path(__file__).parents[1]

In [None]:
#|hide
print(f"Pytorch version: {torch.__version__} - Expected 2.5.1")
print(f"Keras version: {keras.__version__} - Expected 3.8.0")
print(f"metagentorch package location: {__file__}")

Pytorch version: 2.5.1 - Expected 2.5.1
Keras version: 3.8.0 - Expected 3.8.0
metagentorch package location: /home/vtec/projects/bio/metagentorch/metagentorch/__init__.py


# data

> Data preprocessing and transform tools for CNN Virus data.

# CNN Virus project data structure

There are many different types of files and datasets for this project. All data are located in directory `data`, under the project root. The following is an overview of the main types of data and where they sit in the directory tree.

## Data directory

Let's create an instance of `ProjectFileSystem` to access the data directory and its content.

In [None]:
#| hide
# Set configuration file to be the one in nbs-dev folder.
# As ProjectFileSystem is a singleton class, this only needs to be done once per notebook
p2dev_cfg = PACKAGE_ROOT / 'nbs-dev/metagentorch-dev.cfg'
pfs = ProjectFileSystem(config_fname=p2dev_cfg)

In [None]:
pfs = ProjectFileSystem()
pfs.info()

Running linux on local computer
Device's home directory: /home/vtec
Project file structure:
 - Root ........ /home/vtec/projects/bio/metagentorch 
 - Data Dir .... /home/vtec/projects/bio/metagentorch/nbs-dev/data_dev 
 - Notebooks ... /home/vtec/projects/bio/metagentorch/nbs


Note that we have setup this notebook to use the development data directory `metagentorch/nbs-dev/data_dev` and not the standard `metagentorch/data`.

In [None]:
#| hide
assert pfs.data.name == 'data_dev' 

In each directory, a `readme.md` file or another `*.md` file can be added to provide a description of the directory content. 

These `readme.md` files can be conveniently accessed using the `.readme(path)` method available with the class `ProjectFileSystem` (from core module).

In [None]:
pfs.readme()

ReadMe file for directory `nbs-dev/data_dev`:

### Data directory for this package development 
This directory includes all  data required to validate and test this package code.

```text
data_dev
 |--- CNN_Virus_data
 |     |--- 50mer_ds_100_seq
 |     |--- 150mer_ds_100_seq
 |     |--- train_short
 |     |--- val_short
 |     |--- weight_of_classes
 |--- ncbi
 |     |--- infer_results
 |     |     |--- cnn_virus
 |     |     |--- csv
 |     |     |--- xlsx
 |     |     |--- testdb.db
 |     |--- refsequences
 |     |     |--- cov
 |     |     |     |--cov_virus_sequence_one_metadata.json
 |     |     |     |--sequences_two_no_matching_rule.fa
 |     |     |     |--another_sequence.fa
 |     |     |     |--cov_virus_sequences_two.fa
 |     |     |     |--cov_virus_sequences_two_metadata.json
 |     |     |     |--cov_virus_sequence_one.fa
 |     |     |     |--single_1seq_150bp
 |     |     |     |    |--single_1seq_150bp.fq
 |     |     |     |    |--single_1seq_150bp.aln
 |     |     |     |--paired_1seq_150bp
 |     |     |     |    |--paired_1seq_150bp2.aln
 |     |     |     |    |--paired_1seq_150bp2.fq
 |     |     |     |    |--paired_1seq_150bp1.fq 
 |     |     |     |    |--paired_1seq_150bp1.aln 
 |     |--- simreads
 |     |     |--- cov
 |     |     |     |--- paired_1seq_50bp
 |     |     |     |      |--- paired_1seq_50bp_1.aln
 |     |     |     |      |--- paired_1seq_50bp_1.fq
 |     |     |     |--- single_1seq_50bp
 |     |     |     |      |--- single_1seq_50bp_1.aln
 |     |     |     |      |--- single_1seq_50bp_1.fq
 |     |     |--- cov
 |     |     |     |--single_1seq_50bp
 |     |     |     |    |--single_1seq_50bp.aln
 |     |     |     |    |--single_1seq_50bp.fq
 |     |     |     |--single_1seq_150bp
 |     |     |     |    |--single_1seq_150bp.fq
 |     |     |     |    |--single_1seq_150bp.aln
 |     |     |     |--paired_1seq_150bp
 |     |     |     |    |--paired_1seq_150bp2.aln
 |     |     |     |    |--paired_1seq_150bp2.fq
 |     |     |     |    |--paired_1seq_150bp1.fq
 |     |     |     |    |--paired_1seq_150bp1.aln
 |--- saved           
 |--- readme.md               
```

## Original datasets

The CNN Virus team provides training and validation/test datasets for their model. These datasets and the pretrained model parameters are are available on their Google drive shared directory [here](https://drive.google.com/open?id=1sj0-NCSKjLta_Geg6EMo26rChmtWcOiI). A set of shortened datasets are also available in the `data-dev` development data directory for testing the code.

In [None]:
pfs.readme(pfs.data/'CNN_Virus_data')

ReadMe file for directory `nbs-dev/data_dev/CNN_Virus_data`:

### CNN Virus data (development directory version)

This directory includes a set of short data files used to test train and inference with the CNN Virus. 

#### File list and description:
##### 50-mer 
50-mer reads and their labels, in *text format* with one line per sample. Each line consists of three components, separated by tabs: the 50-mer read or sequence, the virus species label and the position label:
```text
'TTACNAGCTCCAGTCTAAGATTGTAACTGGCCTTTTTAAAGATTGCTCTA    94    5\n'
``` 
Files:

- `50mer_ds_100_seq`: small dataset with 100 reads
- `5train_short`: small 1000-read subset from the original training dataset for experiments
- `val_short`: small 500-read subset from the original validation dataset for experiments

##### 150-mer
150-mer reads and their labels in *text format* in a similar format as above:
```text
'TTCTTTCACCACCACAACCAGTCGGCCGTGGAGAGGCGTCGCCGCGTCTCGTTCGTCGAGGCCGATCGACTGCCGCATGAGAGCGGGTGGTATTCTTCCGAAGACGACGGAGACCGGGACGGTGATGAGGAAACTGGAGAGAGCCACAAC    6    0\n'
```
Files:

- `150mer_ds_100_reads`: small subset of 100 reads from original `ICTV_150mer_benchmarking` file

##### Other files:

- `virus_name_mapping`: mapping between virus species and their numerical label
- `weight_of_classes`:  weights for each virus species class in the training dataset



In [None]:
#|export
class OriginalLabels:
    """Converts between labels and species name as per original training dataset"""
    def __init__(
        self, 
        p2mapping:Path|None = None   # Path to the mapping file. Uses `virus_name_mapping` by default
        ):
        if p2mapping is None:
            p2mapping = ProjectFileSystem().data / 'CNN_Virus_data/virus_name_mapping'
        else:
            p2mapping = safe_path(p2mapping)
        if not p2mapping.is_file(): raise FileNotFoundError(f"Mapping file not found at {p2mapping}")
        df = pd.read_csv(p2mapping, sep='\t', header=None, names=['species', 'label'])
        self._label2species = df['species'].to_list()
        self._label2species.append('Unknown Virus Species')
        self._species2label = {specie:label for specie, label in zip(df['species'], df['label'])}
        self._species2label['Unknown Virus Species'] = len(self._label2species)

    def search(self, s:str  # string to search through all original virus species
                       ):
        """Prints all species whose name contains the passed string, with their numerical label"""
        print('\n'.join([f"{k}. Label: {v}" for k,v in self._species2label.items() if s in k.lower()]))

    def label2species(self, n:int # label to convert to species name
                      ):
        """Converts a numerical label into the correpsonding species label"""
        return self._label2species[n]

    def species2label(self, s:str  # string to convert to label
                      ):
        """Converts a species name into the corresponding label number"""
        return self._species2label[s]

Original data include 187 viruses, with label from 0 to 186. 

With the class method `.label2species(n)` and `.species2label(species)` we can convert between the label and the species name.

In [None]:
species = OriginalLabels()
for n in [0, 94, 117, 118]:
    print(f"{n:3d} -> {species.label2species(n)}")

  0 -> Variola_virus
 94 -> Middle_East_respiratory_syndrome-related_coronavirus
117 -> Severe_acute_respiratory_syndrome-related_coronavirus
118 -> Yellow_fever_virus


In [None]:
for s in ['Variola_virus', 'Yellow_fever_virus']:
    print(f"{s:20s} -> {species.species2label(s)}")

Variola_virus        -> 0
Yellow_fever_virus   -> 118


When looking for a numerical specie label it is often more convenient to use a partial name, and the method `.search(species)` because we do not need to know the full specie name.

In [None]:
show_doc(OriginalLabels.search)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/cnn_virus/data.py#L59){target="_blank" style="float:right; font-size:smaller"}

### OriginalLabels.search

>      OriginalLabels.search (s:str)

*Prints all species whose name contains the passed string, with their numerical label*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| s | str | string to search through all original virus species |

In [None]:
species.search('fever')

Sandfly_fever_Naples_phlebovirus. Label: 35
Crimean-Congo_hemorrhagic_fever_orthonairovirus. Label: 76
Yellow_fever_virus. Label: 118
Rift_Valley_fever_phlebovirus. Label: 156


## NCBI reference sequences

We simulate reads using many reference sequences from the NCBI GenBank database. We group all reference sequences as well as all reads simulated from these reference sequences under the `ncbi` directory.

In [None]:
pfs = ProjectFileSystem()

In [None]:
pfs.readme(pfs.data / 'ncbi')

ReadMe file for directory `nbs-dev/data_dev/ncbi`:

### NCBI Data

This directory includes all data related to the work done with reference sequences from NCBI. 

The data is organized in the following subfolders:

- `refsequences`: reference CoV sequences downloaded from NCBI, and related metadata
- `simreads`: all data from simulated reads, using ART Illumina simulator and the reference sequences
- `infer_results`: results from the inference using models with the simulated reads
- `ds`: datasets in proper format for training or inference/prediction using the CNN Virus model


In [None]:
pfs.readme(pfs.data / 'ncbi/refsequences')

ReadMe file for directory `nbs-dev/data_dev/ncbi/refsequences`:

No markdown file in this folder


In [None]:
pfs.readme(pfs.data / 'ncbi/simreads')

ReadMe file for directory `nbs-dev/data_dev/ncbi/simreads`:

### NCBI simulated reads
This directory includes all sets of simulated read sequence files generated from NCBI viral sequences using  ARC Illumina. 

```ascii
this-directory
    |--cov
    |    |
    |    |--single_10seq_50bp
    |    |    |--single_10seq_50bp.fq
    |    |    |--single_10seq_50bp.alnEnd
    |    |-- ...
    |    |--single_100seq_150bp
    |    |    |--single_100seq_150bp.fq
    |    |    |--single_100seq_150bp.aln
    |    |--paired_100seq_50bp
    |    |    |--paired_100seq_50bp2.aln
    |    |    |--paired_100seq_50bp1.aln
    |    |    |--paired_100seq_50bp2.fq
    |    |    |--paired_100seq_50bp1.fq
    |    |-- ...
    |    |
    |---yf
    |    |
    |    |--yf_AY968064-single-150bp
    |    |    |--yf_AY968064-single-1seq-150bp.fq
    |    |    |--yf_AY968064-single-1seq-150bp.aln
    |    |
    |--mRhiFer1
    |    |--mRhiFer1_v1.p.dna_rm.primary_assembly.1
    |    |    |--mRhiFer1_v1.p.dna_rm.primary_assembly.1.fq
    |    |    |--mRhiFer1_v1.p.dna_rm.primary_assembly.1.aln
    |    |

```

This directory includes several subdirectories, each for one virus, e.g. `cov` for corona, `yf` for yellow fever.

In each virus subdirectory, several simreads directory includes simulated reads with various parameters, named as `<method>_<nb-seq>_<nb-bp>` where"
- `<method>` is either `single` or `paired` depending on the simulation method
- `<nb-seq>` is the number of reference sequences used for simulation, and refers to the `fa` file used
- `<nb-bp>` is the number of base pairs used to simulate reads


Each sub-directory includes simreads files made using a simulation method and a specific number of reference sequences.
- `xxx.fq` and `xxx.aln` files when method is `single`
- `xxx1.fq`, `xxx2.fq`, `xxx1.aln` and `xxx2.aln` files when method is `paired`.

Example:
- `paired_10seq_50bp` means that the simreads were generated by using the `paired` method to simulate 50-bp reads, and using the `fa` file `cov_virus_sequences_010-seqs.fa`.
- `single_100seq_50bp` means that the simreads were generated by using the `single` method to simulate 50-bp reads, and using the `fa` file `cov_virus_sequences_100-seqs.fa`. Note that this generated 20,660,104 reads !

#### Simread file formats

Simulated reads information is split between two files:
- **FASTQ** (`.fq`) files providing the read sequences and their ASCII quality scores
- **ALN** (`.aln`) files with alignment information

##### FASTQ (`.fq`)
FASTQ files generated by ART Illumina have the following structure (showing 5 reads), with 4 lines for each read:

```ascii
@2591237:ncbi:1-60400
ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG
+
CCCBCGFGBGGGGGGGBGGGGGGGGG>GGG1G=/GGGGGGGGGGGGGGGG
@2591237:ncbi:1-60399
GATCAATGTGGCATCTACAATACAGACAGCATGAAGCACCACCAAAGGAC
+
BCBCCFGGGGGGGG1CGGGG<GGBGGGGGFGCGGGGGGDGGG/GG1GGGG
@2591237:ncbi:1-60398
ATCTACCAGTGGTAGATGGGTTCTTAATAATGAACATTATAGAGCTCTAC
+
CCCCCGGGEGG1GGF1G/GGEGGGGGGGGGGGGFFGGGGGGGGGGDGGDG
@2591237:ncbi:1-60397
CGTAAAGTAGAGGCTGTATGGTAGCTAGCACAAATGCCAGCACCAATAGG
+
BCCCCGGGFGGGGGGFGGGGFGG1GGGGGGG>GG1GGGGGGGGGGE<GGG
@2591237:ncbi:1-60396
GGTATCGGGTATCTCCTGCATCAATGCAAGGTCTTACAAAGATAAATACT
+
CBCCCGGG@CGGGGGGGGGGGG=GFGGGGDGGGFG1GGGGGGGG@GGGGG
```
The following information can be parsed from the each read sequence in the FASTQ file:

- Line 1: `readid`, a unique ID for the read, under for format `@readid` 
- Line 2: `readseq`, the sequence of the read
- Line 3: a separator `+`
- Line 4: `read_qscores`, the base quality scores encoded in ASCII 

Example:
```
@2591237:ncbi:1-60400
ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG
+
CCCBCGFGBGGGGGGGBGGGGGGGGG>GGG1G=/GGGGGGGGGGGGGGGG
```
- `readid` = `2591237:ncbi:1-60400`
- `readseq` = `ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG`, a 50 bp read
- `read_qscores` = `CCCBCGFGBGGGGGGGBGGGGGGGGG>GGG1G=/GGGGGGGGGGGGGGGG`


#### ALN (`.aln`) 
ALN files generated by ART Illumina consist of :
- a header with the ART-Ilumina command used for the simulation (`@CM`) and info on each of the reference sequences used for the simulations (`@SQ`). Header always starts with `##ART_Illumina` and ends with `##Header End` :
- the body with 3 lines for each read:
    1. definition line with `readid`, 
        - reference sequence identification number `refseqid`, 
        - the position in the read in the reference sequence `aln_start_pos` 
        - the strand the read was taken from `ref_seq_strand`. `+` for coding strand and `-` for template strand
    2. aligned reference sequence, that is the sequence segment in the original reference corresponding to the read
    3. aligned read sequence, that is the simmulated read sequence, where each bp corresponds to the reference sequence bp in the same position.

Example of a ALN file generated by ART Illumina (showing 5 reads):

```ascii
##ART_Illumina    read_length    50
@CM    /bin/art_illumina -i /home/vtec/projects/bio/metagentools/data/cov_data/cov_virus_sequences_ten.fa -ss HS25 -l 50 -f 100 -o /home/vtec/projects/bio/metagentools/data/cov_simreads/single_10seq_50bp/single_10seq_50bp -rs 1674660835
@SQ    2591237:ncbi:1 1   MK211378    2591237    ncbi    1     Coronavirus BtRs-BetaCoV/YN2018D    30213
@SQ    11128:ncbi:2   2   LC494191    11128    ncbi    2     Bovine coronavirus    30942
@SQ    31631:ncbi:3   3   KY967361    31631    ncbi    3     Human coronavirus OC43        30661
@SQ    277944:ncbi:4  4   LC654455    277944    ncbi    4     Human coronavirus NL63    27516
@SQ    11120:ncbi:5   5   MN987231    11120    ncbi    5     Infectious bronchitis virus    27617
@SQ    28295:ncbi:6   6   KU893866    28295    ncbi    6     Porcine epidemic diarrhea virus    28043
@SQ    28295:ncbi:7   7   KJ645638    28295    ncbi    7     Porcine epidemic diarrhea virus    27998
@SQ    28295:ncbi:8   8   KJ645678    28295    ncbi    8     Porcine epidemic diarrhea virus    27998
@SQ    28295:ncbi:9   9   KR873434    28295    ncbi    9     Porcine epidemic diarrhea virus    28038
@SQ    1699095:ncbi:10 10  KT368904    1699095    ncbi    10     Camel alphacoronavirus    27395
##Header End
>2591237:ncbi:1    2591237:ncbi:1-60400    14770    +
ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG
ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG
>2591237:ncbi:1    2591237:ncbi:1-60399    17012    -
GATCAATGTGGCATCTACAATACAGACAGCATGAAGCACCACCAAAGGAC
GATCAATGTGGCATCTACAATACAGACAGCATGAAGCACCACCAAAGGAC
>2591237:ncbi:1    2591237:ncbi:1-60398    9188    +
ATCTACCAGTGGTAGATGGGTTCTTAATAATGAACATTATAGAGCTCTAC
ATCTACCAGTGGTAGATGGGTTCTTAATAATGAACATTATAGAGCTCTAC
.....
```

## Model related data

In [None]:
pfs = ProjectFileSystem()

In [None]:
pfs.readme(pfs.data / 'saved')

ReadMe file for directory `nbs-dev/data_dev/saved`:

### Saved data related to models

This directory includes all data related to models and saved:
- saved model parameters
- saved datasets

For example:
- `cnn_virus_original/pretrained_model.h5` is the saved model parameters for the CNN Virus model



# Parsing sequence files

The following classes make it easier to read and parse files of different formats into their underlying components to generated the training, validation, testing and inference datasets for the model.

Each class inherits from `TextFileBaseReader` and adds:

- One or several text parsing method(s) to parse metadata according to a specific format
- A file parsing method to extract metadata from all elements in the file, returning it as a key:value dictionary and optionally save the metadata as a json file.

## FASTA file

Extension of `TextFileBaseReader` class for fasta sequence files.

Structure of a FASTA sequence file:
```ascii
>definition line - format varies from dataset to dataset
sequence line: sequence of bases
```
Example for the NCBI datasets:
```ascii
>seqid accession taxonomyid source seqnb organism
TATTAGGTTTTCTACCTACCCAGGAAAAGCCAACCAACCTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAAT ...
>2591237:ncbi:1 MK211378	2591237	ncbi	1 Coronavirus BtRs-BetaCoV/YN2018D
TATTAGGTTTTCTACCTACCCAGGAAAAGCCAACCAACCTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAAT ...
```

In [None]:
p2fasta = pfs.data / 'ncbi/refsequences/cov/cov_virus_sequences_two.fa'

fasta = TextFileBaseReader(p2fasta, nlines=1)
for i, t in enumerate(fasta):
    txt = t.replace('\n', '')[:80] + ' ...'
    print(f"{txt}")

>2591237:ncbi:1	1	MK211378	2591237	ncbi	Coronavirus BtRs-BetaCoV/YN2018D ...
TATTAGGTTTTCTACCTACCCAGGAAAAGCCAACCAACCTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAAT ...
>11128:ncbi:2	2	LC494191	11128	ncbi	Bovine coronavirus ...
CATCCCGCTTCACTGATCTCTTGTTAGATCTTTTCATAATCTAAACTTTATAAAAACATCCACTCCCTGTAGTCTATGCC ...


In [None]:
#| export
class FastaFileReader(TextFileBaseReader):
    """Wrap a FASTA file and retrieve its content in raw format and parsed format"""
    def __init__(
        self,
        path: str|Path,  # path to the Fasta file
    ):
        super().__init__(path, nlines=1)
        self.text_to_parse_key = 'definition line'
        self.set_parsing_rules(verbose=False)
        
    def __next__(self)-> dict[str, str]:   # `{'definition line': text in dfn line, 'sequence': full sequence as str}` 
        """Return one definition line and the corresponding sequence"""
        lines = []
        for i in range(2):
            lines.append(self._safe_readline())
        dfn_line = lines[0].replace('\n', '')   #remove the next line symbol at the end of the line
        sequence = lines[1].replace('\n', '')   #remove the next line symbol at the end of the line
        self._chunk_nb = self._chunk_nb + 1
        return {'definition line':dfn_line, 'sequence':f"{sequence}"}

    @property
    def read_nb(self)-> int:
        return self._chunk_nb
    
    def print_first_chunks(
        self, 
        nchunks:int=3,  # number of chunks to print out
    ):
        """Print the first `nchunks` chunks of text from the file"""
        self.reset_iterator()
        for i, seq_dict in enumerate(self.__iter__()):
            print(f"\nSequence {i+1}:")
            print(seq_dict['definition line'])
            print(f"{seq_dict['sequence'][:80]} ...")
            if i >= nchunks-1: break
        self.reset_iterator()
            
    def parse_file(
        self,
        add_seq :bool=False,     # When True, add the full sequence to the parsed metadata dictionary
        save_json: bool=False    # When True, save the file metadata as a json file of same stem name
    )-> dict[str]:               # Metadata as Key/Values pairs
        """Read fasta file and return a dictionary with definition line metadata and optionally sequences"""
    
        self.reset_iterator()
        parsed = {}
        for d in self:
            dfn_line = d['definition line']
            seq = d['sequence']
            metadata = self._parse_text_fn(dfn_line, self.re_pattern, self.re_keys)
            if add_seq: metadata['sequence'] = seq         
            parsed[metadata['seqid']] = metadata
                        
        if save_json:
            p2json = self.path.parent / f"{self.path.stem}_metadata.json"
            with open(p2json, 'w') as fp:
                json.dump(parsed, fp, indent=4)
                print(f"Metadata for '{self.path.name}'> saved as <{p2json.name}> in  \n{p2json.parent.absolute()}\n")

        return parsed

    def review(self):
        """Prints the first and last sequences and metadata in the fasta file and returns the nb or sequences"""

        self.reset_iterator()
        for i, seq in enumerate(self):
            if i == 0:
                first_dfn = seq['definition line']
                first_sequence = seq['sequence'][:80] + ' ...'
                first_meta = self.parse_text(seq['definition line'])
        print(f"There {'is' if i == 0 else 'are'} {i+1} sequences in this file")
        print('\nFirst Sequence:')
        print(first_dfn)
        print(first_sequence)
        print(first_meta)
        if i != 0:
            print('\nLast Sequence:')
            print(seq['definition line'])
            print(seq['sequence'][:80] + ' ...')
            print(self.parse_text(seq['definition line']))
        return i+1


In [None]:
show_doc(FastaFileReader)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/cnn_virus/data.py#L75){target="_blank" style="float:right; font-size:smaller"}

### FastaFileReader

>      FastaFileReader (path:str|pathlib.Path)

*Wrap a FASTA file and retrieve its content in raw format and parsed format*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| path | str \| pathlib.Path | path to the Fasta file |

As an iterator, `FastaFileReader` returns a `dict` at each step, as follows:
```python
{
    'definition line': 'string in file as the definition line for the sequence',
    'sequence': 'the full sequence'
}
```

Illustration:

In [None]:
p2fasta = pfs.data / 'ncbi/refsequences/cov/cov_virus_sequences_two.fa'
fasta = FastaFileReader(p2fasta)
iteration_output = next(fasta)

print(iteration_output['definition line'][:80], '...')
print(iteration_output['sequence'][:80], '...')

>2591237:ncbi:1	1	MK211378	2591237	ncbi	Coronavirus BtRs-BetaCoV/YN2018D ...
TATTAGGTTTTCTACCTACCCAGGAAAAGCCAACCAACCTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAAT ...


In [None]:
print(f"output type :     {type(iteration_output)}")
print(f"keys :            {iteration_output.keys()}")
print(f"definition line : {iteration_output['definition line'][:80]} ...'")
print(f"sequence :       '{iteration_output['sequence'][:100]} ...'")

output type :     <class 'dict'>
keys :            dict_keys(['definition line', 'sequence'])
definition line : >2591237:ncbi:1	1	MK211378	2591237	ncbi	Coronavirus BtRs-BetaCoV/YN2018D ...'
sequence :       'TATTAGGTTTTCTACCTACCCAGGAAAAGCCAACCAACCTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAATCTGTGTAGCTGTCGCTCGGC ...'


The `definition line` is a string, with tab separated values.

In [None]:
display(iteration_output['definition line'])

'>2591237:ncbi:1\t1\tMK211378\t2591237\tncbi\tCoronavirus BtRs-BetaCoV/YN2018D'

In [None]:
show_doc(FastaFileReader.review)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/cnn_virus/data.py#L136){target="_blank" style="float:right; font-size:smaller"}

### FastaFileReader.review

>      FastaFileReader.review ()

*Prints the first and last sequences and metadata in the fasta file and returns the nb or sequences*

In [None]:
nb_seqs = fasta.review()
nb_seqs

There are 2 sequences in this file

First Sequence:
>2591237:ncbi:1	1	MK211378	2591237	ncbi	Coronavirus BtRs-BetaCoV/YN2018D
TATTAGGTTTTCTACCTACCCAGGAAAAGCCAACCAACCTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAAT ...
{'accession': 'MK211378', 'organism': 'Coronavirus BtRs-BetaCoV/YN2018D', 'seqid': '2591237:ncbi:1', 'seqnb': '1', 'source': 'ncbi', 'taxonomyid': '2591237'}

Last Sequence:
>11128:ncbi:2	2	LC494191	11128	ncbi	Bovine coronavirus
CATCCCGCTTCACTGATCTCTTGTTAGATCTTTTCATAATCTAAACTTTATAAAAACATCCACTCCCTGTAGTCTATGCC ...
{'accession': 'LC494191', 'organism': 'Bovine coronavirus', 'seqid': '11128:ncbi:2', 'seqnb': '2', 'source': 'ncbi', 'taxonomyid': '11128'}


2

In [None]:
show_doc(FastaFileReader.print_first_chunks)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/cnn_virus/data.py#L99){target="_blank" style="float:right; font-size:smaller"}

### FastaFileReader.print_first_chunks

>      FastaFileReader.print_first_chunks (nchunks:int=3)

*Print the first `nchunks` chunks of text from the file*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| nchunks | int | 3 | number of chunks to print out |

This is convenient to quickly discover and explore new fasta files in raw text format:

In [None]:
fasta = FastaFileReader(p2fasta)
fasta.print_first_chunks(nchunks=2)


Sequence 1:
>2591237:ncbi:1	1	MK211378	2591237	ncbi	Coronavirus BtRs-BetaCoV/YN2018D
TATTAGGTTTTCTACCTACCCAGGAAAAGCCAACCAACCTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAAT ...

Sequence 2:
>11128:ncbi:2	2	LC494191	11128	ncbi	Bovine coronavirus
CATCCCGCTTCACTGATCTCTTGTTAGATCTTTTCATAATCTAAACTTTATAAAAACATCCACTCCCTGTAGTCTATGCC ...


### Parsing metadata

The class also provides methods to parse metadata from the file content (definition line, headers, ...).

A regex pattern is used for parsing metadata fom the definition lines in the reference sequence fasta file.

Below, we parse the data from the definition line of our Corona virus NCBI dataset (rule `fasta_ncbi_std`):

Sequence 1:

- Definition Line:
```ascii
>2591237:ncbi:1 [MK211378]	2591237	ncbi	1 [MK211378] 2591237	Coronavirus YN2018D		scientific name
```
- Metadata:
    - `seqid` = `2591237:ncbi:1`
    - `taxonomyid` = `2591237`
    - `source` = `ncbi`
    - `seqnb` = `1`
    - `accession` = `MK211378`
    - `species` = `Coronavirus BtRs-BetaCoV/YN2018D`

Sequence 2:

- Definition Line
```ascii
    >11128:ncbi:2 [LC494191]
```

- Metadata:
    - `seqid` = `11128:ncbi:2`
    - `taxonomyid` = `11128`
    - `source` = `ncbi`
    - `seqnb` = `2`
    - `accession` = `LC494191`
    - `species` = `''`

`FastaFileReader` offers:
- `parse_text` a method to parse the metadata
- an option to set a default "parsing rule" for one instance with `set_parsing_rules`.
- `parse_file` a method to parse the metadata from all sequences in the file and save it as a json file.

In [None]:
show_doc(FastaFileReader.parse_text)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/core.py#LNone){target="_blank" style="float:right; font-size:smaller"}

### TextFileBaseReader.parse_text

>      TextFileBaseReader.parse_text (txt:str, pattern:str|None=None,
>                                     keys:list[str]|None=None)

*Parse text using regex pattern and keys. Return a metadata dictionary.

The passed text is parsed using the regex pattern. The method return a dictionary in the format:

    {
        'key_1': 'metadata 1',
        'key_2': 'metadata 2',
        ...
    }*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| txt | str |  | text to parse |
| pattern | str \| None | None | If None, uses standard regex pattern to extract metadata, otherwise, uses passed regex |
| keys | list[str] \| None | None | If None, uses standard regex list of keys, otherwise, uses passed list of keys (str) |
| **Returns** | **dict** |  | **parsed metadata in key/value format** |

Running the parser function with specifically defined `pattern` and `keys`.

In [None]:
fasta = FastaFileReader(p2fasta)
dfn_line, sequence = next(fasta).values()
print(dfn_line.replace('\n', ''))

>2591237:ncbi:1	1	MK211378	2591237	ncbi	Coronavirus BtRs-BetaCoV/YN2018D


In [None]:
# pattern = r"^>(?P<seqid>(?P<taxonomyid>\d+):(?P<source>ncbi):(?P<seqnb>\d*))[\s\t]*\[(?P<accession>[\w\d]*)\]([\s\t]*(?P=taxonomyid)[\s\t]*(?P=source)[\s\t]*(?P=seqnb)[\s\t]*\[(?P=accession)\][\s\t]*(?P=taxonomyid)[\s\t]*(?P<species>[\w\s\-\_\/]*))?"
pattern = r"^>(?P<seqid>(?P<taxonomyid>\d+):(?P<source>ncbi):(?P<seqnb>\d*))[\s\t]*(?P=seqnb)[\s\t](?P<accession>[\w\d]*)([\s\t]*(?P=taxonomyid)[\s\t]*(?P=source)[\s\t][\s\t]*(?P<organism>[\w\s\-\_\/]*))?"

keys = 'seqid taxonomyid accession source seqnb organism'.split(' ')

In [None]:
fasta.parse_text(dfn_line, pattern=pattern, keys=keys)

{'accession': 'MK211378',
 'organism': 'Coronavirus BtRs-BetaCoV/YN2018D',
 'seqid': '2591237:ncbi:1',
 'seqnb': '1',
 'source': 'ncbi',
 'taxonomyid': '2591237'}

When a `FastaFileReader` instance is created, all existing rules in the file `default_parsing_rules.json` are tested on the first definition line of the fasta file and the one rule that parses the most matches will be selected automatically and saved in instance attributes `re_rule_name`, `re_pattern` and `re_keys`. 

`parse_file` extract metadata from each definition line in the fasta file and return a dictionary with all metadata.

In [None]:
print(fasta.re_rule_name)
print(fasta.re_pattern)
print(fasta.re_keys)

fasta_ncbi_std
^>(?P<seqid>(?P<taxonomyid>\d+):(?P<source>ncbi):(?P<seqnb>\d*))[\s\t]*(?P=seqnb)[\s\t](?P<accession>[\w\d]*)([\s\t]*(?P=taxonomyid)[\s\t]*(?P=source)[\s\t][\s\t]*(?P<organism>[\w\s\-\_/]*))?
['seqid', 'taxonomyid', 'source', 'accession', 'seqnb', 'organism']


In [None]:
fasta.parse_text(dfn_line)

{'accession': 'MK211378',
 'organism': 'Coronavirus BtRs-BetaCoV/YN2018D',
 'seqid': '2591237:ncbi:1',
 'seqnb': '1',
 'source': 'ncbi',
 'taxonomyid': '2591237'}

When another fasta file, which has another definition line structure, is used, another parsing rule is selected.

In [None]:
p2other = pfs.data / 'ncbi/refsequences/cov/another_sequence.fa'
assert p2other.is_file()

it2 = FastaFileReader(path=p2other)

dfn_line, sequence = next(it2).values()
print(dfn_line.replace('\n', ''))

>1 dna_rm:primary_assembly primary_assembly:mRhiFer1_v1.p:1:1:124933378:1 REF


In [None]:
print(it2.re_rule_name)
print(it2.re_pattern)
print(it2.re_keys)

fasta_rhinolophus_ferrumequinum
^>\d[\s\t](?P<seq_type>dna_rm):(?P<id_type>[\w\_]*)[\s\w](?P=id_type):(?P<assy>[\w\d\_]*)\.(?P<seq_level>[\w]*):\d*:\d*:(?P<taxonomy>\d*):(?P<id>\d*)[\s	]REF$
['seq_type', 'id_type', 'assy', 'seq_level', 'taxonomy', 'id']


In [None]:
pprint(it2.parse_text(dfn_line))

{'assy': 'mRhiFer1_v1',
 'id': '1',
 'id_type': 'primary_assembly',
 'seq_level': 'p',
 'seq_type': 'dna_rm',
 'taxonomy': '124933378'}


This rule selection is performed by the class method `set_parsing_rule`. The method can also be called with specific `pattern` and `keys` to force parsing rule not yet saved in the json file.

In [None]:
show_doc(FastaFileReader.set_parsing_rules)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/core.py#LNone){target="_blank" style="float:right; font-size:smaller"}

### TextFileBaseReader.set_parsing_rules

>      TextFileBaseReader.set_parsing_rules (pattern:str|None=None,
>                                            keys:list[str]|None=None,
>                                            verbose:bool=False)

*Set the standard regex parsing rule for the file.

Rules can be set:

1. manually by passing specific custom values for `pattern` and `keys`
2. automatically, by testing all parsing rules saved in `parsing_rule.json` 

Automatic selection of parsing rules works by testing each rule saved in `parsing_rule.json` on the first 
definition line of the fasta file, and selecting the one rule that generates the most metadata matches.

Rules consists of two parameters:

- The regex pattern including one `group` for each metadata item, e.g `(?P<group_name>regex_code)`
- The list of keys, i.e. the list with the name of each regex groups, used as key in the metadata dictionary

This method updates the three following class attributes: `re_rule_name`, `re_pattern`, `re_keys`*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| pattern | str \| None | None | regex pattern to apply to parse the text, search in parsing rules json if None |
| keys | list[str] \| None | None | list of keys/group for regex, search in parsing rules json if None |
| verbose | bool | False | when True, provides information on each rule |
| **Returns** | **None** |  |  |

In [None]:
fasta = FastaFileReader(p2fasta)
dfn_line, sequence = next(fasta).values()
print(f"definition line: '{dfn_line[:-1]}'")

definition line: '>2591237:ncbi:1	1	MK211378	2591237	ncbi	Coronavirus BtRs-BetaCoV/YN2018'


Automatic parsing works by testing each saved rule for the value of `definition line` in the first sequence in the fasta file.

In [None]:
print(f"key for text to parse: {fasta.text_to_parse_key}\n")
fasta.reset_iterator()
print('Text to parse for testing (extracted from first iteration):')
print(next(fasta)[fasta.text_to_parse_key])
print()
fasta.set_parsing_rules(verbose=True)

key for text to parse: definition line

Text to parse for testing (extracted from first iteration):
>2591237:ncbi:1	1	MK211378	2591237	ncbi	Coronavirus BtRs-BetaCoV/YN2018D

--------------------------------------------------------------------------------
Rule <fasta_ncbi_std> generated 6 matches
--------------------------------------------------------------------------------
^>(?P<seqid>(?P<taxonomyid>\d+):(?P<source>ncbi):(?P<seqnb>\d*))[\s\t]*(?P=seqnb)[\s\t](?P<accession>[\w\d]*)([\s\t]*(?P=taxonomyid)[\s\t]*(?P=source)[\s\t][\s\t]*(?P<organism>[\w\s\-\_/]*))?
['seqid', 'taxonomyid', 'source', 'accession', 'seqnb', 'organism']
--------------------------------------------------------------------------------
Rule <fastq_art_illumina_ncbi_std> generated an error
No match on this line
--------------------------------------------------------------------------------
Rule <aln_art_illumina_ncbi_std> generated an error
No match on this line
--------------------------------------------------

If no saved rule generates a match, `re_rule_name`, `re_pattern` and `re_keys` remain `None` and a warning message is issued to ask user to add a parsing rule manually. 

In [None]:
p2nomatch = pfs.data / 'ncbi/refsequences/cov/sequences_two_no_matching_rule.fa'
fasta2 = FastaFileReader(p2nomatch)

        None of the saved parsing rules were able to extract metadata from the first line in this file.
        You must set a custom rule (pattern + keys) before parsing text, by using:
            `self.set_parsing_rules(custom_pattern, custom_list_of_keys)`
                


In [None]:
fasta2.re_rule_name is None

True

But we still can set a standard rule manually, by passing a re pattern and the corresponding list of keys.

In [None]:
pat = r"^>(?P<seqid>(?P<taxonomyid>\d+):(?P<source>ncbi):(?P<seqnb>\d*))\s*(?P<text>[\w\s]*)$"
keys = "seqid taxonomyid source seqnb text".split()
fasta2.set_parsing_rules(pattern=pat, keys=keys)

print(fasta2.re_rule_name)
print(fasta2.re_pattern)
print(fasta2.re_keys)

Custom Rule
^>(?P<seqid>(?P<taxonomyid>\d+):(?P<source>ncbi):(?P<seqnb>\d*))\s*(?P<text>[\w\s]*)$
['seqid', 'taxonomyid', 'source', 'seqnb', 'text']


In [None]:
fasta2.reset_iterator()
dfn_line, sequence = next(fasta2).values()
print(f"definition line: '{dfn_line[:-1]}'")
fasta2.parse_text(dfn_line)

definition line: '>2591237:ncbi:1 this sequence does not match any saved parsing rul'


{'seqid': '2591237:ncbi:1',
 'seqnb': '1',
 'source': 'ncbi',
 'taxonomyid': '2591237',
 'text': 'this sequence does not match any saved parsing rule'}

In [None]:
show_doc(FastaFileReader.parse_file)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/cnn_virus/data.py#L112){target="_blank" style="float:right; font-size:smaller"}

### FastaFileReader.parse_file

>      FastaFileReader.parse_file (add_seq:bool=False, save_json:bool=False)

*Read fasta file and return a dictionary with definition line metadata and optionally sequences*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| add_seq | bool | False | When True, add the full sequence to the parsed metadata dictionary |
| save_json | bool | False | When True, save the file metadata as a json file of same stem name |
| **Returns** | **dict** |  | **Metadata as Key/Values pairs** |

In [None]:
fasta = FastaFileReader(p2fasta)
pprint(fasta.parse_file())

{'11128:ncbi:2': {'accession': 'LC494191',
                  'organism': 'Bovine coronavirus',
                  'seqid': '11128:ncbi:2',
                  'seqnb': '2',
                  'source': 'ncbi',
                  'taxonomyid': '11128'},
 '2591237:ncbi:1': {'accession': 'MK211378',
                    'organism': 'Coronavirus BtRs-BetaCoV/YN2018D',
                    'seqid': '2591237:ncbi:1',
                    'seqnb': '1',
                    'source': 'ncbi',
                    'taxonomyid': '2591237'}}


In [None]:
fasta.parse_file(save_json=True);

Metadata for 'cov_virus_sequences_two.fa'> saved as <cov_virus_sequences_two_metadata.json> in  
/home/vtec/projects/bio/metagentorch/nbs-dev/data_dev/ncbi/refsequences/cov



In [None]:
with open('../default_parsing_rules.json', 'r') as fp:
    pprint(json.load(fp), width=20)

{'aln_art_illumina-refseq-ncbi-std': {'keys': 'refseqid '
                                              'reftaxonomyid '
                                              'refsource '
                                              'refseqnb '
                                              'refseq_accession '
                                              'organism '
                                              'refseq_length',
                                      'pattern': '^@SQ[\\t\\s]*(?P<refseqid>(?P<reftaxonomyid>\\d*):(?P<refsource>\\w*):(?P<refseqnb>\\d*))[\\t\\s]*(?P=refseqnb)[\\t\\s]*(?P<refseq_accession>[\\w\\d]*)[\\t\\s]*(?P=reftaxonomyid)[\\t\\s]*(?P=refsource)[\\t\\s](?P<organism>.*)[\\t\\s](?P<refseq_length>\\d*)$'},
 'aln_art_illumina_ncbi_std': {'keys': 'refseqid '
                                       'reftaxonomyid '
                                       'refsource '
                                       'refseqnb '
                                       'readid '
     

In [None]:
p2fasta = pfs.data / 'ncbi/refsequences/cov/cov_virus_sequence_one.fa'
fasta = FastaFileReader(p2fasta)
fasta_meta = fasta.parse_file(save_json=True)
pprint(fasta_meta)

Metadata for 'cov_virus_sequence_one.fa'> saved as <cov_virus_sequence_one_metadata.json> in  
/home/vtec/projects/bio/metagentorch/nbs-dev/data_dev/ncbi/refsequences/cov

{'2591237:ncbi:1': {'accession': 'MK211378',
                    'organism': 'Coronavirus BtRs-BetaCoV/YN2018D',
                    'seqid': '2591237:ncbi:1',
                    'seqnb': '1',
                    'source': 'ncbi',
                    'taxonomyid': '2591237'}}


## FASTQ file

Extension of `TextFileBaseReader` class for fastq sequence files.

Structure of a FASTQ sequence file:

In [None]:
p2fastq = pfs.data / 'ncbi/simreads/cov/single_1seq_150bp/single_1seq_150bp.fq'

fastq = TextFileBaseReader(p2fastq, nlines=1)
for i, t in enumerate(fastq):
    txt = t.replace('\n', '')[:80]
    print(f"{txt}")
    if i >= 11: break

@2591237:ncbi:1-40200
TTGTAGATGGTGTTCCTTTTGTTGTTTCAACTGGATACCATTTTCGTGAGTTAGGAGTTGTACATAATCAGGATGTAAAC
+
CCCGGGCGGGGGCJGJJJGJJGJJJGJGGJGJJJGJGGGGGGGGCJGJGGGGGJJJJGCCGGGGGJCGCGJGJCG=GGGG
@2591237:ncbi:1-40199
TCATAGTACTACAGATAGAGACACCAGCTACGGTGCGAGCTCTATTCTTTGCACTAATGGCGTACTTAAGAGTCATTTGA
+
=CCGGGGCGGGGGJJGJJGJGJG=GJJGJCGJJJCJ=JJJJGGJJCJGJGG=JGC1JJGG8GCJCGGGCGG(GCGGCGC=
@2591237:ncbi:1-40198
TAACATAGTGGTTCGTTTATCAAGGATAATCTATCTCCATAGGTTCTTCATCATCTAACTCTGAATATTTATTCTTAGTT
+
C=CGGGGGGGGGGCJJJJ=JJJJJJJJJJJGGJJJJ1GJJ8GJJGGGJGGJJC=JJGGGCCGG88GG=GGGGGGCJGGGG


In [None]:
#| export
class FastqFileReader(TextFileBaseReader):
    """Iterator going through a fastq file's sequences and return each section + prob error as a dict"""
    def __init__(
        self,
        path:str|Path,   # path to the fastq file
    )-> dict:           # key/value with keys: definition line; sequence; q score; prob error
        self.nlines = 4
        super().__init__(path, nlines=self.nlines)
        self.text_to_parse_key = 'definition line'
        self.set_parsing_rules(verbose=False)        
    
    def __next__(self):
        """Return definition line, sequence and quality scores"""
        lines = []
        for i in range(self.nlines):
            lines.append(self._safe_readline().replace('\n', ''))
        
        output = {
            'definition line':lines[0], 
            'sequence':f"{lines[1]}", 
            'read_qscores': f"{lines[3]}",
        }
        output['probs error'] = np.array([q_score2prob_error(q) for q in output['read_qscores']])
        self._chunk_nb = self._chunk_nb + 1
        return output

    @property
    def read_nb(self)-> int:
        return self._chunk_nb
    
    def print_first_chunks(
        self, 
        nchunks:int=3,  # number of chunks to print out
    ):
        """Print the first `nchunks` chunks of text from the file"""
        for i, seq_dict in enumerate(self.__iter__()):
            print(f"\nSequence {i+1}:")
            print(seq_dict['definition line'])
            print(f"{seq_dict['sequence'][:80]} ...")
            if i >= nchunks: break
            
    def parse_file(
        self,
        add_readseq :bool=False,    # When True, add the full sequence to the parsed metadata dictionary
        add_qscores:bool=False,     # Add the read ASCII Q Scores to the parsed dictionary when True
        add_probs_error:bool=False, # Add the read probability of error to the parsed dictionary when True
        save_json: bool=False       # When True, save the file metadata as a json file of same stem name
    )-> dict[str]:                  # Metadata as Key/Values pairs
        """Read fastq file, return a dict with definition line metadata and optionally read sequence and q scores, ..."""
    
        self.reset_iterator()
        parsed = {}
        for d in self:
            dfn_line = d['definition line']
            seq, q_scores, prob_e = d['sequence'], d['read_qscores'], d['probs error']
            metadata = self._parse_text_fn(dfn_line, self.re_pattern, self.re_keys)
            if add_readseq: metadata['readseq'] = seq         
            if add_qscores: metadata['read_qscores'] = q_scores
            if add_probs_error: metadata['probs error'] = prob_e
            parsed[metadata['readid']] = metadata 
                        
        if save_json:
            p2json = self.path.parent / f"{self.path.stem}_metadata.json"
            with open(p2json, 'w') as fp:
                json.dump(parsed, fp, indent=4)
                print(f"Metadata for '{self.path.name}'> saved as <{p2json.name}> in  \n{p2json.parent.absolute()}\n")

        return parsed

In [None]:
show_doc(FastqFileReader)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/cnn_virus/data.py#L159){target="_blank" style="float:right; font-size:smaller"}

### FastqFileReader

>      FastqFileReader (path:str|pathlib.Path)

*Iterator going through a fastq file's sequences and return each section + prob error as a dict*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| path | str \| pathlib.Path | path to the fastq file |
| **Returns** | **dict** | **key/value with keys: definition line; sequence; q score; prob error** |

In [None]:
fastq = FastqFileReader(p2fastq)
iteration_output = next(fastq)

print(type(iteration_output))
print(iteration_output.keys())
print(f"Definition line:  {iteration_output['definition line']}")
print(f"Read sequence:    {iteration_output['sequence']}")
print(f"Q scores (ASCII): {iteration_output['read_qscores']}")
print(f"Prob error:       {','.join([f'{p:.4f}' for p in iteration_output['probs error']])}")

<class 'dict'>
dict_keys(['definition line', 'sequence', 'read_qscores', 'probs error'])
Definition line:  @2591237:ncbi:1-40200
Read sequence:    TTGTAGATGGTGTTCCTTTTGTTGTTTCAACTGGATACCATTTTCGTGAGTTAGGAGTTGTACATAATCAGGATGTAAACTTACATAGCTCGCGTCTCAGTTTCAAGGAACTTTTAGTGTATGCTGCTGATCCAGCCATGCATGCAGCTT
Q scores (ASCII): CCCGGGCGGGGGCJGJJJGJJGJJJGJGGJGJJJGJGGGGGGGGCJGJGGGGGJJJJGCCGGGGGJCGCGJGJCG=GGGG=CGGGGGG1GCGCGGGGCCGJC8GGGGGGGGGGGCGGGGGGGGGGGC8GGGGGGCGGC1GGGCGGGGGCC
Prob error:       0.0004,0.0004,0.0004,0.0002,0.0002,0.0002,0.0004,0.0002,0.0002,0.0002,0.0002,0.0002,0.0004,0.0001,0.0002,0.0001,0.0001,0.0001,0.0002,0.0001,0.0001,0.0002,0.0001,0.0001,0.0001,0.0002,0.0001,0.0002,0.0002,0.0001,0.0002,0.0001,0.0001,0.0001,0.0002,0.0001,0.0002,0.0002,0.0002,0.0002,0.0002,0.0002,0.0002,0.0002,0.0004,0.0001,0.0002,0.0001,0.0002,0.0002,0.0002,0.0002,0.0002,0.0001,0.0001,0.0001,0.0001,0.0002,0.0004,0.0004,0.0002,0.0002,0.0002,0.0002,0.0002,0.0001,0.0004,0.0002,0.0004,0.0002,0.0001,0.0002,0.0001,0.00

Five largest probabilities of error:

In [None]:
np.sort(iteration_output['probs error'])[-5:]

array([0.00158489, 0.00501187, 0.00501187, 0.02511886, 0.02511886])

In [None]:
np.argsort(iteration_output['probs error'])[-5:]

array([ 80, 127, 102, 138,  88])

In [None]:
dfn_line = iteration_output['definition line']
meta = fastq.parse_text(dfn_line)
meta

{'readid': '2591237:ncbi:1-40200',
 'readnb': '40200',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}

In [None]:
fastq = FastqFileReader(p2fastq)
next(fastq).keys()

dict_keys(['definition line', 'sequence', 'read_qscores', 'probs error'])

In [None]:
show_doc(FastqFileReader.parse_file)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/cnn_virus/data.py#L200){target="_blank" style="float:right; font-size:smaller"}

### FastqFileReader.parse_file

>      FastqFileReader.parse_file (add_readseq:bool=False,
>                                  add_qscores:bool=False,
>                                  add_probs_error:bool=False,
>                                  save_json:bool=False)

*Read fastq file, return a dict with definition line metadata and optionally read sequence and q scores, ...*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| add_readseq | bool | False | When True, add the full sequence to the parsed metadata dictionary |
| add_qscores | bool | False | Add the read ASCII Q Scores to the parsed dictionary when True |
| add_probs_error | bool | False | Add the read probability of error to the parsed dictionary when True |
| save_json | bool | False | When True, save the file metadata as a json file of same stem name |
| **Returns** | **dict** |  | **Metadata as Key/Values pairs** |

In [None]:
parsed = fastq.parse_file(add_readseq=False, add_qscores=False, add_probs_error=False)
for i, (k, v) in enumerate(parsed.items()):
    print(k)
    pprint(v)
    if i >=3: break

2591237:ncbi:1-40200
{'readid': '2591237:ncbi:1-40200',
 'readnb': '40200',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}
2591237:ncbi:1-40199
{'readid': '2591237:ncbi:1-40199',
 'readnb': '40199',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}
2591237:ncbi:1-40198
{'readid': '2591237:ncbi:1-40198',
 'readnb': '40198',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}
2591237:ncbi:1-40197
{'readid': '2591237:ncbi:1-40197',
 'readnb': '40197',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}


In [None]:
metadata = fastq.parse_file(add_readseq=True)
df = pd.DataFrame(metadata).T
df.head(10)

Unnamed: 0,readid,readnb,refseqnb,refsource,reftaxonomyid,readseq
2591237:ncbi:1-40200,2591237:ncbi:1-40200,40200,1,ncbi,2591237,TTGTAGATGGTGTTCCTTTTGTTGTTTCAACTGGATACCATTTTCG...
2591237:ncbi:1-40199,2591237:ncbi:1-40199,40199,1,ncbi,2591237,TCATAGTACTACAGATAGAGACACCAGCTACGGTGCGAGCTCTATT...
2591237:ncbi:1-40198,2591237:ncbi:1-40198,40198,1,ncbi,2591237,TAACATAGTGGTTCGTTTATCAAGGATAATCTATCTCCATAGGTTC...
2591237:ncbi:1-40197,2591237:ncbi:1-40197,40197,1,ncbi,2591237,TAATCACTGATAGCAGCATTGCCATCCTGAGCAAAGAAGAAGTGTT...
2591237:ncbi:1-40196,2591237:ncbi:1-40196,40196,1,ncbi,2591237,CTAATGTCAGTACGCCTACAATGCCTGCATCACGCATAGCATCGCA...
2591237:ncbi:1-40195,2591237:ncbi:1-40195,40195,1,ncbi,2591237,AAGCTGAAGCATACATAACACAGTCCTTAAGCCGATAACCAGACAA...
2591237:ncbi:1-40194,2591237:ncbi:1-40194,40194,1,ncbi,2591237,AGTGGAAGAACTTCACCGTCAAGATGAAACTCGACGGGGCTCTCCA...
2591237:ncbi:1-40193,2591237:ncbi:1-40193,40193,1,ncbi,2591237,GCGTCTCGAGTGCTTCGAGTTCACCGTTCTTGAGAACAACCTCCTC...
2591237:ncbi:1-40192,2591237:ncbi:1-40192,40192,1,ncbi,2591237,CTGGTAGTATCTAAGGCTCCACTGAAATACTTGTACTTGTTATATA...
2591237:ncbi:1-40191,2591237:ncbi:1-40191,40191,1,ncbi,2591237,GTCTCTATCTGTAGTACTATGACAAATAGACAGTTTCATCAGAAAT...


In [None]:
fastq.set_parsing_rules(verbose=True)

--------------------------------------------------------------------------------
Rule <fasta_ncbi_std> generated an error
No match on this line
--------------------------------------------------------------------------------
Rule <fastq_art_illumina_ncbi_std> generated 5 matches
--------------------------------------------------------------------------------
^@(?P<readid>(?P<reftaxonomyid>\d*):(?P<refsource>\w*):(?P<refseqnb>\d*)-(?P<readnb>\d*(\/\d)?))$
['readid', 'reftaxonomyid', 'refsource', 'refseqnb', 'readnb']
--------------------------------------------------------------------------------
Rule <aln_art_illumina_ncbi_std> generated an error
No match on this line
--------------------------------------------------------------------------------
Rule <aln_art_illumina-refseq-ncbi-std> generated an error
No match on this line
--------------------------------------------------------------------------------
Rule <fasta_ncbi_cov> generated an error
No match on this line
-----------------

## ALN Alignment Files

Extension of `TextFileBaseReader` class for ALN read/sequence alignment files.

Structure of a ALN sequence file:

In [None]:
p2aln = pfs.data / 'ncbi/simreads/cov/single_1seq_150bp/single_1seq_150bp.aln'
assert p2aln.is_file()

aln = TextFileBaseReader(p2aln, nlines=1)
for i, t in enumerate(aln):
    txt = t.replace('\n', '')[:80]
    print(f"{txt}")
    if i >= 12: break

##ART_Illumina	read_length	150
@CM	/usr/bin/art_illumina -i /home/vtec/projects/bio/metagentools/data/ncbi/refs
@SQ	2591237:ncbi:1	1	MK211378	2591237	ncbi	Coronavirus BtRs-BetaCoV/YN2018D	3021
##Header End
>2591237:ncbi:1	2591237:ncbi:1-40200	14370	+
TTGTAGATGGTGTTCCTTTTGTTGTTTCAACTGGATACCATTTTCGTGAGTTAGGAGTTGTACATAATCAGGATGTAAAC
TTGTAGATGGTGTTCCTTTTGTTGTTTCAACTGGATACCATTTTCGTGAGTTAGGAGTTGTACATAATCAGGATGTAAAC
>2591237:ncbi:1	2591237:ncbi:1-40199	15144	-
TCATAGTACTACAGATAGAGACACCAGCTACGGTGCGAGCTCTATTCTTTGCACTAATGGCGTACTTAAGATTCATTTGA
TCATAGTACTACAGATAGAGACACCAGCTACGGTGCGAGCTCTATTCTTTGCACTAATGGCGTACTTAAGAGTCATTTGA
>2591237:ncbi:1	2591237:ncbi:1-40198	2971	-
TAACATAGTGGTTCGTTTATCAAGGATAATCTATCTCCATAGGTTCTTCATCATCTAACTCTGAATATTTATTCTTAGTT
TAACATAGTGGTTCGTTTATCAAGGATAATCTATCTCCATAGGTTCTTCATCATCTAACTCTGAATATTTATTCTTAGTT


In [None]:
#| export
class AlnFileReader(TextFileBaseReader):
    """Iterator going through an ALN file"""
    def __init__(
        self,
        path:str|Path,   # path to the aln file
    )-> dict:            # key/value with keys: 
        """Set TextFileBaseReader attributes and specific class attributes"""
        self.nlines = 1
        super().__init__(path, nlines=self.nlines)
        self.header = self.read_header()
        self.nlines = 3
        self.text_to_parse_key = 'definition line'
        self.set_parsing_rules(verbose=False)
        self.set_header_parsing_rules(verbose=False)
        self.ref_sequences = self.parse_header_reference_sequences()

    def __next__(self):
        """Return definition line, sequence and quality scores"""
        lines = []
        for i in range(self.nlines):
            lines.append(self._safe_readline().replace('\n', ''))

        output = {
            'definition line':lines[0], 
            'ref_seq_aligned':f"{lines[1]}", 
            'read_seq_aligned': f"{lines[2]}",
        }   
        return output
    
    def read_header(self):
        """Read ALN file Header and return each section parsed in a dictionary"""
        
        header = {}
        if self.fp is not None:
            self.fp.close()
        self.fp = open(self.path, 'r')
        
        line = self._safe_readline().replace('\n', '')
        if not line.startswith('##ART_Illumina'): 
            raise ValueError(f"Header of this file does not start with ##ART_Illumina")
        line = self._safe_readline().replace('\n', '')
        if not line.startswith('@CM'): 
            raise ValueError(f"First header line should start with @CM")
        else: 
            header['command'] = line[3:].replace('\t', '').strip()

        refseqs = []
        while True:
            line = self._safe_readline().replace('\n', '')
            if line.startswith('##Header End'): break
            else:
                refseqs.append(line)
        header['reference sequences'] = refseqs
        
        return header
    
    def reset_iterator(self):
        """Reset the iterator to point to the first line in the file, by recreating a new file handle.
        
        `AlnFileReader` requires a specific `reset_iterator` method, in order to skip the header every time it is reset
        """
        if self.fp is not None:
            self.fp.close()
        self.fp = open(self.path, 'r')
        while True:
            line = self._safe_readline().replace('\n', '')
            if line.startswith('##Header End'): break

    def parse_definition_line_with_position(
        self, 
        dfn_line:str    # fefinition line string to be parsed
        )-> dict:       # parsed metadata in key/value format + relative position of the read
        """Parse definition line and adds relative position"""
        read_meta = self.parse_text(dfn_line)
        read_refseqid = read_meta['refseqid']
        read_start_pos = int(read_meta['aln_start_pos'])
        read_refseq_lentgh = int(self.ref_sequences[read_refseqid]['refseq_length'])
        read_meta['read_pos'] = (read_start_pos *10)// read_refseq_lentgh + 1
        return read_meta
    
    def parse_file(
        self, 
        add_ref_seq_aligned:bool=False,   # Add the reference sequence aligned to the parsed dictionary when True
        add_read_seq_aligned:bool=False,  # Add the read sequence aligned to the parsed dictionary when True
    )-> dict[str]: 
        # Key/Values. Keys: 
        # `readid`,`seqid`,`seq_nbr`,`read_nbr`,`aln_start_pos`,`ref_seq_strand`
        # optionaly `ref_seq_aligned`,`read_seq_aligned`
        """Read ALN file, return a dict w/ alignment info for each read and optionaly aligned reference sequence & read"""
        self.reset_iterator()
        parsed = {}
        for d in self:
            dfn_line = d['definition line']
            ref_seq_aligned, read_seq_aligned = d['ref_seq_aligned'], d['read_seq_aligned']
            metadata = self.parse_text(dfn_line)
            if add_ref_seq_aligned: metadata['ref_seq_aligned'] = ref_seq_aligned         
            if add_read_seq_aligned: metadata['read_seq_aligned'] = read_seq_aligned
            parsed[metadata['readid']] = metadata 
        return parsed

    def parse_header_reference_sequences(
        self,
        pattern:str|None=None,     # regex pattern to apply to parse the reference sequence info
        keys:list[str]|None=None,  # list of keys: keys are both regex match group names and corresponding output dict keys 
        )->dict[str]:                  # parsed metadata in key/value format
        """Extract metadata from all header reference sequences"""
        if pattern is None and keys is None:
            pattern, keys = self.re_header_pattern, self.re_header_keys
        parsed = {}
        for seq_dfn_line in self.header['reference sequences']:
            metadata = self.parse_text(seq_dfn_line, pattern, keys)
            parsed[metadata['refseqid']] = metadata
            
        return parsed       
        
    def set_header_parsing_rules(
        self,
        pattern: str|bool=None,   # regex pattern to apply to parse the text, search in parsing rules json if None
        keys: list[str]=None,     # list of keys/group for regex, search in parsing rules json if None
        verbose: bool=False       # when True, provides information on each rule
    )-> None:
        """Set the regex parsing rule for reference sequence in ALN header.
               
        Updates 3 class attributes: `re_header_rule_name`, `re_header_pattern`, `re_header_keys`
        
        TODO: refactor this and the method in Core: to use a single function for the common part and a parameter for the text_to_parse 
        """
        
        P2JSON = Path(f"{PACKAGE_ROOT}/default_parsing_rules.json")
        
        self.re_header_rule_name = None
        self.re_header_pattern = None
        self.re_header_keys = None
        
        # get the first reference sequence definition line in header
        text_to_parse = self.header['reference sequences'][0]
        divider_line = f"{'-'*80}"

        if pattern is not None and keys is not None:  # When specific pattern and keys are passed
            try:
                metadata_dict = self.parse_text(text_to_parse, pattern, keys)
                self.re_header_rule_name = 'Custom Rule'
                self.re_header_pattern = pattern
                self.re_header_keys = keys
                if verbose:
                    print(divider_line)
                    print(f"Custom rule was set for header in this instance.")
            except Exception as err: 
                raise ValueError(f"The pattern generates the following error:\n{err}")
                
        else:  # automatic rule selection among rules saved in json file
            # Load all existing rules from json file
            with open(P2JSON, 'r') as fp:
                parsing_rules = json.load(fp)
                
            # test all existing rules and keep the one with highest number of matches
            max_nbr_matches = 0
            for k, v in parsing_rules.items():
                re_header_pattern = v['pattern']
                re_header_keys = v['keys'].split(' ')
                try:
                    metadata_dict = self.parse_text(text_to_parse, re_header_pattern, re_header_keys)
                    nbr_matches = len(metadata_dict)
                    if verbose:
                        print(divider_line)
                        print(f"Rule <{k}> generated {nbr_matches:,d} matches")
                        print(divider_line)
                        print(re_header_pattern)
                        print(re_header_keys)

                    if len(metadata_dict) > max_nbr_matches:
                        self.re_header_pattern = re_header_pattern
                        self.re_header_keys = re_header_keys
                        self.re_header_rule_name = k    
                except Exception as err:
                    if verbose:
                        print(divider_line)
                        print(f"Rule <{k}> generated an error")
                        print(err)
                    else:
                        pass
            if self.re_header_rule_name is None:
                msg = """
        None of the saved parsing rules were able to extract metadata from the first line in this file.
        You must set a custom rule (pattern + keys) before parsing text, by using:
            `self.set_parsing_rules(custom_pattern, custom_list_of_keys)`
                """
                warnings.warn(msg, category=UserWarning)
            
            if verbose:
                print(divider_line)
                print(f"Selected rule with most matches: {self.re_header_rule_name}")

            # We used the iterator, now we need to reset it to make all lines available
            self.reset_iterator()

In [None]:
show_doc(AlnFileReader)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/cnn_virus/data.py#L229){target="_blank" style="float:right; font-size:smaller"}

### AlnFileReader

>      AlnFileReader (path:str|pathlib.Path)

*Iterator going through an ALN file*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| path | str \| pathlib.Path | path to the aln file |
| **Returns** | **dict** | **key/value with keys:** |

In [None]:
aln = AlnFileReader(p2aln)

`AlnFileReader` iterator returns elements one by one, as dictionaries with each data line related to the read, accessible through the following keys: 

- key `'definition line'`: **read definition line**, including read metadata 
- key `'ref_seq_aligned'`: **aligned reference sequence**, that is the sequence segment in the original reference corresponding to the read
- key `'read_seq_aligned'`: **aligned read**, that is the simmulated read sequence, where each bp corresponds to the reference sequence bp in the same position.

In [None]:
one_iteration = next(aln)
one_iteration.keys()

dict_keys(['definition line', 'ref_seq_aligned', 'read_seq_aligned'])

In [None]:
pprint(one_iteration)

{'definition line': '>2591237:ncbi:1\t2591237:ncbi:1-40200\t14370\t+',
 'read_seq_aligned': 'TTGTAGATGGTGTTCCTTTTGTTGTTTCAACTGGATACCATTTTCGTGAGTTAGGAGTTGTACATAATCAGGATGTAAACTTACATAGCTCGCGTCTCAGTTTCAAGGAACTTTTAGTGTATGCTGCTGATCCAGCCATGCATGCAGCTT',
 'ref_seq_aligned': 'TTGTAGATGGTGTTCCTTTTGTTGTTTCAACTGGATACCATTTTCGTGAGTTAGGAGTTGTACATAATCAGGATGTAAACTTACATAGCTCGCGTCTCAGTTTCAAGGAACTTTTAGTGTATGCTGCTGATCCAGCCATGCATGCAGCTT'}


In [None]:
dfn_line, ref_seq_aligned, read_seq_aligned = one_iteration.values()

In [None]:
dfn_line

'>2591237:ncbi:1\t2591237:ncbi:1-40200\t14370\t+'

In [None]:
ref_seq_aligned[:100]

'TTGTAGATGGTGTTCCTTTTGTTGTTTCAACTGGATACCATTTTCGTGAGTTAGGAGTTGTACATAATCAGGATGTAAACTTACATAGCTCGCGTCTCAG'

In [None]:
read_seq_aligned[:100]

'TTGTAGATGGTGTTCCTTTTGTTGTTTCAACTGGATACCATTTTCGTGAGTTAGGAGTTGTACATAATCAGGATGTAAACTTACATAGCTCGCGTCTCAG'

In [None]:
another_iteration = next(aln)
pprint(another_iteration)

{'definition line': '>2591237:ncbi:1\t2591237:ncbi:1-40199\t15144\t-',
 'read_seq_aligned': 'TCATAGTACTACAGATAGAGACACCAGCTACGGTGCGAGCTCTATTCTTTGCACTAATGGCGTACTTAAGAGTCATTTGAGTTATAGTAGGGATGACATTACGCTTAGTATACGCGAAAAGTGCATCTTGATCCTCATAACTCATTGAGT',
 'ref_seq_aligned': 'TCATAGTACTACAGATAGAGACACCAGCTACGGTGCGAGCTCTATTCTTTGCACTAATGGCGTACTTAAGATTCATTTGAGTTATAGTAGGGATGACATTACGCTTAGTATACGCGAAAAGTGCATCTTGATCCTCATAACTCATTGAGT'}


In [None]:
aln.reset_iterator()
for i, d in enumerate(aln):
    print(d['definition line'])
    print(d['ref_seq_aligned'][:80], '...')
    print(d['read_seq_aligned'][:80], '...\n')
    if i >= 3: break

>2591237:ncbi:1	2591237:ncbi:1-40200	14370	+
TTGTAGATGGTGTTCCTTTTGTTGTTTCAACTGGATACCATTTTCGTGAGTTAGGAGTTGTACATAATCAGGATGTAAAC ...
TTGTAGATGGTGTTCCTTTTGTTGTTTCAACTGGATACCATTTTCGTGAGTTAGGAGTTGTACATAATCAGGATGTAAAC ...

>2591237:ncbi:1	2591237:ncbi:1-40199	15144	-
TCATAGTACTACAGATAGAGACACCAGCTACGGTGCGAGCTCTATTCTTTGCACTAATGGCGTACTTAAGATTCATTTGA ...
TCATAGTACTACAGATAGAGACACCAGCTACGGTGCGAGCTCTATTCTTTGCACTAATGGCGTACTTAAGAGTCATTTGA ...

>2591237:ncbi:1	2591237:ncbi:1-40198	2971	-
TAACATAGTGGTTCGTTTATCAAGGATAATCTATCTCCATAGGTTCTTCATCATCTAACTCTGAATATTTATTCTTAGTT ...
TAACATAGTGGTTCGTTTATCAAGGATAATCTATCTCCATAGGTTCTTCATCATCTAACTCTGAATATTTATTCTTAGTT ...

>2591237:ncbi:1	2591237:ncbi:1-40197	15485	-
TAATCACTGATAGCAGCATTGCCATCCTGAGCAAAGAAGAAGTGTTTTAGTTCAACAGAACTTCCTTCCTTAAAGAAACC ...
TAATCACTGATAGCAGCATTGCCATCCTGAGCAAAGAAGAAGTGTTTTAGTTCAACAGAACTTCCTTCCTTAAAGAAACC ...



Once instantiated, the `AlnFileReader` iterator gives access to the file's header information through `header` instance attribute. It is a dictionary with two keys: `'command'` and `'reference sequences'`:

```
    {'command':             'art-illumina command used to create the reads',
     'reference sequences': ['@SQ metadata on reference sequence 1 used for the reads',
                             '@SQ metadata on reference sequence 2 used for the reads', 
                             ...
                            ]
    }
```

In [None]:
print(aln.header['command'])

/usr/bin/art_illumina -i /home/vtec/projects/bio/metagentools/data/ncbi/refsequences/cov/cov_refseq_001-seq1.fa -ss HS25 -l 150 -f 200 -o /home/vtec/projects/bio/metagentools/data/ncbi/simreads/cov/single_1seq_150bp/single_1seq_150bp -rs 1723893089


In [None]:
for seq_info in aln.header['reference sequences']:
    print(seq_info)

@SQ	2591237:ncbi:1	1	MK211378	2591237	ncbi	Coronavirus BtRs-BetaCoV/YN2018D	30213


The **read definition line** includes key metadata, which need to be parsed using the appropriate parsing rule.

In [None]:
show_doc(AlnFileReader.parse_text)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/core.py#LNone){target="_blank" style="float:right; font-size:smaller"}

### TextFileBaseReader.parse_text

>      TextFileBaseReader.parse_text (txt:str, pattern:str|None=None,
>                                     keys:list[str]|None=None)

*Parse text using regex pattern and keys. Return a metadata dictionary.

The passed text is parsed using the regex pattern. The method return a dictionary in the format:

    {
        'key_1': 'metadata 1',
        'key_2': 'metadata 2',
        ...
    }*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| txt | str |  | text to parse |
| pattern | str \| None | None | If None, uses standard regex pattern to extract metadata, otherwise, uses passed regex |
| keys | list[str] \| None | None | If None, uses standard regex list of keys, otherwise, uses passed list of keys (str) |
| **Returns** | **dict** |  | **parsed metadata in key/value format** |

In [None]:
#| hide
pattern, keys = aln.re_pattern, aln.re_keys

In [None]:
aln.parse_text(dfn_line, pattern, keys)

{'aln_start_pos': '14370',
 'readid': '2591237:ncbi:1-40200',
 'readnb': '40200',
 'refseq_strand': '+',
 'refseqid': '2591237:ncbi:1',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}

In [None]:
show_doc(AlnFileReader.parse_definition_line_with_position)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/cnn_virus/data.py#L297){target="_blank" style="float:right; font-size:smaller"}

### AlnFileReader.parse_definition_line_with_position

>      AlnFileReader.parse_definition_line_with_position (dfn_line:str)

*Parse definition line and adds relative position*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| dfn_line | str | fefinition line string to be parsed |
| **Returns** | **dict** | **parsed metadata in key/value format + relative position of the read** |

Upon instance creation, `AlnFileReader` automatically checks the `default_parsing_rules.json` file for a workable rule among saved rules. Saved rules include the rule for ART Illumina ALN files.

In [None]:
aln.re_rule_name

'aln_art_illumina_ncbi_std'

It is therefore not required to pass a specific `pattern` and `keys` parameter.


In [None]:
aln.parse_text(dfn_line)

{'aln_start_pos': '14370',
 'readid': '2591237:ncbi:1-40200',
 'readnb': '40200',
 'refseq_strand': '+',
 'refseqid': '2591237:ncbi:1',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}

ART Ilumina ALN files definition lines consist of:

- The **read** ID: `readid`, e.g. `2591237:ncbi:1-20100`
- the **read** number (order in the file): `readnb`, e.g. `20100`
- The **read** start position in the reference sequence: `aln_start_pos`, e.g. `23878`
- The **reference sequence** ID: `readid`, e.g. `2591237:ncbi:1-20100`
- The **reference sequence** number: `refseqnb`, e.g. `1`
- The **reference sequence** source: `refsource`, e.g. `ncbi`
- The **reference sequence** taxonomy: `reftaxonomyid`, e.g. `2591237`
- The **reference sequence** strand:  `refseq_strand` wich is either `+` or  `-`,


In [None]:
show_doc(AlnFileReader.parse_file)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/cnn_virus/data.py#L309){target="_blank" style="float:right; font-size:smaller"}

### AlnFileReader.parse_file

>      AlnFileReader.parse_file (add_ref_seq_aligned:bool=False,
>                                add_read_seq_aligned:bool=False)

*Read ALN file, return a dict w/ alignment info for each read and optionaly aligned reference sequence & read*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| add_ref_seq_aligned | bool | False | Add the reference sequence aligned to the parsed dictionary when True |
| add_read_seq_aligned | bool | False | Add the read sequence aligned to the parsed dictionary when True |
| **Returns** | **dict** |  |  |

In [None]:
parsed = aln.parse_file()

for i, (k, v) in enumerate(parsed.items()):
    print(k)
    pprint(v)
    if i > 3: break

2591237:ncbi:1-40200
{'aln_start_pos': '14370',
 'readid': '2591237:ncbi:1-40200',
 'readnb': '40200',
 'refseq_strand': '+',
 'refseqid': '2591237:ncbi:1',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}
2591237:ncbi:1-40199
{'aln_start_pos': '15144',
 'readid': '2591237:ncbi:1-40199',
 'readnb': '40199',
 'refseq_strand': '-',
 'refseqid': '2591237:ncbi:1',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}
2591237:ncbi:1-40198
{'aln_start_pos': '2971',
 'readid': '2591237:ncbi:1-40198',
 'readnb': '40198',
 'refseq_strand': '-',
 'refseqid': '2591237:ncbi:1',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}
2591237:ncbi:1-40197
{'aln_start_pos': '15485',
 'readid': '2591237:ncbi:1-40197',
 'readnb': '40197',
 'refseq_strand': '-',
 'refseqid': '2591237:ncbi:1',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}
2591237:ncbi:1-40196
{'aln_start_pos': '16221',
 'readid': '2591237:ncbi:1-40196',
 'readnb': '40

In [None]:
show_doc(AlnFileReader.parse_header_reference_sequences)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/cnn_virus/data.py#L329){target="_blank" style="float:right; font-size:smaller"}

### AlnFileReader.parse_header_reference_sequences

>      AlnFileReader.parse_header_reference_sequences (pattern:str|None=None,
>                                                      keys:list[str]|None=None)

*Extract metadata from all header reference sequences*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| pattern | str \| None | None | regex pattern to apply to parse the reference sequence info |
| keys | list[str] \| None | None | list of keys: keys are both regex match group names and corresponding output dict keys |
| **Returns** | **dict** |  | **parsed metadata in key/value format** |

In [None]:
pprint(aln.parse_header_reference_sequences())

{'2591237:ncbi:1': {'organism': 'Coronavirus BtRs-BetaCoV/YN2018D',
                    'refseq_accession': 'MK211378',
                    'refseq_length': '30213',
                    'refseqid': '2591237:ncbi:1',
                    'refseqnb': '1',
                    'refsource': 'ncbi',
                    'reftaxonomyid': '2591237'}}


In [None]:
show_doc(AlnFileReader.set_header_parsing_rules)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/cnn_virus/data.py#L344){target="_blank" style="float:right; font-size:smaller"}

### AlnFileReader.set_header_parsing_rules

>      AlnFileReader.set_header_parsing_rules (pattern:str|bool=None,
>                                              keys:list[str]=None,
>                                              verbose:bool=False)

*Set the regex parsing rule for reference sequence in ALN header.

Updates 3 class attributes: `re_header_rule_name`, `re_header_pattern`, `re_header_keys`

TODO: refactor this and the method in Core: to use a single function for the common part and a parameter for the text_to_parse*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| pattern | str \| bool | None | regex pattern to apply to parse the text, search in parsing rules json if None |
| keys | list | None | list of keys/group for regex, search in parsing rules json if None |
| verbose | bool | False | when True, provides information on each rule |
| **Returns** | **None** |  |  |

In [None]:
aln.set_header_parsing_rules(verbose=True)

--------------------------------------------------------------------------------
Rule <fasta_ncbi_std> generated an error
No match on this line
--------------------------------------------------------------------------------
Rule <fastq_art_illumina_ncbi_std> generated an error
No match on this line
--------------------------------------------------------------------------------
Rule <aln_art_illumina_ncbi_std> generated an error
No match on this line
--------------------------------------------------------------------------------
Rule <aln_art_illumina-refseq-ncbi-std> generated 7 matches
--------------------------------------------------------------------------------
^@SQ[\t\s]*(?P<refseqid>(?P<reftaxonomyid>\d*):(?P<refsource>\w*):(?P<refseqnb>\d*))[\t\s]*(?P=refseqnb)[\t\s]*(?P<refseq_accession>[\w\d]*)[\t\s]*(?P=reftaxonomyid)[\t\s]*(?P=refsource)[\t\s](?P<organism>.*)[\t\s](?P<refseq_length>\d*)$
['refseqid', 'reftaxonomyid', 'refsource', 'refseqnb', 'refseq_accession', 'organism

In [None]:
print(aln.re_header_rule_name)
print(aln.re_header_pattern)
print(aln.re_header_keys)

aln_art_illumina-refseq-ncbi-std
^@SQ[\t\s]*(?P<refseqid>(?P<reftaxonomyid>\d*):(?P<refsource>\w*):(?P<refseqnb>\d*))[\t\s]*(?P=refseqnb)[\t\s]*(?P<refseq_accession>[\w\d]*)[\t\s]*(?P=reftaxonomyid)[\t\s]*(?P=refsource)[\t\s](?P<organism>.*)[\t\s](?P<refseq_length>\d*)$
['refseqid', 'reftaxonomyid', 'refsource', 'refseqnb', 'refseq_accession', 'organism', 'refseq_length']


# Build datasets

Reads are provided in various formats (plain text for original data, fastq + aln for simulated reads). The CNN_Virus model requires inputs in the form of three tensors:

- `read_seq_batch`: a batch of 50-mer read sequences, in "base-hot-encoded" format (BHE)
- `label_batch`: a batch of the species' labels, in one-hot-encoded format (187 classes) (OHE)
- `pos_batch`: a batch of the read's relative positions in the reference sequence, in one-hot-encoded format (10 classes) (OHE)

The classes below are custom pytorch `Dataset` used to load the reads from their file and transform it into BHE or OHE tensors.

## Plain text based data

The datasets provided by CNN Virus team are in plain text format, one sequence per line, with format as follows:

```ascii
AAAAAGATTTTGAGAGAGGTCGACCTGTCCTCCTAAAACGTTTACAAAAG	71	0
CATGTAACGCAGCTTAGTCCGATCGTGGCTATAATCCGTCTTTCGATTTG	1	7
AACAACATCTTGTTGATGATAACCGTCAAAGTGTTTTGGGTCTGGAGGGA	158	6
AGTACCTGGAGAGCGTTAAGAAACACAAACGGCTGGATGTAGTGCCGCGC	6	7
CCACGTCGATGAAGCTCCGACGAGAGTCGGCGCTGAGCCCGCGCACCTCC	71	6
AGCTCGTGGATCTCCCCTCCTTCTGCAGTTTCAACATCAGAAGCCCTGAA	87	1
```

The first column is the k-base (k-mer) read, followed byt the specie label and the relative position.

When using this type of data, the data pipeline consist of:

- Creating a `TextFileDataset` instance to load the data

- Using the created dataset to in a `Dataloader` used for training or inference.

In [None]:
#| export
class TextFileDataset(IterableDataset):
    """Load data from text file and yield (BHE sequence tensor, (label OHE tensor, position OHE tensor))"""

    base2encoding = {
        'A': [1,0,0,0,0], 
        'C': [0,1,0,0,0], 
        'G': [0,0,1,0,0], 
        'T': [0,0,0,1,0], 
        'N': [0,0,0,0,1],
        '-': [0,0,0,0,1],
        }
    nb_labels = 187
    nb_pos = 10
    
    def __init__(
        self,
        p2file:str|Path,  # path to the file to read
    ):
        self.p2file = safe_path(p2file)

    def __iter__(self):
        with open(self.p2file, 'r') as f:
            for line in f:
                # wi = torch.utils.data.get_worker_info()
                # if wi:
                #     print(f"{wi.id} loading {line}")
                seq, lbl, pos = line.replace('\n','').strip().split('\t')
                seq_bhe = torch.tensor(list(map(self._bhe_fn, seq)))
                lbl_ohe = torch.zeros(self.nb_labels)
                lbl_ohe[int(lbl)] = 1
                pos_ohe = torch.zeros(self.nb_pos)
                pos_ohe[int(pos)] = 1
                yield seq_bhe, (lbl_ohe, pos_ohe)
    
    def _bhe_fn(self, base:str) -> list[int]:
        """Convert a base to a one hot encoding vector"""
        return self.base2encoding[base]

This `IterableDataset` reads the file line by line and yields a tupple `(seq_bhe, (lbl_ohe, pos_ohe))`:

- `seq_bhe`: tensor of shape [k, 5] with the k-mer read in base-hot-encoded format
- `lbl_ohe`: tensor of shape [187] with the specie label in one-hot-encoded format
- `pos_ohe`: tensor of shape [10] with the relative position label in one-hot-encoded format

The dataset is then used with a pytorch `DataLoader` to handle batching. When using the GPU, the `pin_memory=True` should be used to make transfer of data to the GPU faster.

Let's use a small data file to illustrate the process.

In [None]:
p2file = Path('data_dev/CNN_Virus_data/50mer_ds_100_seq')
assert p2file.is_file()

ds = TextFileDataset(p2file)
dl = DataLoader(ds, batch_size=8, num_workers=2, pin_memory=False)

for seq_batch, (lbl_batch, pos_batch) in dl:
    print(seq_batch.shape, lbl_batch.shape, pos_batch.shape)

torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 187]) torch.Size([8, 10])
torch.Size([8, 50, 5]) torch.Size([8, 18

The read sequence tensor is BHE:

In [None]:
seq_batch[:2, :3, :]

tensor([[[1, 0, 0, 0, 0],
         [1, 0, 0, 0, 0],
         [0, 1, 0, 0, 0]],

        [[1, 0, 0, 0, 0],
         [0, 0, 0, 1, 0],
         [1, 0, 0, 0, 0]]])

In [None]:
lbl_batch[:2, :], lbl_batch.argmax(dim=1)[:2]

(tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0

## ALN file based

All simulated reads are stored in a `.fastq` and `.aln` file. We use the `.aln` file because it includes all metadata information required to retrieve both the specie and the relative position of each read.

`AlnFileDataset` allows to load reads and metadata stored in ART Illumina simulated ALN file.

In [None]:
p2aln = pfs.data / 'ncbi/simreads/cov/single_1seq_150bp/single_1seq_150bp.aln'
aln = AlnFileReader(p2aln)

for i, d in enumerate(aln):
    pass
print(f"{i+1:,d} reads in this ALN file")

40,200 reads in this ALN file


In [None]:
#| export
class AlnFileDataset(IterableDataset):
    """Load data and metadata from ALN file, yield BHE sequence, OHE label, OHE position tensors + metadata

    The iterator yield tupple (read tensor,(label tensor, position tensor)):
    
    - kmer read tensor in base hot encoded format (shape [k, 5])
    - label tensor in one hot encoded format (shape [187])
    - position tensor in one hote encoded format (shape [10])

    It also optionally returns a dictionary of the read metadata available in the ALN file.
    """

    base2encoding = {
        'A': [1,0,0,0,0], 
        'C': [0,1,0,0,0], 
        'G': [0,0,1,0,0], 
        'T': [0,0,0,1,0], 
        'N': [0,0,0,0,1],
        '-': [0,0,0,0,1],
        }
    nb_labels = 187
    nb_pos = 10
    
    def __init__(
        self,
        p2file:str|Path,            # path to the file to read
        label:int = 118,            # label for this batch (assuming all reads are from the same species)
        return_metadata:bool=False  # yield each read metadata as a dictionary when Trud
    ):
        self.p2file = safe_path(p2file)
        self.aln = AlnFileReader(self.p2file)
        self.label = label
        self.return_metadata = return_metadata

    def __iter__(self):
        for d in self.aln:
            metadata = self.aln.parse_definition_line_with_position(d['definition line'])
            seq = d['read_seq_aligned']
            seq_bhe = torch.tensor(list(map(self._bhe_fn, seq)))
            lbl_ohe = torch.zeros(self.nb_labels)
            lbl_ohe[int(self.label)] = 1
            pos = metadata['read_pos']
            pos_ohe = torch.zeros(self.nb_pos)
            pos_ohe[int(pos-1)] = 1
            if self.return_metadata:
                yield seq_bhe, (lbl_ohe, pos_ohe), metadata
            else:   
                yield seq_bhe, (lbl_ohe, pos_ohe)
    
    def _bhe_fn(self, base:str) -> list[int]:
        """Convert a base to a one hot encoding vector"""
        return self.base2encoding[base]

In [None]:
show_doc(AlnFileDataset)

---

[source](https://github.com/vtecftwy/metagentorch/blob/main/metagentorch/cnn_virus/data.py#L465){target="_blank" style="float:right; font-size:smaller"}

### AlnFileDataset

>      AlnFileDataset (p2file:str|pathlib.Path, label:int=118,
>                      return_metadata:bool=False)

*Load data and metadata from ALN file, yield BHE sequence, OHE label, OHE position tensors + metadata

The iterator yield tupple (read tensor,(label tensor, position tensor)):

- kmer read tensor in base hot encoded format (shape [k, 5])
- label tensor in one hot encoded format (shape [187])
- position tensor in one hote encoded format (shape [10])

It also optionally returns a dictionary of the read metadata available in the ALN file.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| p2file | str \| pathlib.Path |  | path to the file to read |
| label | int | 118 | label for this batch (assuming all reads are from the same species) |
| return_metadata | bool | False | yield each read metadata as a dictionary when Trud |

Create a dataset from a ALN file, then pass it to the class `Dataloader` to create batches of data

In [None]:
ds = AlnFileDataset(p2aln, label=118)

dl = DataLoader(ds, batch_size=16, shuffle=False)

for i, (s, (lbl,pos)) in enumerate(dl):
    print(f"Sample {i+1:000d}:")
    print(f"  Shapes of each data tensor:")
    print(f"    seq: {s.shape}, label: {lbl.shape}, pos: {pos.shape}")
    if i+1 >= 2:break

Sample 1:
  Shapes of each data tensor:
    seq: torch.Size([16, 150, 5]), label: torch.Size([16, 187]), pos: torch.Size([16, 10])
Sample 2:
  Shapes of each data tensor:
    seq: torch.Size([16, 150, 5]), label: torch.Size([16, 187]), pos: torch.Size([16, 10])


When an `AlnFileDataset` instance with `return_metadata` set as True is passed to `DataLoader`, the dataloader will yield (read_tensor,(lbl_tensor, pos_tensor), metadata) where:

- read_tensor is the batch of kmer read tensor in base hot encoded format (shape [bs, k, 5])
- lbl_tensor is the batch of label tensor in one hot encoded format (shape [bs, 187])
- pos_tensor is the batch of position tensor in one hote encoded format (shape [bs, 10])
- metadata is a dictionary of form `key:list` like:
```python
{
    'aln_start_pos': ['14370', '15144', '2971'], 
    'readid': ['2591237:ncbi:1-40200', '2591237:ncbi:1-40199', '2591237:ncbi:1-40198'], 
    'readnb': ['40200', '40199', '40198'], 
    'refseq_strand': ['+', '-', '-'], 
    'refseqid': ['2591237:ncbi:1', '2591237:ncbi:1', '2591237:ncbi:1'], 
    'refseqnb': ['1', '1', '1'], 
    'refsource': ['ncbi', 'ncbi', 'ncbi'], 
    'reftaxonomyid': ['2591237', '2591237', '2591237'], 
    'read_pos': tensor([ 5,  6,  1])
}
```

In [None]:
ds = AlnFileDataset(p2aln, label=118, return_metadata=True)
dl = DataLoader(ds, batch_size=3, shuffle=False)

for i, (s, (lbl,pos), d) in enumerate(dl):
    print(f"Sample {i+1:000d}:")
    print(f"  Shapes of each data tensor:")
    print(f"    seq: {s.shape}, label: {lbl.shape}, pos: {pos.shape}")
    print(f"  Metadata dictionary for the batch")
    print(f"    {d}")
    if i+1 >= 2:break

Sample 1:
  Shapes of each data tensor:
    seq: torch.Size([3, 150, 5]), label: torch.Size([3, 187]), pos: torch.Size([3, 10])
  Metadata dictionary for the batch
    {'aln_start_pos': ['14370', '15144', '2971'], 'readid': ['2591237:ncbi:1-40200', '2591237:ncbi:1-40199', '2591237:ncbi:1-40198'], 'readnb': ['40200', '40199', '40198'], 'refseq_strand': ['+', '-', '-'], 'refseqid': ['2591237:ncbi:1', '2591237:ncbi:1', '2591237:ncbi:1'], 'refseqnb': ['1', '1', '1'], 'refsource': ['ncbi', 'ncbi', 'ncbi'], 'reftaxonomyid': ['2591237', '2591237', '2591237'], 'read_pos': tensor([5, 6, 1])}
Sample 2:
  Shapes of each data tensor:
    seq: torch.Size([3, 150, 5]), label: torch.Size([3, 187]), pos: torch.Size([3, 10])
  Metadata dictionary for the batch
    {'aln_start_pos': ['15485', '16221', '18953'], 'readid': ['2591237:ncbi:1-40197', '2591237:ncbi:1-40196', '2591237:ncbi:1-40195'], 'readnb': ['40197', '40196', '40195'], 'refseq_strand': ['-', '-', '-'], 'refseqid': ['2591237:ncbi:1', '259123

## Handle long reads

### Preprocessing

CNN Virus handles 50-mer reads only. Longer reads need to be split into 50-mer reads by sliding a window of size 50 along the read sequence. The set of multiple 50-mer reads is then used as input to the model. After prediction on the set of 50-mer reads, the final prediction is obtained by filtering 50-mer prediction that yield a hight enough probability and then voting for the most predicted species.

A k-mer (150-mer) read will be split in k-49 (101) 50-mer reads which will yield k-49 (101) probability tensors. The final prediction will be the species with the highest number of votes, as long as their respective probabilities > 90%.

In [None]:
#| export
def split_kmer_batch_into_50mers(
    kmer: torch.Tensor        # tensor representing a batch of k-mer reads, BHE format, shape [b, k, 5]
    ) -> torch.Tensor:
    """Convert a batch of k-mer reads into 50-mer reads, by shifting the k-mer one base at a time.

    Shapes: for a batch of `b` k-mer reads, returns a batch of `b * (k - 49)` 50-mer reads

    Technical Note: we use advanced indexing of the tensor to create the 50-mer and roll them, with no loop.
    """
    b = kmer.shape[0]
    k = kmer.shape[1]
    n = k - 49
    # Create an array of indices for rolling
    idx_rows = np.arange(n)[:, None]    # shape (n, 1) broadcast n rows to create the split effect
    idx_bases = np.arange(k)            # shape (k) creates the shifting effect
    indices = idx_rows + idx_bases      # shape (n, k)
    rolled_indices = indices  % k       # shape (n, k) Modulo to create the rolling effect

    # Create a rolled tensor using broadcasting
    rolled_tensor = kmer[:, rolled_indices, :].reshape(-1,k, 5)

    return rolled_tensor[:, :50,:] # keep only the first 50 bases

We can test the function, with a test tensor designed to make the validation easier.

We create a tensor with a batch of b k-mer reads in the following format:
```ascii
read 1: 10001, 10002, 10003, .... 10149
read 2: 20001, 20002, 20003, .... 20149
read 3: 30001, 30002, 30003, .... 30149
    ...
```

We then add an additional dimension to simulate the BHE encoding, by repeating the same value 5 times.

In [None]:
b = 4
k = 150

# Build a tensor of shape (b,k) with the pattern described above
kmer = torch.stack([(i+1) * 10000 + torch.arange(k) for i in range(b)], dim=0)

# simulate the BHE buy duplicating the 5 dimensions
kmer = torch.stack([kmer]*5, dim=2)

print(kmer.shape)
# Slice the tensor to only show the first and last three bases, and only one of the 5 BHE dimensions
idx_bases = np.array([0,1,2,147,148,149])
kmer[:, idx_bases, 0]

torch.Size([4, 150, 5])


tensor([[10000, 10001, 10002, 10147, 10148, 10149],
        [20000, 20001, 20002, 20147, 20148, 20149],
        [30000, 30001, 30002, 30147, 30148, 30149],
        [40000, 40001, 40002, 40147, 40148, 40149]])

We see that the function `split_reads` returns a tensor with shape [b, k-49, 50, 5] as expected.

In [None]:
split_tensor = split_kmer_batch_into_50mers(kmer)
split_tensor.shape

torch.Size([404, 50, 5])

Lets slice only the first and last 2 50-mer reads for each batch and a selection of bases to confirm that the batch is properly split.

In [None]:
dim0_idxs = [0,1,99,100,101,102,200,201,202,203,301,302]
dim1_idxs = [0,1,25,26,48,49]
dim2_idxs = [0,1,2,3,4]
print(split_tensor[np.ix_(dim0_idxs, dim1_idxs, dim2_idxs)][:,:,0])

tensor([[10000, 10001, 10025, 10026, 10048, 10049],
        [10001, 10002, 10026, 10027, 10049, 10050],
        [10099, 10100, 10124, 10125, 10147, 10148],
        [10100, 10101, 10125, 10126, 10148, 10149],
        [20000, 20001, 20025, 20026, 20048, 20049],
        [20001, 20002, 20026, 20027, 20049, 20050],
        [20099, 20100, 20124, 20125, 20147, 20148],
        [20100, 20101, 20125, 20126, 20148, 20149],
        [30000, 30001, 30025, 30026, 30048, 30049],
        [30001, 30002, 30026, 30027, 30049, 30050],
        [30099, 30100, 30124, 30125, 30147, 30148],
        [30100, 30101, 30125, 30126, 30148, 30149]])


### Postprocessing

The model CNN Virus requires post inderence processing of the predictions:
- before inference, each k-mer was split into $k-49$ 50-mer, where were presented to the model for inference. Each k-mer led to $k-49$ predictions (probability tensors)
- now, we need to combine these $n = k-49 $ probability tensor into a single prediction for the original k-mer. This is done by first filtering all 50-reads that gave a max probability lower than a specific threshold (0.9 by default) and then combining the prediction by a simple vote

Now we can build the function `combine_predictions` step by step.

First, lets create a test probabilities tensor representing a batch of `bs` k-mers, each split into `k-49` 50-mers, with a probability tensor of shape `[bs, k-49, nb_class]`.

In [None]:
def create_test_probs(bs=4, k=150, nb_class=187):
    print(f"Creating a batch of probabilities for {bs} {k}-mer reads with {nb_class} classes:")

    probs = torch.from_numpy(np.random.rand(bs * n ,nb_class)).reshape(-1,n,nb_class)
    print(f"probs {probs.shape}")
    return probs

bs, k, nb_class = 4, 55, 10
n = k - 49
probs = create_test_probs(bs, k, nb_class)

Creating a batch of probabilities for 4 55-mer reads with 10 classes:
probs torch.Size([4, 6, 10])


**Step 1**: Extract the predictions from probabilities for each 50-mer read, then filter the predictions with a probability lower than a threshold.

In [None]:
threshold = 0.8
preds = probs.argmax(dim=2)
print(f"preds {preds.shape}:\n",preds)

INVALID = 9999
invalid_labels_filter = probs.max(dim=2).values <= threshold
preds[invalid_labels_filter] = INVALID
print(f"preds {preds.shape}:\n",preds)

preds torch.Size([4, 6]):
 tensor([[9, 6, 2, 0, 5, 3],
        [7, 6, 8, 7, 6, 1],
        [5, 6, 4, 8, 5, 8],
        [8, 2, 8, 8, 4, 9]])
preds torch.Size([4, 6]):
 tensor([[   9,    6,    2,    0,    5,    3],
        [   7,    6,    8, 9999,    6,    1],
        [   5,    6,    4,    8, 9999,    8],
        [   8,    2,    8,    8,    4,    9]])


**Step 2**: Extract the prediction unique values and their counts.

> Code explanation note: 
>
> We do not want to use a python loop. This is why we use `torch.unique(preds)` with `return_inverse`. The function returns a flat tensor with the unique values accross the entire `preds` and a tensor of the same shape as `preds` with the indices of the unique values in `preds`.
>
```python
    x = torch.tensor([[10, 30, 20, 30],[20,20,20,50]], dtype=torch.long)

    uniques, inverse = torch.unique(x, sorted=True, return_inverse=True)

    > unique:  tensor([10, 20, 30, 50])
      inverse: tensor([[0, 2, 1, 2],
                       [1, 1, 1, 3]]))
```


In [None]:
# Get unique values and their counts for the entire tensor
unique_values, inverse_indices = torch.unique(preds, return_inverse=True)
inverse_indices = inverse_indices.view(preds.shape)
print(f"unique_values {unique_values.shape}:\n",unique_values)
print(f"inverse_indices {inverse_indices.shape}:\n",inverse_indices)

unique_values torch.Size([11]):
 tensor([   0,    1,    2,    3,    4,    5,    6,    7,    8,    9, 9999])
inverse_indices torch.Size([4, 6]):
 tensor([[ 9,  6,  2,  0,  5,  3],
        [ 7,  6,  8, 10,  6,  1],
        [ 5,  6,  4,  8, 10,  8],
        [ 8,  2,  8,  8,  4,  9]])


**Step 3**: Compute the counts of unique values in each 50-mer read and store it in a tensor of shape (batch size, nb of unique values in the the shole tensor `preds`).

In [None]:
# Create a tensor to hold the counts (shape (bs, nb_unique_values_across_the_batch))
counts = torch.zeros((preds.shape[0], unique_values.shape[0]), dtype=torch.int64)
# Count occurrences of each unique value per row
counts.scatter_add_(dim=1, index=inverse_indices, src=torch.ones_like(inverse_indices, dtype=torch.int64))
print(f"unique_values {unique_values.shape}:\n", unique_values)
print(f"counts {counts.shape}:\n", counts)

unique_values torch.Size([11]):
 tensor([   0,    1,    2,    3,    4,    5,    6,    7,    8,    9, 9999])
counts torch.Size([4, 11]):
 tensor([[1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0],
        [0, 1, 0, 0, 0, 0, 2, 1, 1, 0, 1],
        [0, 0, 0, 0, 1, 1, 1, 0, 2, 0, 1],
        [0, 0, 1, 0, 1, 0, 0, 0, 3, 1, 0]])


**Step 4**: Get the index of the most frequent value for each 50-mer read, excluding the INVALID placeholder, and extract the corresponding values into a tensor of shape (`bs`).

In [None]:
# get value most voted per 50-read (vertical tensor), excluding the placeholder INVALID
most_voted_idxs = counts[:, :-1].argmax(dim=1)
most_voted_value = unique_values[most_voted_idxs][:, None]
print(f"most_voted_value {most_voted_value.shape}:\n", most_voted_value)

most_voted_value torch.Size([4, 1]):
 tensor([[0],
        [6],
        [8],
        [8]])


We put it all in the function `combine_predictions`

In [None]:
# | export
def combine_predictions(
    label_probs: torch.Tensor,   # Probabilities for the labels classes for each 50-mer
    pos_probs: torch.Tensor,     # Probabolities for the position classes for each 50-mer
    threshold: float = 0.9       # Threshold to consider a prediction as valid
    ):
    """Combine a batch of 50-mer probabilities into one batch of final prediction for label and position

    Note: the input must be of shape (batch_size, n, c) where n is k-49 and c is the nb of labels or positions
    """
    INVALID = 9999
    is_batch = True
    
    assert label_probs.dim() == pos_probs.dim(), "Input do not have the same nb of dimensions"

    if label_probs.dim() != 3:
        is_batch = False
        print('Converting probability tensors to 3 dimensions')
        label_probs = label_probs.unsqueeze(0)
        pos_probs = pos_probs.unsqueeze(0)

    # Extract the prediction for each 50-mer read
    label_preds = label_probs.argmax(dim=2) # shape (bs, nb_50mers)
    pos_preds = pos_probs.argmax(dim=2)
    # print(label_preds.shape, pos_preds.shape, label_probs.shape, pos_probs.shape)

    # Identify reads with too low prediction probability and replace their prediction by INVALID
    invalid_labels_filter = label_probs.max(dim=2).values <= threshold
    # print(invalid_labels_filter.shape)
    label_preds[invalid_labels_filter] = INVALID
    pos_preds[invalid_labels_filter] = INVALID

    def most_common_value(preds, invalid_filter):
        # print(f"_preds {preds.shape}:\n",preds)
        # Get unique values and their counts for the entire tensor
        unique_values, inverse_indices = torch.unique(preds, return_inverse=True)
        inverse_indices = inverse_indices.view(preds.shape)
        # print(f"unique_values {unique_values.shape}:\n",unique_values)
        # print(f"inverse_indices {inverse_indices.shape}:\n",inverse_indices)

        # Create a tensor to hold the counts (shape (bs, nb_unique_values_across_the_batch))
        counts = torch.zeros((preds.shape[0], unique_values.shape[0]), dtype=torch.int64)
        # Count occurrences of each unique value per row
        counts.scatter_add_(dim=1, index=inverse_indices, src=torch.ones_like(inverse_indices, dtype=torch.int64))
        # print(f"counts {counts.shape}:\n", counts)

        # get value most voted per 50-read (vertical tensor), excluding the placeholder INVALID
        most_voted_value = unique_values[counts[:, :-1].argmax(dim=1)][:, None]
        # print(f"most_voted_value {most_voted_value.shape}:\n", most_voted_value)
        return most_voted_value

    combined_labels = most_common_value(label_preds, invalid_labels_filter)
    combined_pos = most_common_value(pos_preds, invalid_labels_filter)

    # Concatenate combined_label and combined_position
    combined_preds = torch.cat([combined_labels, combined_pos], dim=1)

    return combined_preds if is_batch else combined_preds.squeeze(0)

Testing the function with the same test tensor

In [None]:
combine_predictions(probs, probs[:, :,:10])

tensor([[2, 2],
        [6, 6],
        [6, 6],
        [8, 8]])

Testing the function with test tensor correponding to real probabilities

In [None]:
bs, k = 1, 150
label_probs = create_test_probs(bs, k, 187)
pos_probs = create_test_probs(bs, k, 10)

print('\ncombine for a batch of data')
p = combine_predictions(label_probs,pos_probs)
print(p)
print('\ncombine for a single sample (not a batch)')
p = combine_predictions(label_probs[0,:,:],pos_probs[0,:,:])
print(p)

Creating a batch of probabilities for 1 150-mer reads with 187 classes:
probs torch.Size([1, 6, 187])
Creating a batch of probabilities for 1 150-mer reads with 10 classes:
probs torch.Size([1, 6, 10])

combine for a batch of data
tensor([[32,  4]])

combine for a single sample (not a batch)
Converting probability tensors to 3 dimensions
tensor([32,  4])


# Deprecated Items

When any of the following classes and functions is called, it will raise an exception with an error message indicating how to handle the required code refactoring.

Example:
```python
---------------------------------------------------------------------------
DeprecationWarning                        Traceback (most recent call last)
Input In [484], in <cell line: 1>()
----> 1 combine_prediction_batch(label_probs, pos_probs)

Input In [481], in combine_prediction_batch(*args, **kwargs)
      3 """Deprecated"""
      4 msg = """
      5 `combine_prediction_batch` is deprecated. 
      6 Use `combine_predictions` instead, with same capabilities and more."""
----> 7 raise DeprecationWarning(msg)

DeprecationWarning: 
    `combine_prediction_batch` is deprecated. 
    Use `combine_predictions` instead, with same capabilities and more.
```

In [None]:
# | export
def combine_prediction_batch(*args, **kwargs):
    """Deprecated"""
    msg = """
    `combine_prediction_batch` is deprecated. 
    Use `combine_predictions` instead, with same capabilities and more."""
    raise DeprecationWarning(msg)

In [None]:
#| hide
test_fail(
    combine_prediction_batch, 
    args=[probs, probs[:, :,:10]],
    msg="`combine_prediction_batch` is deprecated.", 
    contains="is deprecated."
    )

In [None]:
#| hide
nbdev_export()