# Prediction on yellow fever simulated reads (Sqlite)

This is a a reference notebook for prediction inference on yellow fever simulated reads, using the CNN_Virus original model and saving predictions, probabilities and metadata in a sqlite database.

- Simulated reads from an aligned file generated by ART Illumina simulator (`*.aln file`).
- Uses the generator provided by `AlnFileReader.cnn_virus_input_generator` to read batches of simulates reads and their metadata.
- Uses the `cnn_virus` model to predict the label and position probabilities and classes for each simreads.
- Creates a prediction report and saves it in a sqlite database for easier retrieval and analysis later.

> **Note**: 
>
>When an `*aln` file counts a very large number of simulated reads, running a prediction on all of them is very time consuming. Thererofe, we also provide a function `skip_existing_predictions` applied to the generator, which allows to skip all simulated reads down to the last simulated read for which a prediction was already saved into the database. This allows to build the database step by step.

# 1. Imports and setup environment

In [1]:
# Install required custom packages if not installed yet.
import importlib.util
if not importlib.util.find_spec('ecutilities'):
    print('installing package: `ecutilities`')
    ! pip install -qqU ecutilities
else:
    print('`ecutilities` already installed')
if not importlib.util.find_spec('metagentools'):
    print('installing package: `metagentools')
    ! pip install -qqU metagentools
else:
    print('`metagentools` already installed')

`ecutilities` already installed
`metagentools` already installed


In [2]:
# Import all required packages
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import re
import sqlite3
import time

from datetime import datetime
from ecutilities.core import files_in_tree
from ecutilities.ipython import nb_setup
from functools import partial
from IPython.display import display, update_display, Markdown, HTML
from nbdev import show_doc
from pandas import HDFStore
from pathlib import Path
from pprint import pprint
from tqdm.notebook import tqdm, trange
from typing import List, Tuple, Dict, Any, Generator

# Setup the notebook for development
nb_setup()

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'}
import tensorflow as tf
from tensorflow.python.client import device_lib
from tensorflow.keras.models import load_model
print(f"Tensorflow version: {tf.__version__}\n")

from metagentools.cnn_virus.data import _base_hot_encode, strings_to_tensors
from metagentools.cnn_virus.data import split_kmer_into_50mers, combine_prediction_batch
from metagentools.cnn_virus.data import FastaFileReader, FastqFileReader, AlnFileReader
from metagentools.cnn_virus.data import OriginalLabels
from metagentools.cnn_virus.data import string_input_batch_to_tensors, split_kmer_batch_into_50mers
from metagentools.cnn_virus.architecture import create_model_original
from metagentools.core import ProjectFileSystem, TextFileBaseReader

Set autoreload mode
Tensorflow version: 2.8.2



List all computing devices available on the machine

In [3]:
devices = device_lib.list_local_devices()
print('\nDevices:')
for d in devices:
    t = d.device_type
    name = d.physical_device_desc
    l = [item.split(':', 1) for item in name.split(', ')]
    name_attr = dict([x for x in l if len(x)==2])
    dev = name_attr.get('name', ' ')
    print(f"  - {t}  {d.name} {dev:25s}")


Devices:
  - CPU  /device:CPU:0                          
  - GPU  /device:GPU:0  NVIDIA GeForce GTX 1050 


# 2. Setup paths to files

Key folders and system information

In [4]:
pfs = ProjectFileSystem()
pfs.info()

Running linux on local computer
Device's home directory: /home/vtec
Project file structure:
 - Root ........ /home/vtec/projects/bio/metagentools 
 - Data Dir .... /home/vtec/projects/bio/metagentools/data 
 - Notebooks ... /home/vtec/projects/bio/metagentools/nbs


- `p2model`: path to file with saved original pretrained model
- `p2virus_labels` path to file with virus names and labels mapping for original model

In [5]:
p2model = pfs.data / 'saved/cnn_virus_original/pretrained_model.h5'
assert p2model.is_file(), f"No file found at {p2model.absolute()}"

p2virus_labels = pfs.data / 'CNN_Virus_data/virus_name_mapping'
assert p2virus_labels.is_file(), f"No file found at {p2virus_labels.absolute()}"

Path to the simulated read we want to use

In [6]:
fnames = files_in_tree(pfs.data / 'ncbi/simreads/yf', pattern='69seq')

simreads
  |--yf
  |    |--single_1seq_150bp
  |    |--single_69seq_150bp
  |    |    |--single_69seq_150bp.fq (0)
  |    |    |--single_69seq_150bp.aln (1)
  |    |--paired_1seq_150bp
  |    |--paired_69seq_150bp
  |    |    |--paired_69seq_150bp1.fq (2)
  |    |    |--paired_69seq_150bp2.fq (3)
  |    |    |--paired_69seq_150bp1.aln (4)
  |    |    |--paired_69seq_150bp2.aln (5)


In [7]:
file_stem = 'single_69seq_150bp'

p2aln = pfs.data / f"ncbi/simreads/yf/{file_stem[:-2] if file_stem[-1] in ['1', '2'] else file_stem}/{file_stem}.aln"
assert p2aln.exists()

aln = AlnFileReader(p2aln)
print(f"Reading alignment file: {p2aln.name}:\n")
for i, aln_read in enumerate(aln):
    pass
print(f"  - {i+1:,d} simulated reads in file '{p2aln.name}' from {len(aln.header['reference sequences'])} reference sequences.")

print('  - ART command: ',aln.header['command'])
print('  - Reference Sequences:')
print('     ','\n      '.join(aln.header['reference sequences']))

Reading alignment file: single_69seq_150bp.aln:

  - 1,161,034 simulated reads in file 'single_69seq_150bp.aln' from 69 reference sequences.
  - ART command:  /usr/bin/art_illumina -i /home/vtec/projects/bio/metagentools/data/ncbi/refsequences/yf/yf_2023_yellow_fever.fa -ss HS25 -l 150 -f 250 -o /home/vtec/projects/bio/metagentools/data/ncbi/simreads/yf/single_69seq_150bp/single_69seq_150bp -rs 1724163599
  - Reference Sequences:
      @SQ	11089:ncbi:1	1	AY968064	11089	ncbi	Angola_1971	10237
      @SQ	11089:ncbi:2	2	U54798	11089	ncbi	Ivory_Coast_1982	10237
      @SQ	11089:ncbi:3	3	DQ235229	11089	ncbi	Ethiopia_1961	10237
      @SQ	11089:ncbi:4	4	AY572535	11089	ncbi	Gambia_2001	10237
      @SQ	11089:ncbi:5	5	MF405338	11089	ncbi	Ghana_Hsapiens_1927	10237
      @SQ	11089:ncbi:6	6	U21056	11089	ncbi	Senegal_1927	10237
      @SQ	11089:ncbi:7	7	AY968065	11089	ncbi	Uganda_1948	10237
      @SQ	11089:ncbi:8	8	JX898871	11089	ncbi	ArD114896_Senegal_1995	10237
      @SQ	11089:ncbi:9	9	JX898872	11089

# 3. (Optional) Test each inference step

Let's test the steps to prepare model inputs, using a small batch size:
1. create the generator using teh `.aln` file to yiel pairs of batches (metadata and reads strings) using `aln.cnn_virus_input_generator`
2. transform the batch of string reads into a base hot encoded tensor, using the preprocessing funtion `string_input_batch_to_tensors`
3. split each kmer read into (k-49) 50-mer reads to present to the model

In [8]:
show_doc(aln.cnn_virus_input_generator)

---

[source](https://github.com/vtecftwy/metagentools/blob/main/metagentools/cnn_virus/data.py#LNone){target="_blank" style="float:right; font-size:smaller"}

### AlnFileReader.cnn_virus_input_generator

>      AlnFileReader.cnn_virus_input_generator (label:int=118, bs:int=32)

Create a generator yielding a metadata batch (dict) and a read batch (tensor of strings)

The metadata dictionary contains lists of metadata for each read in the batch. For example:
``` 
{
    'readid': ['2591237:ncbi:1-40200','2591237:ncbi:1-40199','2591237:ncbi:1-40198', ...],
    'refseqid': ['2591237:ncbi:1','2591237:ncbi:1','2591237:ncbi:1', ...],
    'read_pos': [5, 6, 1, ...],
    'refsource': ['ncbi', 'ncbi', 'ncbi', ...],
    ...
}
```

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| label | int | 118 | label for this batch (assuming all reads are from the same species) |
| bs | int | 32 | batch size |
| **Returns** | **tuple** |  | **dict of metadata list and tensor of strings** |

We need to define the model label for yellow fever and a batch size.

In [14]:
OriginalLabels().search(s='yellow')

Yellow_fever_virus. Label: 118


In [15]:
show_doc(string_input_batch_to_tensors)

---

[source](https://github.com/vtecftwy/metagentools/blob/main/metagentools/cnn_virus/data.py#LNone){target="_blank" style="float:right; font-size:smaller"}

### string_input_batch_to_tensors

>      string_input_batch_to_tensors (b:tensorflow.python.framework.ops.Tensor,
>                                     k:int=150)

Function converting a batch of bp strings into three tensors: (x_seqs, (y_labels, y_pos))

Expects input strings to have the format: 'read sequence     label   position' where:

- read sequence: a string of length k (kmer)
- label: an integer between 0 and 186 representing the read virus label
- position: an integer between 0 and 9 representing the read position in the genome

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| b | Tensor |  | batch of strings representing the inputs (kmer, label, position) |
| k | int | 150 | maximum read length in the batch |

In [12]:
b = 8
true_label = 118

aln.reset_iterator()
for batch_meta, batch_reads in aln.cnn_virus_input_generator(bs=b, label=true_label):
    reads_kmer, (labels_kmer, positions_kmer) = string_input_batch_to_tensors(batch_reads, k=150)
    break

print('Review metadata batch yielded by the generator:')
print(f"  List of metadata keys:")
print('  -','\n  - '.join(batch_meta.keys()))
print(f"  'readid' included in this batch':", ', '.join(batch_meta['readid']))
print('\nReview batch of string reads yielded by the generator:')
print(f"  - Shape: {batch_reads.shape}")
print('\nReview the read kmer tensor after preprocessing:')
print(f"  - Shape: {reads_kmer.shape}")
print('\nReview ground truth tensors:')
print(f"  - Shape true label tensor:    {labels_kmer.shape}")
print(f"  - Shape true position tensor: {positions_kmer.shape}")

Review metadata batch yielded by the generator:
  List of metadata keys:
  - aln_start_pos
  - readid
  - readnb
  - refseq_strand
  - refseqid
  - refseqnb
  - refsource
  - reftaxonomyid
  - read_pos
  'readid' included in this batch': 11089:ncbi:1-17000, 11089:ncbi:1-16999, 11089:ncbi:1-16998, 11089:ncbi:1-16997, 11089:ncbi:1-16996, 11089:ncbi:1-16995, 11089:ncbi:1-16994, 11089:ncbi:1-16993

Review batch of string reads yielded by the generator:
  - Shape: (8,)

Review the read kmer tensor after preprocessing:
  - Shape: (8, 150, 5)

Review ground truth tensors:
  - Shape true label tensor:    (8, 187)
  - Shape true position tensor: (8, 10)


The model only accepts 50-mer reads, so we need to split kmer reads into 50mers. For each kmer read, k-49 50mer reads will be generated, by shifting a window of 50 nucleotides by 1 nucleotide at a time.

In [13]:
show_doc(split_kmer_batch_into_50mers)

---

[source](https://github.com/vtecftwy/metagentools/blob/main/metagentools/cnn_virus/data.py#LNone){target="_blank" style="float:right; font-size:smaller"}

### split_kmer_batch_into_50mers

>      split_kmer_batch_into_50mers
>                                    (kmer:tensorflow.python.framework.ops.Tenso
>                                    r)

Converts a batch of k-mer reads into several 50-mer reads, by shifting the k-mer one base at a time.

for a batch of `b` k-mer reads, returns a batch of `b - 49` 50-mer reads

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| kmer | Tensor | tensor representing a batch of k-mer reads, after base encoding |

In [14]:
reads_50mer = split_kmer_batch_into_50mers(reads_kmer)
nb_50mer_per_kmer = reads_kmer.shape[1]-49 
nb_50mer_reads = (nb_50mer_per_kmer) * reads_kmer.shape[0]
assert reads_50mer.shape == (nb_50mer_reads, 50,5)

print(f"Each kmer is split into {nb_50mer_per_kmer} 50-mer reads. Total 50-mer reads in a batch: {nb_50mer_reads}")
print('\nReview the reads tensor:')
print(f"  - Shape kmer tensor:   {reads_kmer.shape}")
print(f"  - Shape 50-mer tensor: {reads_50mer.shape}")

Each kmer is split into 101 50-mer reads. Total 50-mer reads in a batch: 808

Review the reads tensor:
  - Shape kmer tensor:   (8, 150, 5)
  - Shape 50-mer tensor: (808, 50, 5)


If all runs smoothly, our generator and preprocessing are working fine. We can run the prediction loop

# 4. Run the Loop

## Utility Functions

In [8]:
# Database functions
def open_db(p2db: Path,    #path to the sqlite database
            k: int,        # k-mer size
            top_n: int=5   # number of top predictions to store, 5 by default
            ):
    conn = sqlite3.connect(p2db)
    cursor = conn.cursor()
    query = "SELECT name FROM sqlite_master WHERE type='table' AND name='predictions'"
    res = cursor.execute(query).fetchone()
    if res is None or 'predictions' not in res:
        print('Creating tables in database...')
        create_tables(cursor, k)
    return conn

def table_columns(cursor:sqlite3.Cursor,  # cursor to the database  
                  table: str              # name of the table
                  ):
    """Returns the name of the columns in the passed table"""
    query = f"PRAGMA table_info({table})"
    cursor.execute(query)
    cols = [row[1] for row in cursor.fetchall()]
    return cols

def create_tables(cursor: sqlite3.Cursor,   # cursor to the database  
                  k:int,                    # k-mer size
                  top_n: int=5              # number of top predictions to store, 5 by default   
                  ):
    """Create tables for the database"""
    # Create table for predictions
    pred_cols_str = 'readid refseqid refsource refseq_strand taxonomyid'.split(' ')
    pred_cols_int = 'lbl_true lbl_pred pos_true pos_pred'.split(' ')
    top_pred_cols = [f"top_{top_n}_lbl_pred_{i}" for i in range(top_n)]

    query = """
    CREATE TABLE IF NOT EXISTS predictions (
        id INTEGER PRIMARY KEY,
    """
    for col in pred_cols_str:
        query += f"{col} TEXT, "
    for col in pred_cols_int:
        query += f"{col} INTEGER, "
    for col in top_pred_cols:
        query += f"{col} INTEGER, "
    query = query[:-2]+')'
    # print(query)
    cursor.execute(query)

    query = "CREATE INDEX idx_preds ON predictions (readid, refseqid, pos_true);"
    cursor.execute(query)
    print('Table `predictions` created.')

    # Create table for probabilities (one per 50-mer in order to keep small nb or columns in table)

    query = f"""
    CREATE TABLE IF NOT EXISTS label_probabilities (
        id INTEGER PRIMARY KEY,
        read_kmer_id TEXT,
        read_50mer_nb INTEGER,
    """
    query += ','.join([f"prob_{i:03d}" for i in range(187)]) + " REAL, "
    query += "FOREIGN KEY (read_kmer_id) REFERENCES predictions(readid)"
    query += ')'
    cursor.execute(query)

    query = "CREATE INDEX idx_probs ON label_probabilities (read_kmer_id, read_50mer_nb);"
    cursor.execute(query)
    print(f'Table `label_probabilities` created.')

In [10]:
def skip_existing_predictions(gen: Generator,       # generator of batches (metadata, reads)
                              p2db: Path,           # path to the sqlite database 
                              bs: int               # batch size
                              ) -> Tuple[int, int]: # number of batches and kmer reads skipped
    
    # Identify the readnb for the last saved prediction
    print('Checking predictions already in database...')
    with open_db(p2db=p2db, k=150) as conn:
        last_predictions_id = conn.execute("SELECT MAX(id) FROM predictions").fetchone()[0]
        if last_predictions_id is None:
            return 0, 0
        else:
            nb_predictions = conn.execute("SELECT COUNT(*) FROM predictions").fetchone()[0]
            last_readid = conn.execute(f"SELECT readid FROM predictions WHERE id = {last_predictions_id};").fetchone()[0]
        
    print(f"Database includes {nb_predictions:,d} predictions, corresponding to {nb_predictions//bs:,d} batches")
    print(f"Last prediction id: {last_predictions_id:,d} for kmer read '{last_readid}'")
    print(f"Skipping already processed predictions ...")

    for i, (batch_meta, batch_reads) in enumerate(gen):
        if i%100 == 0: 
            print(f"   Skipped first {(i+1)*bs:,d} kmer reads ({i+1:,d} batches)")
 
        if last_readid in batch_meta['readid']:
            print(f"   Reached last batch of saved prediction (batch {i+1:,d})")
            print('Can procees with normal prediction inference')
            nb_kmer_reads_skipped = nb_predictions
            nb_batches_skipped = i+1
            break
    return nb_batches_skipped, nb_kmer_reads_skipped

# aln.reset_iterator()
# gen2 = aln.cnn_virus_input_generator(bs=512, label=118)
# nbs, nrs = skip_existing_predictions(gen=gen2, p2db=p2db, bs=512)

In [11]:
def top_predictions(probs, n=5):
    """Returns the top n top predictions for each kmer read"""

    def top_n_most_frequent(preds, n=5):
        """Returns the top n most frequent predictions for each 50read"""
        # print(preds.shape)
        uniques, counts = np.unique(preds, return_counts=True)
        top_idx = np.argsort(counts)[-n:]
        return uniques.take(top_idx)

    top_preds_in_50mers = np.argsort(probs, axis=-1)[:, :, -n:]
    nb_kmers, nb_50mers, nb_lbls = top_preds_in_50mers.shape
    # print(top_preds_in_50mers.shape)
    top_preds_in_kmer = top_preds_in_50mers.reshape(nb_kmers,nb_50mers * nb_lbls)
    # print(top_preds_in_kmer.shape)

    return np.apply_along_axis(top_n_most_frequent, axis=1, arr=top_preds_in_kmer, n=n)

## The loop

### Inference for 25% of the simreads on 69 sequences

In [12]:
fnames = files_in_tree(pfs.data / 'ncbi/simreads/yf', pattern='69seq')

simreads
  |--yf
  |    |--single_1seq_150bp
  |    |--single_69seq_150bp
  |    |    |--single_69seq_150bp.fq (0)
  |    |    |--single_69seq_150bp.aln (1)
  |    |--paired_1seq_150bp
  |    |--paired_69seq_150bp
  |    |    |--paired_69seq_150bp1.fq (2)
  |    |    |--paired_69seq_150bp2.fq (3)
  |    |    |--paired_69seq_150bp1.aln (4)
  |    |    |--paired_69seq_150bp2.aln (5)


In [13]:
file_stem = 'single_69seq_150bp'

p2aln = pfs.data / f"ncbi/simreads/yf/{file_stem[:-2] if file_stem[-1] in ['1', '2'] else file_stem}/{file_stem}.aln"
assert p2aln.exists()

aln = AlnFileReader(p2aln)
print(f"Reading alignment file: '{p2aln.name}' (in {p2aln.parent})\n")
for i, aln_read in enumerate(aln):
    pass
nb_kmer_reads = i
print(f"  - {i+1:,d} simulated reads in file '{p2aln.name}' from {len(aln.header['reference sequences'])} reference sequences.")

print('  - ART command: ',aln.header['command'])
print('  - Reference Sequences:')
print('     ','\n      '.join(aln.header['reference sequences']))

Reading alignment file: 'single_69seq_150bp.aln' (in /home/vtec/projects/bio/metagentools/data/ncbi/simreads/yf/single_69seq_150bp)

  - 1,161,034 simulated reads in file 'single_69seq_150bp.aln' from 69 reference sequences.
  - ART command:  /usr/bin/art_illumina -i /home/vtec/projects/bio/metagentools/data/ncbi/refsequences/yf/yf_2023_yellow_fever.fa -ss HS25 -l 150 -f 250 -o /home/vtec/projects/bio/metagentools/data/ncbi/simreads/yf/single_69seq_150bp/single_69seq_150bp -rs 1724163599
  - Reference Sequences:
      @SQ	11089:ncbi:1	1	AY968064	11089	ncbi	Angola_1971	10237
      @SQ	11089:ncbi:2	2	U54798	11089	ncbi	Ivory_Coast_1982	10237
      @SQ	11089:ncbi:3	3	DQ235229	11089	ncbi	Ethiopia_1961	10237
      @SQ	11089:ncbi:4	4	AY572535	11089	ncbi	Gambia_2001	10237
      @SQ	11089:ncbi:5	5	MF405338	11089	ncbi	Ghana_Hsapiens_1927	10237
      @SQ	11089:ncbi:6	6	U21056	11089	ncbi	Senegal_1927	10237
      @SQ	11089:ncbi:7	7	AY968065	11089	ncbi	Uganda_1948	10237
      @SQ	11089:ncbi:8	8	JX89

In [14]:
print(f"'{aln.path.name}' used for prediction (in {aln.path.parent}).")

'single_69seq_150bp.aln' used for prediction (in /home/vtec/projects/bio/metagentools/data/ncbi/simreads/yf/single_69seq_150bp).


In [15]:
# p2db = pfs.data / 'ncbi/infer_results/yf-ncbi' / f'{p2aln.stem}.db'
p2db = Path('/mnt/k/metagentools') / 'ncbi/infer_results/yf-ncbi' / f'{p2aln.stem}.db'
print(f"'{p2db.name}' sqlite db used, (in {p2db.parent})")

nl = '\n'
msg = f"Are you sure you want to use this database?{nl}Database file '{p2db.name}' does not correspond to aln '{aln.path.name}'"
if aln.path.stem != p2db.stem: raise Warning(msg)

'single_69seq_150bp.db' sqlite db used, (in /mnt/k/metagentools/ncbi/infer_results/yf-ncbi)


> **NOTE**
>
> Estimated space required to save prediction and probability reports for the simreads simulated on 69 sequence is *470 Gb*. 
>
> This is currently too large to save even on my NAS. 
>
> Will first build a prediction dataset with 25% of the total reads: `nb_batches_to_run = int(nb_kmer_reads / b * 0.25)`

In [None]:
# Review and set following parameters

# 2. Data parameters
b = 512             # number of k-mer in a batch
k = 150             # read length
true_label = 118    # yellow fever virus
top_n = 5           # n for top-n prediction to keep

# 3. Inference loop parameters
run_all_batches = False
# nb_batches_to_run = 4
nb_batches_to_run = int(nb_kmer_reads / b * 0.25)

#====================================================================================================
# Setup prediction Loop
#====================================================================================================
nb_50mer = k - 49
uid = datetime.today().strftime('%Y-%m-%d_%H_%M_%S')

aln.reset_iterator()
model = create_model_original(path2parameters=p2model)
print(f"Model loaded and ready to run ...")

# Open connection to sqlite db
conn = open_db(p2db, k=k)
cursor = conn.cursor()

# Create list of columns for prediction and probabilities reports
pred_cols_str = 'readid refseqid refsource refseq_strand taxonomyid'.split(' ')
pred_cols_int = 'lbl_true lbl_pred pos_true pos_pred'.split(' ')
top_pred_cols = [f"top_{top_n}_lbl_pred_{i}" for i in range(top_n)]
prob_cols = [f"prob_{i:03d}" for i in range(187)]

def tprint(string):
    print(f"{datetime.now().strftime('%H:%M:%S')}    {string}")

#====================================================================================================
# Setup prediction Loop
#====================================================================================================
print(f"Run prediction loop with the following parameters:")
print(f"   {b} k-mer per batch; {k} bp per sequence; keep top-{top_n} predictions")
tprint(f"Starting prediction loop ...")
gen = aln.cnn_virus_input_generator(bs=b, label=true_label)

# Skip kmer reads that are already processed
nb_batches_skipped, nb_kmer_reads_skipped = skip_existing_predictions(gen=gen, p2db=p2db, bs=b)
tprint(f"Skipped {nb_batches_skipped:,d} batches ({nb_kmer_reads_skipped:,d} kmer reads)")

# Proceed with prediction inference 
for i,(metadata_batch, reads_batch) in enumerate(gen):
    loop_start = datetime.now()
    tprint(f"Batch {i+1:3,d} (aln batch {nb_batches_skipped+i+1:3,d}) ...")

    reads_kmer, (labels_true, position_true) = string_input_batch_to_tensors(reads_batch, k=k)
    reads_50mer = split_kmer_batch_into_50mers(reads_kmer)
    assert reads_50mer.shape == ((reads_kmer.shape[1]-49) * b, 50, 5), f"Problem with shape in batch {i+1}: {reads_50mer.shape}"

    tprint(f'  Starting prediction for {b:,} kmer reads ...')
    label_probs, position_probs = model.predict(reads_50mer)

    tprint('  Reshaping predictions ...')
    label_probs_kmer = tf.reshape(label_probs, shape=(b,nb_50mer,-1))
    position_probs_kmer = tf.reshape(position_probs, shape=(b,nb_50mer,-1))

    tprint('  Combining predictions ...')
    combined_predictions = tf.map_fn(
        fn=combine_prediction_batch,
        elems=[label_probs_kmer, position_probs_kmer], 
        fn_output_signature=tf.int64
        )

    label_predictions = combined_predictions[:,0]
    position_predictions = combined_predictions[:,1]
    top_preds = top_predictions(label_probs_kmer, n=top_n)

    # Add results for current batch
    tprint('  Preparing prediction report ...')
    preds_report = np.concatenate(
        [
            np.expand_dims(np.array(metadata_batch['readid']), axis=1),         # readid 
            np.expand_dims(np.array(metadata_batch['refseqid']), axis=1),       # refseqid
            np.expand_dims(np.array(metadata_batch['refsource']), axis=1),      # refsource
            np.expand_dims(np.array(metadata_batch['refseq_strand']), axis=1),  # refseq_strand
            np.expand_dims(np.array(metadata_batch['reftaxonomyid']), axis=1),  # taxonomyid
            np.expand_dims(np.array([true_label]*b), axis=1),                   # lbl_true
            np.expand_dims(label_predictions, axis=1),                          # lbl_pred
            np.expand_dims(np.array(metadata_batch['aln_start_pos']), axis=1),  # pos_true
            np.expand_dims(position_predictions, axis=1),                       # pos_pred
            top_preds[:, ::-1],                                                 # top_5_lbl_pred_0, top_5_lbl_pred_1, top_5_lbl_pred_2, top_5_lbl_pred_3, top_5_lbl_pred_4
        ],
        axis=1
    )

    df_preds = pd.DataFrame(
        data=preds_report, 
        columns=pred_cols_str + pred_cols_int + top_pred_cols
        )
    tprint('  Saving batch prediction report to db...')
    df_preds.to_sql('predictions', conn, if_exists='append', index=False)

    tprint('  Preparing label probabilities report ...')
    df_probs = None
    for read_50mer_nb in range(nb_50mer):
        probs_report = np.concatenate(
            [
                np.expand_dims(np.array(metadata_batch['readid']), axis=1),        # readid 
                np.expand_dims(np.array([read_50mer_nb]*b), axis=1),               # read_50mer_nb
                label_probs_kmer[:, read_50mer_nb, :]
            ],
            axis=1
        )

        df = pd.DataFrame(
            data=probs_report, 
            columns=['read_kmer_id', 'read_50mer_nb'] + prob_cols
            )
        df_probs = df if df_probs is None else pd.concat([df_probs, df], axis=0)

    tprint('  Saving batch label probabilities report to db...')
    df_probs.to_sql('label_probabilities', conn, if_exists='append', index=False)

    tprint(f"  Batch processing time: {(datetime.now() - loop_start).total_seconds():.2f} sec")
    if not run_all_batches and i+1 >= nb_batches_to_run: 
        print('Stopping')
        break

conn.close()

In [17]:
with open_db(p2db=p2db, k=150) as conn:
    display(pd.read_sql_query("SELECT * FROM predictions WHERE readid = '11089:ncbi:1-17000';", conn))
    display(pd.read_sql_query("SELECT * FROM label_probabilities WHERE read_kmer_id = '11089:ncbi:1-17000';", conn))

Unnamed: 0,id,readid,refseqid,refsource,refseq_strand,taxonomyid,lbl_true,lbl_pred,pos_true,pos_pred,top_5_lbl_pred_0,top_5_lbl_pred_1,top_5_lbl_pred_2,top_5_lbl_pred_3,top_5_lbl_pred_4
0,1,11089:ncbi:1-17000,11089:ncbi:1,ncbi,-,11089,118,10,7804,0,118,32,10,62,94


Unnamed: 0,id,read_kmer_id,read_50mer_nb,prob_000,prob_001,prob_002,prob_003,prob_004,prob_005,prob_006,...,prob_177,prob_178,prob_179,prob_180,prob_181,prob_182,prob_183,prob_184,prob_185,prob_186
0,1,11089:ncbi:1-17000,0,9.452837e-16,1.8033755e-24,3.709243e-18,3.2934015e-16,1.11393635e-13,4.909559e-09,3.2027174e-19,...,4.620669e-10,2.1956457e-23,4.973716e-17,1.2870734e-19,6.685301e-13,2.7962992e-13,5.0206733e-22,2.3932547e-15,1.2987661e-21,3.557021e-18
1,513,11089:ncbi:1-17000,1,9.987287e-13,9.640493e-23,6.864665e-13,5.297059e-17,5.2123617e-09,1.9526384e-07,6.811955e-14,...,0.00029164975,9.457506e-20,4.247617e-15,2.2502056e-20,1.3939608e-09,4.1289537e-12,1.0989764e-16,5.3598934e-11,1.09030056e-14,1.850025e-18
2,1025,11089:ncbi:1-17000,2,4.1150392e-14,1.0409154e-21,3.1731182e-18,4.837065e-13,1.8273633e-12,5.6225136e-13,1.5127502e-14,...,2.7924674e-08,9.881187e-25,1.2789581e-19,5.2265872e-23,3.6289507e-16,5.854743e-18,4.597412e-22,2.7675505e-16,3.6334886e-19,7.191011e-22
3,1537,11089:ncbi:1-17000,3,2.4102365e-13,3.9195824e-19,4.2152547e-16,5.363886e-15,7.9354416e-13,1.9598592e-10,2.0204985e-14,...,2.3507164e-12,3.4494873e-22,1.4030163e-15,5.3813346e-20,2.1669545e-18,3.9168753e-20,4.6393943e-18,2.162604e-17,5.868014e-17,3.671587e-12
4,2049,11089:ncbi:1-17000,4,5.402446e-15,1.3982798e-24,5.1237304e-19,2.6393469e-12,7.070969e-12,4.5762413e-08,3.941572e-13,...,2.2205168e-10,2.4770586e-20,2.2223746e-16,1.3230725e-16,1.473535e-13,1.0368611e-15,7.554584e-18,6.905381e-15,6.823121e-19,5.281248e-14
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
96,49153,11089:ncbi:1-17000,96,1.266811e-12,1.7694516e-13,6.9926794e-11,7.442386e-16,2.1827995e-12,2.2899127e-11,6.6601347e-13,...,2.1055688e-09,2.6956715e-11,4.02884e-10,3.6730546e-06,7.332998e-13,9.214267e-17,4.41918e-18,1.0378567e-09,3.9226554e-16,1.512387e-09
97,49665,11089:ncbi:1-17000,97,3.6751678e-11,5.2148962e-11,1.1142021e-07,1.3979555e-13,0.00023110698,1.8854205e-11,1.128952e-07,...,3.426913e-08,3.5629816e-17,2.7815642e-11,7.9123413e-07,9.112657e-12,1.8065088e-14,2.2273976e-15,2.6934228e-07,3.2327324e-15,1.599009e-12
98,50177,11089:ncbi:1-17000,98,3.987377e-10,3.995657e-08,4.0260434e-06,4.5705413e-09,0.00026610115,3.9758719e-10,1.0729461e-07,...,6.970071e-08,1.210523e-17,2.2871915e-10,3.9424535e-06,1.6976076e-09,4.1861083e-09,3.708444e-10,6.0206723e-05,4.0571067e-09,5.209376e-10
99,50689,11089:ncbi:1-17000,99,3.773077e-12,6.944977e-13,6.4031043e-09,3.3788874e-09,7.6050675e-08,4.657467e-08,3.358606e-07,...,3.0940438e-08,1.2781211e-18,2.6286571e-08,1.7933372e-08,4.844457e-09,1.0032285e-13,2.9824456e-15,3.3517054e-08,4.4160476e-16,2.173458e-09


In [26]:
with open_db(p2db=p2db, k=150) as conn:
    last_readid = conn.execute("SELECT MAX(id) FROM predictions").fetchone()[0]
    readid = conn.execute(f"SELECT readid FROM predictions WHERE id = {last_readid};").fetchone()[0]
    regex = re.compile(r'\d*:ncbi:(?P<read_nb>\d*-\d*)')
    m = regex.search(readid)
    read_nb = '???' if m is None else m.group('read_nb')

last_readid, readid, read_nb

(419328, '11089:ncbi:25-5673', '25-5673')

In [21]:
with open_db(p2db=p2db, k=150) as conn:
    last_predictions_id = conn.execute("SELECT MAX(id) FROM predictions").fetchone()[0]
    nb_predictions = conn.execute("SELECT COUNT(*) FROM predictions").fetchone()[0]
    last_readid = conn.execute(f"SELECT readid FROM predictions WHERE id = {last_predictions_id};").fetchone()[0]

print(f"Last id: {last_predictions_id:,d}, Nbr predictions: {nb_predictions:,d}, last readid: {last_readid}")

Last id: 709,120, Nbr predictions: 709,120, last readid: 11089:ncbi:42-4881


## Technical note to accelerate `skip_existing_predictions`

The slow step in `skip_existing_predictions` is the query to get the total number of rows in the table. We can accelerate this by maintaining an accurate row count in a separate column when inserting new rows into a SQLite table named "predictions" that has a primary key "id" and an indexed column "readid". 

This can be done with the following steps:

1. **Create a trigger** that fires after an INSERT operation on the "predictions" table. The trigger will update the row count in a separate table or column.

2. **Create a table** to store the row count, for example:
```sql
    CREATE TABLE table_stats (
    id INTEGER PRIMARY KEY,
    row_count INTEGER NOT NULL
    );
```

3. **Create the trigger** to update the row count after an INSERT:
```sql
    CREATE TRIGGER update_prediction_count
    AFTER INSERT ON predictions
    BEGIN
    UPDATE table_stats 
    SET row_count = row_count + 1
    WHERE id = 1;
    END;
```

This trigger assumes there is only one row in the "prediction_stats" table with an id of 1. If you want to store the count per "readid", you can modify the trigger to:

```sql
    CREATE TRIGGER update_prediction_count
    AFTER INSERT ON predictions  
    BEGIN
    INSERT INTO table_stats (readid, row_count)
    VALUES (NEW.readid, 1)
    ON CONFLICT(readid) DO UPDATE SET row_count = row_count + 1;
    END;
```

This will insert a new row for each unique "readid" with an initial count of 1, and update the row_count if the "readid" already exists.

4. **Insert a row** into the "prediction_stats" table with an initial count:
```sql
    INSERT INTO table_stats (id, row_count) VALUES (1, 0);
```

Now, whenever a new row is inserted into the "predictions" table, the trigger will automatically update the row count in the "prediction_stats" table. This maintains an accurate count without needing to perform a full table scan with COUNT(*).

Remember to handle deletions as well by creating a BEFORE DELETE trigger that decrements the row count accordingly.

Citations:
- 1. https://www.sqlitetutorial.net/sqlite-insert/
- 2. https://www.sqlitetutorial.net/sqlite-count-function/
- 3. https://stackoverflow.com/questions/55007800/dynamic-way-to-insert-data-into-sqlite-table-when-column-counts-change
- 4. https://docs.python.org/es/3/library/sqlite3.html
- 5. https://www.sql-easy.com/learn/sqlite-count/
- 6. https://stackoverflow.com/questions/4474873/what-is-the-most-efficient-way-to-count-rows-in-a-table-in-sqlite
- 7. https://sqlite.org/forum/info/57c04743e1b6aa10
- 8. https://sqlite.org/forum/forumpost/f832398c19

# End of Section