# Prediction on yellow fever simulated reads (Sqlite)

This is a a reference notebook for prediction inference on yellow fever simulated reads, using the CNN_Virus original model and saving predictions, probabilities and metadata in a sqlite database.

- Simulated reads from an aligned file generated by ART Illumina simulator (`*.aln file`).
- Uses the generator provided by `AlnFileReader.cnn_virus_input_generator` to read batches of simulates reads and their metadata.
- Uses the `cnn_virus` model to predict the label and position probabilities and classes for each simreads.
- Creates a prediction report and saves it in a sqlite database for easier retrieval and analysis later.

> **Note**: 
>
>When an `*aln` file counts a very large number of simulated reads, running a prediction on all of them is very time consuming. Thererofe, we also provide a function `skip_existing_predictions` applied to the generator, which allows to skip all simulated reads down to the last simulated read for which a prediction was already saved into the database. This allows to build the database step by step.

# 1. Imports and setup environment

In [1]:
# Install required custom packages if not installed yet.
import importlib.util
if not importlib.util.find_spec('ecutilities'):
    print('installing package: `ecutilities`')
    ! pip install -qqU ecutilities
else:
    print('`ecutilities` already installed')
if not importlib.util.find_spec('metagentools'):
    print('installing package: `metagentools')
    ! pip install -qqU metagentools
else:
    print('`metagentools` already installed')

`ecutilities` already installed
`metagentools` already installed


In [2]:
# Import all required packages
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import re
import sqlite3
import time

from datetime import datetime
from ecutilities.core import files_in_tree
from ecutilities.ipython import nb_setup
from functools import partial
from IPython.display import display, update_display, Markdown, HTML
from nbdev import show_doc
from pandas import HDFStore
from pathlib import Path
from pprint import pprint
from tqdm.notebook import tqdm, trange
from typing import List, Tuple, Dict, Any, Generator

# Setup the notebook for development
nb_setup()

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'}
import tensorflow as tf
from tensorflow.python.client import device_lib
from tensorflow.keras.models import load_model
print(f"Tensorflow version: {tf.__version__}\n")

from metagentools.cnn_virus.data import _base_hot_encode, strings_to_tensors
from metagentools.cnn_virus.data import split_kmer_into_50mers, combine_prediction_batch
from metagentools.cnn_virus.data import FastaFileReader, FastqFileReader, AlnFileReader
from metagentools.cnn_virus.data import OriginalLabels
from metagentools.cnn_virus.data import string_input_batch_to_tensors, split_kmer_batch_into_50mers
from metagentools.cnn_virus.architecture import create_model_original
from metagentools.core import ProjectFileSystem, TextFileBaseReader, SqliteDatabase

Set autoreload mode
Tensorflow version: 2.8.2



List all computing devices available on the machine

In [3]:
devices = device_lib.list_local_devices()
print('\nDevices:')
for d in devices:
    t = d.device_type
    name = d.physical_device_desc
    l = [item.split(':', 1) for item in name.split(', ')]
    name_attr = dict([x for x in l if len(x)==2])
    dev = name_attr.get('name', ' ')
    print(f"  - {t}  {d.name} {dev:25s}")


Devices:
  - CPU  /device:CPU:0                          
  - GPU  /device:GPU:0  NVIDIA GeForce GTX 1050 


# 2. Setup paths to files

Key folders and system information

In [4]:
pfs = ProjectFileSystem()
pfs.info()

Running linux on local computer
Device's home directory: /home/vtec
Project file structure:
 - Root ........ /home/vtec/projects/bio/metagentools 
 - Data Dir .... /home/vtec/projects/bio/metagentools/data 
 - Notebooks ... /home/vtec/projects/bio/metagentools/nbs


- `p2model`: path to file with saved original pretrained model
- `p2virus_labels` path to file with virus names and labels mapping for original model

In [5]:
p2model = pfs.data / 'saved/cnn_virus_original/pretrained_model.h5'
assert p2model.is_file(), f"No file found at {p2model.absolute()}"

p2virus_labels = pfs.data / 'CNN_Virus_data/virus_name_mapping'
assert p2virus_labels.is_file(), f"No file found at {p2virus_labels.absolute()}"

Path to the simulated read we want to use

In [6]:
fnames = files_in_tree(pfs.data / 'ncbi/simreads/yf', pattern='69seq')

simreads
  |--yf
  |    |--single_1seq_150bp
  |    |--single_69seq_150bp
  |    |    |--single_69seq_150bp.fq (0)
  |    |    |--single_69seq_150bp.aln (1)
  |    |--paired_1seq_150bp
  |    |--paired_69seq_150bp
  |    |    |--paired_69seq_150bp1.fq (2)
  |    |    |--paired_69seq_150bp2.fq (3)
  |    |    |--paired_69seq_150bp1.aln (4)
  |    |    |--paired_69seq_150bp2.aln (5)


In [7]:
file_stem = 'single_69seq_150bp'

p2aln = pfs.data / f"ncbi/simreads/yf/{file_stem[:-2] if file_stem[-1] in ['1', '2'] else file_stem}/{file_stem}.aln"
assert p2aln.exists()

aln = AlnFileReader(p2aln)
print(f"Reading alignment file: {p2aln.name}:\n")
for i, aln_read in enumerate(aln):
    pass
print(f"  - {i+1:,d} simulated reads in file '{p2aln.name}' from {len(aln.header['reference sequences'])} reference sequences.")

print('  - ART command: ',aln.header['command'])
print('  - Reference Sequences:')
print('     ','\n      '.join(aln.header['reference sequences']))

Reading alignment file: single_69seq_150bp.aln:

  - 1,161,034 simulated reads in file 'single_69seq_150bp.aln' from 69 reference sequences.
  - ART command:  /usr/bin/art_illumina -i /home/vtec/projects/bio/metagentools/data/ncbi/refsequences/yf/yf_2023_yellow_fever.fa -ss HS25 -l 150 -f 250 -o /home/vtec/projects/bio/metagentools/data/ncbi/simreads/yf/single_69seq_150bp/single_69seq_150bp -rs 1724163599
  - Reference Sequences:
      @SQ	11089:ncbi:1	1	AY968064	11089	ncbi	Angola_1971	10237
      @SQ	11089:ncbi:2	2	U54798	11089	ncbi	Ivory_Coast_1982	10237
      @SQ	11089:ncbi:3	3	DQ235229	11089	ncbi	Ethiopia_1961	10237
      @SQ	11089:ncbi:4	4	AY572535	11089	ncbi	Gambia_2001	10237
      @SQ	11089:ncbi:5	5	MF405338	11089	ncbi	Ghana_Hsapiens_1927	10237
      @SQ	11089:ncbi:6	6	U21056	11089	ncbi	Senegal_1927	10237
      @SQ	11089:ncbi:7	7	AY968065	11089	ncbi	Uganda_1948	10237
      @SQ	11089:ncbi:8	8	JX898871	11089	ncbi	ArD114896_Senegal_1995	10237
      @SQ	11089:ncbi:9	9	JX898872	11089

# 3. (Optional) Test each inference step

Let's test the steps to prepare model inputs, using a small batch size:
1. create the generator using teh `.aln` file to yiel pairs of batches (metadata and reads strings) using `aln.cnn_virus_input_generator`
2. transform the batch of string reads into a base hot encoded tensor, using the preprocessing funtion `string_input_batch_to_tensors`
3. split each kmer read into (k-49) 50-mer reads to present to the model

In [8]:
show_doc(aln.cnn_virus_input_generator)

---

[source](https://github.com/vtecftwy/metagentools/blob/main/metagentools/cnn_virus/data.py#LNone){target="_blank" style="float:right; font-size:smaller"}

### AlnFileReader.cnn_virus_input_generator

>      AlnFileReader.cnn_virus_input_generator (label:int=118, bs:int=32)

Create a generator yielding a metadata batch (dict) and a read batch (tensor of strings)

The metadata dictionary contains lists of metadata for each read in the batch. For example:
``` 
{
    'readid': ['2591237:ncbi:1-40200','2591237:ncbi:1-40199','2591237:ncbi:1-40198', ...],
    'refseqid': ['2591237:ncbi:1','2591237:ncbi:1','2591237:ncbi:1', ...],
    'read_pos': [5, 6, 1, ...],
    'refsource': ['ncbi', 'ncbi', 'ncbi', ...],
    ...
}
```

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| label | int | 118 | label for this batch (assuming all reads are from the same species) |
| bs | int | 32 | batch size |
| **Returns** | **tuple** |  | **dict of metadata list and tensor of strings** |

We need to define the model label for yellow fever and a batch size.

In [14]:
OriginalLabels().search(s='yellow')

Yellow_fever_virus. Label: 118


In [15]:
show_doc(string_input_batch_to_tensors)

---

[source](https://github.com/vtecftwy/metagentools/blob/main/metagentools/cnn_virus/data.py#LNone){target="_blank" style="float:right; font-size:smaller"}

### string_input_batch_to_tensors

>      string_input_batch_to_tensors (b:tensorflow.python.framework.ops.Tensor,
>                                     k:int=150)

Function converting a batch of bp strings into three tensors: (x_seqs, (y_labels, y_pos))

Expects input strings to have the format: 'read sequence     label   position' where:

- read sequence: a string of length k (kmer)
- label: an integer between 0 and 186 representing the read virus label
- position: an integer between 0 and 9 representing the read position in the genome

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| b | Tensor |  | batch of strings representing the inputs (kmer, label, position) |
| k | int | 150 | maximum read length in the batch |

In [12]:
b = 8
true_label = 118

aln.reset_iterator()
for batch_meta, batch_reads in aln.cnn_virus_input_generator(bs=b, label=true_label):
    reads_kmer, (labels_kmer, positions_kmer) = string_input_batch_to_tensors(batch_reads, k=150)
    break

print('Review metadata batch yielded by the generator:')
print(f"  List of metadata keys:")
print('  -','\n  - '.join(batch_meta.keys()))
print(f"  'readid' included in this batch':", ', '.join(batch_meta['readid']))
print('\nReview batch of string reads yielded by the generator:')
print(f"  - Shape: {batch_reads.shape}")
print('\nReview the read kmer tensor after preprocessing:')
print(f"  - Shape: {reads_kmer.shape}")
print('\nReview ground truth tensors:')
print(f"  - Shape true label tensor:    {labels_kmer.shape}")
print(f"  - Shape true position tensor: {positions_kmer.shape}")

Review metadata batch yielded by the generator:
  List of metadata keys:
  - aln_start_pos
  - readid
  - readnb
  - refseq_strand
  - refseqid
  - refseqnb
  - refsource
  - reftaxonomyid
  - read_pos
  'readid' included in this batch': 11089:ncbi:1-17000, 11089:ncbi:1-16999, 11089:ncbi:1-16998, 11089:ncbi:1-16997, 11089:ncbi:1-16996, 11089:ncbi:1-16995, 11089:ncbi:1-16994, 11089:ncbi:1-16993

Review batch of string reads yielded by the generator:
  - Shape: (8,)

Review the read kmer tensor after preprocessing:
  - Shape: (8, 150, 5)

Review ground truth tensors:
  - Shape true label tensor:    (8, 187)
  - Shape true position tensor: (8, 10)


The model only accepts 50-mer reads, so we need to split kmer reads into 50mers. For each kmer read, k-49 50mer reads will be generated, by shifting a window of 50 nucleotides by 1 nucleotide at a time.

In [13]:
show_doc(split_kmer_batch_into_50mers)

---

[source](https://github.com/vtecftwy/metagentools/blob/main/metagentools/cnn_virus/data.py#LNone){target="_blank" style="float:right; font-size:smaller"}

### split_kmer_batch_into_50mers

>      split_kmer_batch_into_50mers
>                                    (kmer:tensorflow.python.framework.ops.Tenso
>                                    r)

Converts a batch of k-mer reads into several 50-mer reads, by shifting the k-mer one base at a time.

for a batch of `b` k-mer reads, returns a batch of `b - 49` 50-mer reads

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| kmer | Tensor | tensor representing a batch of k-mer reads, after base encoding |

In [14]:
reads_50mer = split_kmer_batch_into_50mers(reads_kmer)
nb_50mer_per_kmer = reads_kmer.shape[1]-49 
nb_50mer_reads = (nb_50mer_per_kmer) * reads_kmer.shape[0]
assert reads_50mer.shape == (nb_50mer_reads, 50,5)

print(f"Each kmer is split into {nb_50mer_per_kmer} 50-mer reads. Total 50-mer reads in a batch: {nb_50mer_reads}")
print('\nReview the reads tensor:')
print(f"  - Shape kmer tensor:   {reads_kmer.shape}")
print(f"  - Shape 50-mer tensor: {reads_50mer.shape}")

Each kmer is split into 101 50-mer reads. Total 50-mer reads in a batch: 808

Review the reads tensor:
  - Shape kmer tensor:   (8, 150, 5)
  - Shape 50-mer tensor: (808, 50, 5)


If all runs smoothly, our generator and preprocessing are working fine. We can run the prediction loop

# 4. Run the Loop

## Utility Functions

In [28]:
def create_tables(db, dry_run=False):

    top_n = 5

    db.connect()
    # Create table for predictions and its index
    pred_cols_str = 'readid refseqid refsource refseq_strand taxonomyid'.split(' ')
    pred_cols_int = 'lbl_true lbl_pred pos_true pos_pred'.split(' ')
    top_pred_cols = [f"top_{top_n}_lbl_pred_{i}" for i in range(top_n)]
    query = """
    CREATE TABLE IF NOT EXISTS predictions (
        id INTEGER PRIMARY KEY,
    """
    for col in pred_cols_str:
        query += f"{col} TEXT, "
    for col in pred_cols_int:
        query += f"{col} INTEGER, "
    for col in top_pred_cols:
        query += f"{col} INTEGER, "
    query = query[:-2]+') ;'
    print(query)
    if not dry_run: db.execute(query)

    query = "CREATE INDEX IF NOT EXISTS idx_preds_refseqid ON predictions (refseqid);"
    print(query)
    if not dry_run: db.execute(query)

    query = "CREATE INDEX IF NOT EXISTS idx_preds_3 ON predictions (readid, refseqid, pos_true);"
    print(query)
    if not dry_run: db.execute(query)

    # Create table for probabilities (one per 50-mer in order to keep small nb or columns in table)
    query = f"""
    CREATE TABLE IF NOT EXISTS label_probabilities (
        id INTEGER PRIMARY KEY,
        read_kmer_id TEXT,
        read_50mer_nb INTEGER,
        refseqid TEXT, 
    """
    query += ' '.join([f"prob_{i:03d} REAL, " for i in range(187)])
    query += "FOREIGN KEY (read_kmer_id) REFERENCES predictions(readid)"
    query += ')'
    print(query)
    if not dry_run: db.execute(query)

    query = "CREATE INDEX IF NOT EXISTS idx_probs_refseqid ON label_probabilities (refseqid);"
    print(query)
    if not dry_run: db.execute(query)

    query = "CREATE INDEX IF NOT EXISTS idx_probs_ids ON label_probabilities (refseqid, id);"
    print(query)
    if not dry_run: db.execute(query)

    query = "CREATE INDEX IF NOT EXISTS idx_probs_3 ON label_probabilities (read_kmer_id, read_50mer_nb, refseqid);"
    print(query)
    if not dry_run: db.execute(query)


    # Create view joining predictions and label_probabilities
    view_name = 'preds_probs'

    # top prediction columns from table predictions:
    top_lbl_pred_n = ','.join([f"p.top_5_lbl_pred_{i}" for i in range(5)])

    # probabilities columns from table label_probabilities 
    probs_n = ','.join([f"lp.prob_{i:03d}" for i in range(187)])

    query = f"""
    CREATE VIEW IF NOT EXISTS {view_name} AS
    SELECT 
        lp.id,
        lp.refseqid,
        p.lbl_true, p.lbl_pred,
        p.pos_true, p.pos_pred,
        {top_lbl_pred_n},
        lp.read_kmer_id, lp.read_50mer_nb,
        {probs_n}
    FROM 
        label_probabilities lp
    INNER JOIN 
        predictions p
    ON 
        lp.read_kmer_id = p.readid
    """
    print(query)
    if not dry_run: db.execute(query)

# p2db = pfs.data / '/mnt/s/metagentools/ncbi/infer_results/yf-ncbi/test-2.db'
# db = SqliteDatabase(p2db)
# db.close()
# create_tables(db, dry_run=False)
# db.print_schema()
# db.close()

In [13]:
def skip_existing_predictions(gen: Generator,           # generator of batches (metadata, reads)
                              db: sqlite3.Connection,   # path to the sqlite database 
                              bs: int                   # batch size
                              ) -> Tuple[int, int]:     # number of batches and kmer reads skipped
    
    # Identify the readnb for the last saved prediction
    print('Checking predictions already in database...')
    last_predictions_id = db.execute("SELECT MAX(id) FROM predictions").fetchone()[0]
    if last_predictions_id is None:
        return 0, 0
    else:
        last_readid = db.execute(f"SELECT readid FROM predictions WHERE id = {last_predictions_id};").fetchone()[0]
        
    # print(f"Database includes {nb_predictions:,d} predictions, corresponding to {nb_predictions//bs:,d} batches")
    print(f"Last prediction id: {last_predictions_id:,d} for kmer read '{last_readid}'")
    print(f"Skipping already processed predictions ...")

    for i, (batch_meta, batch_reads) in enumerate(gen):
        if i%100 == 0: 
            print(f"   Skipped first {(i+1)*bs:,d} kmer reads ({i+1:,d} batches)")
 
        if last_readid in batch_meta['readid']:
            print(f"   Reached last batch of saved prediction (batch {i+1:,d})")
            print('Can procees with normal prediction inference')
            nb_batches_skipped = i+1
            break
    return nb_batches_skipped, nb_batches_skipped * bs

# aln.reset_iterator()
# gen2 = aln.cnn_virus_input_generator(bs=512, label=118)
# # p2db = pfs.data / '/mnt/s/metagentools/ncbi/infer_results/yf-ncbi/test-2.db'
# p2db = pfs.data / '/mnt/s/metagentools/ncbi/infer_results/yf-ncbi/single_69seq_150bp.db'
# db = SqliteDatabase(p2db)
# db.connect();
# nbs, nrs = skip_existing_predictions(gen=gen2, db=db, bs=4)

In [14]:
def top_predictions(probs, n=5):
    """Returns the top n top predictions for each kmer read"""

    def top_n_most_frequent(preds, n=5):
        """Returns the top n most frequent predictions for each 50read"""
        # print(preds.shape)
        uniques, counts = np.unique(preds, return_counts=True)
        top_idx = np.argsort(counts)[-n:]
        return uniques.take(top_idx)

    top_preds_in_50mers = np.argsort(probs, axis=-1)[:, :, -n:]
    nb_kmers, nb_50mers, nb_lbls = top_preds_in_50mers.shape
    # print(top_preds_in_50mers.shape)
    top_preds_in_kmer = top_preds_in_50mers.reshape(nb_kmers,nb_50mers * nb_lbls)
    # print(top_preds_in_kmer.shape)

    return np.apply_along_axis(top_n_most_frequent, axis=1, arr=top_preds_in_kmer, n=n)

## The loop

### Inference for 25% of the simreads on 69 sequences ALN

In [15]:
fnames = files_in_tree(pfs.data / 'ncbi/simreads/yf', pattern='69seq')

simreads
  |--yf
  |    |--single_1seq_150bp
  |    |--single_69seq_150bp
  |    |    |--single_69seq_150bp.fq (0)
  |    |    |--single_69seq_150bp.aln (1)
  |    |--paired_1seq_150bp
  |    |--paired_69seq_150bp
  |    |    |--paired_69seq_150bp1.fq (2)
  |    |    |--paired_69seq_150bp2.fq (3)
  |    |    |--paired_69seq_150bp1.aln (4)
  |    |    |--paired_69seq_150bp2.aln (5)


In [16]:
file_stem = 'single_69seq_150bp'

p2aln = pfs.data / f"ncbi/simreads/yf/{file_stem[:-2] if file_stem[-1] in ['1', '2'] else file_stem}/{file_stem}.aln"
assert p2aln.exists()

aln = AlnFileReader(p2aln)
print(f"Reading alignment file: '{p2aln.name}' (in {p2aln.parent})\n")
for i, aln_read in enumerate(aln):
    pass
nb_kmer_reads = i
print(f"  - {i+1:,d} simulated reads in file '{p2aln.name}' from {len(aln.header['reference sequences'])} reference sequences.")

print('  - ART command: ',aln.header['command'])
print('  - Reference Sequences:')
print('     ','\n      '.join(aln.header['reference sequences']))

Reading alignment file: 'single_69seq_150bp.aln' (in /home/vtec/projects/bio/metagentools/data/ncbi/simreads/yf/single_69seq_150bp)

  - 1,161,034 simulated reads in file 'single_69seq_150bp.aln' from 69 reference sequences.
  - ART command:  /usr/bin/art_illumina -i /home/vtec/projects/bio/metagentools/data/ncbi/refsequences/yf/yf_2023_yellow_fever.fa -ss HS25 -l 150 -f 250 -o /home/vtec/projects/bio/metagentools/data/ncbi/simreads/yf/single_69seq_150bp/single_69seq_150bp -rs 1724163599
  - Reference Sequences:
      @SQ	11089:ncbi:1	1	AY968064	11089	ncbi	Angola_1971	10237
      @SQ	11089:ncbi:2	2	U54798	11089	ncbi	Ivory_Coast_1982	10237
      @SQ	11089:ncbi:3	3	DQ235229	11089	ncbi	Ethiopia_1961	10237
      @SQ	11089:ncbi:4	4	AY572535	11089	ncbi	Gambia_2001	10237
      @SQ	11089:ncbi:5	5	MF405338	11089	ncbi	Ghana_Hsapiens_1927	10237
      @SQ	11089:ncbi:6	6	U21056	11089	ncbi	Senegal_1927	10237
      @SQ	11089:ncbi:7	7	AY968065	11089	ncbi	Uganda_1948	10237
      @SQ	11089:ncbi:8	8	JX89

In [17]:
print(f"'{aln.path.name}' used for prediction (in {aln.path.parent}).")

'single_69seq_150bp.aln' used for prediction (in /home/vtec/projects/bio/metagentools/data/ncbi/simreads/yf/single_69seq_150bp).


In [34]:
# p2db = pfs.data / 'ncbi/infer_results/yf-ncbi' / f'{p2aln.stem}.db'
# p2db = pfs.data / 'ncbi/infer_results/yf-ncbi' / f'selected_test.db'
# p2db = Path('/mnt/s/metagentools') / 'ncbi/infer_results/yf-ncbi' / f'{p2aln.stem}.db'
p2db = Path('/mnt/s/metagentools') / 'ncbi/infer_results/yf-ncbi' / f'single_selected_7seq_150bp.db'
print(f"'{p2db.name}' sqlite db used, (in {p2db.parent})")

nl = '\n'
msg = f"Are you sure you want to use this database?{nl}Database file '{p2db.name}' does not correspond to aln '{aln.path.name}'"
if aln.path.stem != p2db.stem and 'selected' not in p2db.stem: raise Warning(msg)

'single_selected_7seq_150bp.db' sqlite db used, (in /mnt/s/metagentools/ncbi/infer_results/yf-ncbi)


> **NOTE**
>
> Estimated space required to save prediction and probability reports for the simreads simulated on 69 sequence is *470 Gb*. 
>
> This is currently too large to save even on my NAS. 
>
> Will first build a prediction dataset with 25% of the total reads: `nb_batches_to_run = int(nb_kmer_reads / b * 0.25)`

In [35]:
refseq_metadata = aln.parse_header_reference_sequences()

selected_refseqs = ['11089:ncbi:10','11089:ncbi:13', '11089:ncbi:30','11089:ncbi:37', '11089:ncbi:32', '11089:ncbi:35', '11089:ncbi:1' ]
# Originaly, 11089:ncbi:27 was also selected, but its accession (MK457701) is not in the distance matrix
print(f"      {'refseqid':^15s}  {'Accession':^13s}  {'Description':^20s}")
print(f"      {'-'*15:^15s}  {'-'*10:^13s}  {'-'*20:^20s}")
for i,rsid in enumerate(selected_refseqs):
    print(f"{i+1:2d}.   {rsid:15s}  {refseq_metadata[rsid]['refseq_accession']:^13s}  {refseq_metadata[rsid]['organism']:^20s}")


         refseqid        Accession        Description     
      ---------------   ----------    --------------------
 1.   11089:ncbi:10      GQ379163      Peru_Hsapiens_2007 
 2.   11089:ncbi:13      KU978764     Sudan_Hsapiens_1941 
 3.   11089:ncbi:30      KU978763     Nigeria_Hsapiens_1946
 4.   11089:ncbi:37      JF912190     Brazil_Hsapiens_2002
 5.   11089:ncbi:32      JF912181     Brazil_Hsapiens_1983
 6.   11089:ncbi:35      JF912182     Brazil_Hsapiens_1984
 7.   11089:ncbi:1       AY968064         Angola_1971     


In [36]:
# Review and set following parameters

# 2. Data parameters
b = 850             # number of k-mer in a batch 850 5to cover half a set of one refseq (1700)
# b = 4             
k = 150             # read length
true_label = 118    # yellow fever virus
top_n = 5           # n for top-n prediction to keep

# 3. Inference loop parameters
run_all_batches = True
# nb_batches_to_run = 2
# nb_batches_to_run = int(nb_kmer_reads / b * 0.25)

#====================================================================================================
# Setup prediction Loop
#====================================================================================================
nb_50mer = k - 49
uid = datetime.today().strftime('%Y-%m-%d_%H_%M_%S')

aln.reset_iterator()
model = create_model_original(path2parameters=p2model)
print(f"Model loaded and ready to run ...")

# Open connection to sqlite db and create tables if empty database
db = SqliteDatabase(p2db)
db.connect()
tables = db.get_result("SELECT name FROM sqlite_master WHERE type='table';")
if 'predictions' not in [t[0] for t in tables]:
    print("Empty database. Creating tables ...")
    create_tables(db, dry_run=False)
    db.print_schema()

# Create list of columns for prediction and probabilities reports
pred_cols_str = 'readid refseqid refsource refseq_strand taxonomyid'.split(' ')
pred_cols_int = 'lbl_true lbl_pred pos_true pos_pred'.split(' ')
top_pred_cols = [f"top_{top_n}_lbl_pred_{i}" for i in range(top_n)]
prob_cols = [f"prob_{i:03d}" for i in range(187)]

def tprint(string):
    print(f"{datetime.now().strftime('%H:%M:%S')}    {string}")

#====================================================================================================
# Setup prediction Loop
#====================================================================================================
print(f"Run prediction loop with the following parameters:")
print(f"   {b} k-mer per batch; {k} bp per sequence; keep top-{top_n} predictions")
tprint(f"Starting prediction loop ...")
gen = aln.cnn_virus_input_generator(bs=b, label=true_label)

# Skip kmer reads that are already processed
nb_batches_skipped, nb_reads_skipped = skip_existing_predictions(gen=gen, db=db, bs=b)
tprint(f"Skipped {nb_batches_skipped:,d} batches ({nb_reads_skipped:,d} kmer reads)")

# Proceed with prediction inference 
for i,(metadata_batch, reads_batch) in enumerate(gen):
    # skip any reference sequence not in the selected list
    any_selected_refseq_in_batch = any([rsid in selected_refseqs for rsid in metadata_batch['refseqid']])
    if not any_selected_refseq_in_batch:
        tprint(f"Skipping batch because does not include any selected reference sequence")
        continue

    loop_start = datetime.now()
    tprint(f"Batch {i+1:3,d} (aln batch {nb_batches_skipped+i+1:3,d}) ...")

    reads_kmer, (labels_true, position_true) = string_input_batch_to_tensors(reads_batch, k=k)
    reads_50mer = split_kmer_batch_into_50mers(reads_kmer)
    assert reads_50mer.shape == ((reads_kmer.shape[1]-49) * b, 50, 5), f"Problem with shape in batch {i+1}: {reads_50mer.shape}"

    tprint(f'  Starting prediction for {b:,} kmer reads ...')
    label_probs, position_probs = model.predict(reads_50mer)

    tprint('  Reshaping predictions ...')
    label_probs_kmer = tf.reshape(label_probs, shape=(b,nb_50mer,-1))
    position_probs_kmer = tf.reshape(position_probs, shape=(b,nb_50mer,-1))

    tprint('  Combining predictions ...')
    combined_predictions = tf.map_fn(
        fn=combine_prediction_batch,
        elems=[label_probs_kmer, position_probs_kmer], 
        fn_output_signature=tf.int64
        )

    label_predictions = combined_predictions[:,0]
    position_predictions = combined_predictions[:,1]
    top_preds = top_predictions(label_probs_kmer, n=top_n)

    # Add results for current batch
    tprint('  Preparing prediction report ...')
    preds_report = np.concatenate(
        [
            np.expand_dims(np.array(metadata_batch['readid']), axis=1),         # readid 
            np.expand_dims(np.array(metadata_batch['refseqid']), axis=1),       # refseqid
            np.expand_dims(np.array(metadata_batch['refsource']), axis=1),      # refsource
            np.expand_dims(np.array(metadata_batch['refseq_strand']), axis=1),  # refseq_strand
            np.expand_dims(np.array(metadata_batch['reftaxonomyid']), axis=1),  # taxonomyid
            np.expand_dims(np.array([true_label]*b), axis=1),                   # lbl_true
            np.expand_dims(label_predictions, axis=1),                          # lbl_pred
            np.expand_dims(np.array(metadata_batch['aln_start_pos']), axis=1),  # pos_true
            np.expand_dims(position_predictions, axis=1),                       # pos_pred
            top_preds[:, ::-1],                                                 # top_5_lbl_pred_0, top_5_lbl_pred_1, top_5_lbl_pred_2, top_5_lbl_pred_3, top_5_lbl_pred_4
        ],
        axis=1
    )

    df_preds = pd.DataFrame(
        data=preds_report, 
        columns=pred_cols_str + pred_cols_int + top_pred_cols
        )
    tprint('  Saving batch prediction report to db...')
    db.dataframe_to_table(df_preds, 'predictions', if_exists='append', index=False)

    tprint('  Preparing label probabilities report ...')
    df_probs = None
    for read_50mer_nb in range(nb_50mer):
        probs_report = np.concatenate(
            [
                np.expand_dims(np.array(metadata_batch['readid']), axis=1),     # readid 
                np.expand_dims(np.array([read_50mer_nb]*b), axis=1),            # read_50mer_nb
                np.expand_dims(np.array(metadata_batch['refseqid']), axis=1),   # refseqid
                label_probs_kmer[:, read_50mer_nb, :]                           # label probabilities
            ],
            axis=1
        )

        df = pd.DataFrame(
            data=probs_report, 
            columns=['read_kmer_id', 'read_50mer_nb', 'refseqid'] + prob_cols
            )
        df_probs = df if df_probs is None else pd.concat([df_probs, df], axis=0)

    tprint('  Saving batch label probabilities report to db...')
    db.dataframe_to_table(df_probs, 'label_probabilities', if_exists='append', index=False)

    tprint(f"  Batch processing time: {(datetime.now() - loop_start).total_seconds():.2f} sec")
    if not run_all_batches and i+1 >= nb_batches_to_run: 
        print('Stopping')
        break

db.close()

Creating CNN Model (Original)
Loading parameters from pretrained_model.h5
Created pretrained model
Model loaded and ready to run ...
Empty database. Creating tables ...

    CREATE TABLE IF NOT EXISTS predictions (
        id INTEGER PRIMARY KEY,
    readid TEXT, refseqid TEXT, refsource TEXT, refseq_strand TEXT, taxonomyid TEXT, lbl_true INTEGER, lbl_pred INTEGER, pos_true INTEGER, pos_pred INTEGER, top_5_lbl_pred_0 INTEGER, top_5_lbl_pred_1 INTEGER, top_5_lbl_pred_2 INTEGER, top_5_lbl_pred_3 INTEGER, top_5_lbl_pred_4 INTEGER) ;
CREATE INDEX IF NOT EXISTS idx_preds_refseqid ON predictions (refseqid);
CREATE INDEX IF NOT EXISTS idx_preds_3 ON predictions (readid, refseqid, pos_true);

    CREATE TABLE IF NOT EXISTS label_probabilities (
        id INTEGER PRIMARY KEY,
        read_kmer_id TEXT,
        read_50mer_nb INTEGER,
        refseqid TEXT, 
    prob_000 REAL,  prob_001 REAL,  prob_002 REAL,  prob_003 REAL,  prob_004 REAL,  prob_005 REAL,  prob_006 REAL,  prob_007 REAL,  prob_00

Review reads distribution in the database

In [70]:
p2db

Path('/mnt/s/metagentools/ncbi/infer_results/yf-ncbi/single_selected_8seq_150bp.db')

In [100]:
df = db.get_dataframe("SELECT refseqid, COUNT(*) AS count FROM predictions GROUP BY refseqid")
df['organism'] = [refseq_metadata[i]['organism'] for i in df['refseqid']]
df = df.loc[:, ['refseqid', 'organism', 'count']]
total_count = df['count'].sum()
print(f"{total_count:,d} predictions and {total_count * 101:,d} in database")
display(df)
print(sorted(selected_refseqs))

53,248 predictions and 5,378,048 in database


Unnamed: 0,refseqid,organism,count
0,11089:ncbi:1,Angola_1971,17000
1,11089:ncbi:10,Peru_Hsapiens_2007,17000
2,11089:ncbi:11,Spain_Vaccine_2004,496
3,11089:ncbi:12,Singapore_2017,224
4,11089:ncbi:13,Sudan_Hsapiens_1941,17000
5,11089:ncbi:14,ArD181250_Senegal_2005,184
6,11089:ncbi:2,Ivory_Coast_1982,408
7,11089:ncbi:26,Netherlands_Hsapiens_Gambia_2018,144
8,11089:ncbi:27,Nigeria_Hsapiens_2018,368
9,11089:ncbi:9,Senegal_Aedes-aegypti_1995,424


['11089:ncbi:1', '11089:ncbi:10', '11089:ncbi:13', '11089:ncbi:27', '11089:ncbi:30', '11089:ncbi:32', '11089:ncbi:35', '11089:ncbi:37']


Need to rerun this for the entire dataset, which would lead to 136,000 kmer reads and 13.5 million 50mer reads.

In [57]:
# db.execute("DROP TABLE IF EXISTS predictions;")
# db.execute("DROP TABLE IF EXISTS label_probabilities;")
# db.execute("DROP VIEW IF EXISTS preds_probs;")
# db.print_schema()

In [101]:
# display(db.get_dataframe("SELECT * FROM predictions;"))
# display(db.get_dataframe("SELECT * FROM predictions WHERE readid = '11089:ncbi:1-17000';"))
# display(db.get_dataframe("SELECT * FROM label_probabilities"))
# display(db.get_dataframe("SELECT * FROM label_probabilities WHERE read_kmer_id = '11089:ncbi:1-17000';"))

In [103]:
last_readid = db.execute("SELECT MAX(id) FROM predictions").fetchone()[0]
readid = db.execute(f"SELECT readid FROM predictions WHERE id = {last_readid};").fetchone()[0]
regex = re.compile(r'\d*:ncbi:(?P<read_nb>\d*-\d*)')
m = regex.search(readid)
read_nb = '???' if m is None else m.group('read_nb')

last_readid, readid, read_nb

(53248, '11089:ncbi:27-16633', '27-16633')

In [104]:
last_predictions_id = db.execute("SELECT MAX(id) FROM predictions").fetchone()[0]
nb_predictions = db.execute("SELECT COUNT(*) FROM predictions").fetchone()[0]
last_readid = db.execute(f"SELECT readid FROM predictions WHERE id = {last_predictions_id};").fetchone()[0]

print(f"Last id: {last_predictions_id:,d}, Nbr predictions: {nb_predictions:,d}, last readid: {last_readid}")

Last id: 53,248, Nbr predictions: 53,248, last readid: 11089:ncbi:27-16633


## Technical note to accelerate `skip_existing_predictions`

The slow step in `skip_existing_predictions` is the query to get the total number of rows in the table. We can accelerate this by maintaining an accurate row count in a separate column when inserting new rows into a SQLite table named "predictions" that has a primary key "id" and an indexed column "readid". 

This can be done with the following steps:

1. **Create a trigger** that fires after an INSERT operation on the "predictions" table. The trigger will update the row count in a separate table or column.

2. **Create a table** to store the row count, for example:
```sql
    CREATE TABLE table_stats (
    id INTEGER PRIMARY KEY,
    row_count INTEGER NOT NULL
    );
```

3. **Create the trigger** to update the row count after an INSERT:
```sql
    CREATE TRIGGER update_prediction_count
    AFTER INSERT ON predictions
    BEGIN
    UPDATE table_stats 
    SET row_count = row_count + 1
    WHERE id = 1;
    END;
```

This trigger assumes there is only one row in the "prediction_stats" table with an id of 1. If you want to store the count per "readid", you can modify the trigger to:

```sql
    CREATE TRIGGER update_prediction_count
    AFTER INSERT ON predictions  
    BEGIN
    INSERT INTO table_stats (readid, row_count)
    VALUES (NEW.readid, 1)
    ON CONFLICT(readid) DO UPDATE SET row_count = row_count + 1;
    END;
```

This will insert a new row for each unique "readid" with an initial count of 1, and update the row_count if the "readid" already exists.

4. **Insert a row** into the "prediction_stats" table with an initial count:
```sql
    INSERT INTO table_stats (id, row_count) VALUES (1, 0);
```

Now, whenever a new row is inserted into the "predictions" table, the trigger will automatically update the row count in the "prediction_stats" table. This maintains an accurate count without needing to perform a full table scan with COUNT(*).

Remember to handle deletions as well by creating a BEFORE DELETE trigger that decrements the row count accordingly.

Citations:
- 1. https://www.sqlitetutorial.net/sqlite-insert/
- 2. https://www.sqlitetutorial.net/sqlite-count-function/
- 3. https://stackoverflow.com/questions/55007800/dynamic-way-to-insert-data-into-sqlite-table-when-column-counts-change
- 4. https://docs.python.org/es/3/library/sqlite3.html
- 5. https://www.sql-easy.com/learn/sqlite-count/
- 6. https://stackoverflow.com/questions/4474873/what-is-the-most-efficient-way-to-count-rows-in-a-table-in-sqlite
- 7. https://sqlite.org/forum/info/57c04743e1b6aa10
- 8. https://sqlite.org/forum/forumpost/f832398c19

# End of Section