# Experiments for Inference and Analysis NCBI Yellow Fever

TBC

# 1. Imports and setup environment

In [1]:
# Install required custom packages if not installed yet.
import importlib.util
if not importlib.util.find_spec('ecutilities'):
    print('installing package: `ecutilities`')
    ! pip install -qqU ecutilities
else:
    print('`ecutilities` already installed')
if not importlib.util.find_spec('metagentools'):
    print('installing package: `metagentools')
    ! pip install -qqU metagentools
else:
    print('`metagentools` already installed')

`ecutilities` already installed
`metagentools` already installed


In [2]:
# Import all required packages
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import re
import sqlite3
import time

from datetime import datetime
from ecutilities.core import files_in_tree
from ecutilities.ipython import nb_setup
from functools import partial
from IPython.display import display, update_display, Markdown, HTML
from pandas import HDFStore
from pathlib import Path
from pprint import pprint
from tqdm.notebook import tqdm, trange

# Setup the notebook for development
nb_setup()

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'}
import tensorflow as tf
from tensorflow.python.client import device_lib
from tensorflow.keras.models import load_model
print(f"Tensorflow version: {tf.__version__}\n")

from metagentools.cnn_virus.data import _base_hot_encode, strings_to_tensors
from metagentools.cnn_virus.data import split_kmer_into_50mers, combine_prediction_batch
from metagentools.cnn_virus.data import FastaFileReader, FastqFileReader, AlnFileReader
from metagentools.cnn_virus.data import OriginalLabels
from metagentools.cnn_virus.data import string_input_batch_to_tensors, split_kmer_batch_into_50mers
from metagentools.cnn_virus.architecture import create_model_original
from metagentools.core import ProjectFileSystem, TextFileBaseReader

Set autoreload mode
Tensorflow version: 2.8.2



List all computing devices available on the machine

In [3]:
devices = device_lib.list_local_devices()
print('\nDevices:')
for d in devices:
    t = d.device_type
    name = d.physical_device_desc
    l = [item.split(':', 1) for item in name.split(', ')]
    name_attr = dict([x for x in l if len(x)==2])
    dev = name_attr.get('name', ' ')
    print(f"  - {t}  {d.name} {dev:25s}")


Devices:
  - CPU  /device:CPU:0                          
  - GPU  /device:GPU:0  NVIDIA GeForce GTX 1050 


# 2. Setup paths to files

Key folders and system information

In [4]:
pfs = ProjectFileSystem()
pfs.info()

Running linux on local computer
Device's home directory: /home/vtec
Project file structure:
 - Root ........ /home/vtec/projects/bio/metagentools 
 - Data Dir .... /home/vtec/projects/bio/metagentools/data 
 - Notebooks ... /home/vtec/projects/bio/metagentools/nbs


- `p2model`: path to file with saved original pretrained model
- `p2virus_labels` path to file with virus names and labels mapping for original model

In [5]:
p2model = pfs.data / 'saved/cnn_virus_original/pretrained_model.h5'
assert p2model.is_file(), f"No file found at {p2model.absolute()}"

p2virus_labels = pfs.data / 'CNN_Virus_data/virus_name_mapping'
assert p2virus_labels.is_file(), f"No file found at {p2virus_labels.absolute()}"

# 3. Load simulated reads and review

Check which simread files are already created:

In [9]:
files_in_tree(pfs.data / 'ncbi/simreads/yf');

simreads
  |--yf
  |    |--readme.md (0)
  |    |--single_1seq_150bp
  |    |    |--single_1seq_150bp.fq (1)
  |    |    |--single_1seq_150bp.aln (2)
  |    |--single_69seq_150bp
  |    |    |--single_69seq_150bp.fq (3)
  |    |    |--single_69seq_150bp.aln (4)
  |    |--paired_1seq_150bp
  |    |    |--paired_1seq_150bp2.aln (5)
  |    |    |--paired_1seq_150bp2.fq (6)
  |    |    |--paired_1seq_150bp1.fq (7)
  |    |    |--paired_1seq_150bp1.aln (8)
  |    |--paired_69seq_150bp
  |    |    |--paired_69seq_150bp1.fq (9)
  |    |    |--paired_69seq_150bp2.fq (10)
  |    |    |--paired_69seq_150bp1.aln (11)
  |    |    |--paired_69seq_150bp2.aln (12)


In [10]:
p2aln = pfs.data / 'ncbi/simreads/yf/single_1seq_150bp/single_1seq_150bp.aln'
assert p2aln.exists()

In [11]:
aln = AlnFileReader(p2aln)
for i, aln_read in enumerate(aln):
    pass
print(f"{i+1:,d} simulated reads in file {p2aln.name} from {len(aln.header['reference sequences'])} refseq.")

print('ART command: ',aln.header['command'])
print('Reference Sequences:\n','\n'.join(aln.header['reference sequences']))

17,000 simulated reads in file single_1seq_150bp.aln from 1 refseq.
ART command:  /usr/bin/art_illumina -i /home/vtec/projects/bio/metagentools/data/ncbi/refsequences/yf/yf_1971_angola.fa -ss HS25 -l 150 -f 250 -o /home/vtec/projects/bio/metagentools/data/ncbi/simreads/yf/single_1seq_150bp/single_1seq_150bp -rs 1724163574
Reference Sequences:
 @SQ	11089:ncbi:1	1	AY968064	11089	ncbi	Angola_1971	10237


# Prediction Loop

Label for Yellow Fever in `118`

In [12]:
# TODO: add functionality to do species search into OriginalLabels
[(k,v) for k, v in OriginalLabels()._species2label.items() if 'yellow' in k.lower()]

[('Yellow_fever_virus', 118)]

Let's test:
- the generator (`aln.cnn_virus_input_generator`)
- the string to tensor transform function (`string_input_batch_to_tensors`)

In [13]:
aln.reset_iterator()
for batch_idxs, batch_reads in aln.cnn_virus_input_generator(bs=8, label=118):
    reads_kmer, (labels_kmer, positions_kmer) = string_input_batch_to_tensors(batch_reads, k=150)
    break

reads_kmer.shape, labels_kmer.shape, positions_kmer.shape

(TensorShape([8, 150, 5]), TensorShape([8, 187]), TensorShape([8, 10]))

Then we apply `split_kmer_batch_into_50mers` to split the kmers into a series of 50-mer reads.

In [14]:
reads_50mer = split_kmer_batch_into_50mers(reads_kmer)
nb_50mer_reads = (reads_kmer.shape[1]-49) * reads_kmer.shape[0]
print('Nb 50-mer reads in a batch: ',nb_50mer_reads)
assert reads_50mer.shape == (nb_50mer_reads, 50,5)

reads_50mer.shape

Nb 50-mer reads in a batch:  808


TensorShape([808, 50, 5])

# Run the Loop

Before running the next cell, define:
- size of the batch `b`
- the number of bp in a k-mer
- the number n for top-n predictions to save
- whether to run the full loop or only a few batches

In [119]:
aln = AlnFileReader(p2aln)
p2aln.stem

'single_1seq_150bp'

In [None]:
aln.re_keys
o = next(aln)
aln.parse_definition_line_with_position(o['definition line']).keys()

dict_keys(['aln_start_pos', 'readid', 'readnb', 'refseq_strand', 'refseqid', 'refseqnb', 'refsource', 'reftaxonomyid', 'read_pos'])

In [23]:
def open_db(p2db: Path, k: int, top_n: int=5):
    conn = sqlite3.connect(p2db)
    cursor = conn.cursor()
    query = "SELECT name FROM sqlite_master WHERE type='table' AND name='predictions'"
    res = cursor.execute(query).fetchone()
    if res is None or 'predictions' not in res:
        print('Creating tables in database...')
        create_tables(cursor, k)
    return conn

def table_columns(cursor, table):
    query = f"PRAGMA table_info({table})"
    cursor.execute(query)
    cols = [row[1] for row in cursor.fetchall()]
    return cols

def create_tables(cursor, k, top_n=5):
    # Create table for predictions
    pred_cols_str = 'readid refseqid refsource refseq_strand taxonomyid'.split(' ')
    pred_cols_int = 'lbl_true lbl_pred pos_true pos_pred'.split(' ')
    top_pred_cols = [f"top_{top_n}_lbl_pred_{i}" for i in range(top_n)]

    query = """
    CREATE TABLE IF NOT EXISTS predictions (
        id INTEGER PRIMARY KEY,
    """
    for col in pred_cols_str:
        query += f"{col} TEXT, "
    for col in pred_cols_int:
        query += f"{col} INTEGER, "
    for col in top_pred_cols:
        query += f"{col} INTEGER, "
    query = query[:-2]+')'
    # print(query)
    cursor.execute(query)

    query = "CREATE INDEX idx_preds ON predictions (readid, refseqid, pos_true);"
    cursor.execute(query)
    print('Table `predictions` created.')

    # Create table for probabilities (one per 50-mer in order to keep small nb or columns in table)

    query = f"""
    CREATE TABLE IF NOT EXISTS label_probabilities (
        id INTEGER PRIMARY KEY,
        read_kmer_id TEXT,
        read_50mer_nb INTEGER,
    """
    query += ','.join([f"prob_{i:03d}" for i in range(187)]) + " REAL, "
    query += "FOREIGN KEY (read_kmer_id) REFERENCES predictions(readid)"
    query += ')'
    cursor.execute(query)

    query = "CREATE INDEX idx_probs ON label_probabilities (read_kmer_id, read_50mer_nb);"
    cursor.execute(query)
    print(f'Table `label_probabilities` created.')

In [24]:
def top_predictions(probs, n=3):

    def top_n_most_frequent(preds, n=3):
        """Returns the top n most frequent predictions for each read"""
        # print(preds.shape)
        uniques, counts = np.unique(preds, return_counts=True)
        top_idx = np.argsort(counts)[-n:]
        return uniques.take(top_idx)

    top_preds_in_50mers = np.argsort(probs, axis=-1)[:, :, -n:]
    nb_seq, nb_50mer, nb_lbls = top_preds_in_50mers.shape
    # print(top_preds_in_50mers.shape)
    top_preds_in_kmer = top_preds_in_50mers.reshape(nb_seq,nb_50mer * nb_lbls)
    # print(top_preds_in_kmer.shape)

    return np.apply_along_axis(top_n_most_frequent, axis=1, arr=top_preds_in_kmer, n=n)

In [None]:
p2db = pfs.data / 'ncbi/infer_results/yf-ncbi' / f'{p2aln.stem}.db'
conn = open_db(p2db, k=150)
conn.close()

Creating tables in database...
Table `predictions` created.
Table `label_probabilities` created.


In [None]:
b = 512   # number of k-mer in a batch
k = 150
true_label = 118
top_n = 5   # n for top-n prediction to keep
run_all_batches = True
nb_batches_to_run = 2

uid = datetime.today().strftime('%Y-%m-%d_%H_%M_%S')
p2db = pfs.data / 'ncbi/infer_results/yf-ncbi' / f'{p2aln.stem}.db'
nb_50mer = k - 49

print(f"Run prediction loop with the following parameters:")
print(f"   {b} k-mer per batch; {k} bp per sequence; keep top-{top_n} predictions")

aln.reset_iterator()
model = create_model_original(path2parameters=p2model)
print(f"Model loaded and ready to run ...")

# open connection to sqlite db
conn = open_db(p2db, k=k)
cursor = conn.cursor()

# create a dataframe to store results
pred_cols_str = 'readid refseqid refsource refseq_strand taxonomyid'.split(' ')
pred_cols_int = 'lbl_true lbl_pred pos_true pos_pred'.split(' ')
top_pred_cols = [f"top_{top_n}_lbl_pred_{i}" for i in range(top_n)]
prob_cols = [f"prob_{i:03d}" for i in range(187)]


def tprint(string):
    print(f"{datetime.now().strftime('%H:%M:%S')}    {string}")

print(f"Starting prediction loop ...")
for i,(metadata_batch, reads_batch) in enumerate(aln.cnn_virus_input_generator(bs=b, label=true_label)):
    loop_start = datetime.now()
    tprint(f"Batch {i+1:3d} ...")

    reads_kmer, (labels_true, position_true) = string_input_batch_to_tensors(reads_batch, k=k)
    reads_50mer = split_kmer_batch_into_50mers(reads_kmer)
    assert reads_50mer.shape == ((reads_kmer.shape[1]-49) * b, 50, 5), f"Problem with shape in batch {i+1}: {reads_50mer.shape}"

    tprint(f'  Starting prediction for {b:,} kmer reads ...')
    label_probs, position_probs = model.predict(reads_50mer)

    tprint('  Reshaping predictions ...')
    label_probs_kmer = tf.reshape(label_probs, shape=(b,nb_50mer,-1))
    position_probs_kmer = tf.reshape(position_probs, shape=(b,nb_50mer,-1))
    # successive_preds = count_successive_label_preds(label_probs_kmer)

    tprint('  Combining predictions ...')
    combined_predictions = tf.map_fn(
        fn=combine_prediction_batch,
        elems=[label_probs_kmer, position_probs_kmer], 
        fn_output_signature=tf.int64
        )

    label_predictions = combined_predictions[:,0]
    position_predictions = combined_predictions[:,1]
    top_preds = top_predictions(label_probs_kmer, n=top_n)

    # Add results for batch
    tprint('  Preparing prediction report ...')
    # 'readid', 'refseqid', 'refsource', 'refseq_strand', 'taxonomyid',
    # 'lbl_true', 'lbl_pred', 'pos_true', 'pos_pred'
    # 'top_5_lbl_pred_0', 'top_5_lbl_pred_1', 'top_5_lbl_pred_2',
    # 'top_5_lbl_pred_3', 'top_5_lbl_pred_4'
    preds_report = np.concatenate(
        [
            np.expand_dims(np.array(metadata_batch['readid']), axis=1),         # readid 
            np.expand_dims(np.array(metadata_batch['refseqid']), axis=1),       # refseqid
            np.expand_dims(np.array(metadata_batch['refsource']), axis=1),      # refsource
            np.expand_dims(np.array(metadata_batch['refseq_strand']), axis=1),  # refseq_strand
            np.expand_dims(np.array(metadata_batch['reftaxonomyid']), axis=1),  # taxonomyid
            np.expand_dims(np.array([true_label]*b), axis=1),                   # lbl_true
            np.expand_dims(label_predictions, axis=1),                          # lbl_pred
            np.expand_dims(np.array(metadata_batch['aln_start_pos']), axis=1),  # pos_true
            np.expand_dims(position_predictions, axis=1),                       # pos_pred
            top_preds[:, ::-1],                                                 # top_5_lbl_pred_0, top_5_lbl_pred_1, top_5_lbl_pred_2, top_5_lbl_pred_3, top_5_lbl_pred_4
        ],
        axis=1
    )

    df_preds = pd.DataFrame(
        data=preds_report, 
        columns=pred_cols_str + pred_cols_int + top_pred_cols
        )
    tprint('  Saving batch prediction report to db...')
    df_preds.to_sql('predictions', conn, if_exists='append', index=False)

    tprint('  Preparing label probabilities report ...')
    df_probs = None
    for read_50mer_nb in range(nb_50mer):
        probs_report = np.concatenate(
            [
                np.expand_dims(np.array(metadata_batch['readid']), axis=1),        # readid 
                np.expand_dims(np.array([read_50mer_nb]*b), axis=1),               # read_50mer_nb
                label_probs_kmer[:, read_50mer_nb, :]
            ],
            axis=1
        )

        df = pd.DataFrame(
            data=probs_report, 
            columns=['read_kmer_id', 'read_50mer_nb'] + prob_cols
            )
        df_probs = df if df_probs is None else pd.concat([df_probs, df], axis=0)

    tprint('  Saving batch label probabilities report to db...')
    df_probs.to_sql('label_probabilities', conn, if_exists='append', index=False)

    tprint(f"  Batch processing time: {(datetime.now() - loop_start).total_seconds():.2f} sec")
    if not run_all_batches and i+1 >= nb_batches_to_run: 
        print('Stopping')
        break

conn.close()

Run prediction loop with the following parameters:
   512 k-mer per batch; 150 bp per sequence; keep top-5 predictions
Creating CNN Model (Original)
Loading parameters from pretrained_model.h5
Created pretrained model
Model loaded and ready to run ...
Creating tables in database...
Table `predictions` created.
Table `label_probabilities` created.
Starting prediction loop ...
20:04:19    Batch   1 ...
20:04:19      Starting prediction for 512 kmer reads ...
20:04:49      Reshaping predictions ...
20:04:49      Combining predictions ...
20:05:01      Preparing prediction report ...
20:05:01      Saving batch prediction report to db...
20:05:01      Preparing label probabilities report ...
20:05:18      Saving batch label probabilities report to db...
20:05:32      Batch processing time: 73.55 sec
20:05:32    Batch   2 ...
20:05:33      Starting prediction for 512 kmer reads ...
20:06:02      Reshaping predictions ...
20:06:02      Combining predictions ...
20:06:13      Preparing predict

Estimated size of DB:
- 17,000 kmer reads -> 6,844,380 kb
- kb per kmer reads = 405 kb

Total nbr of kmer reads for 69 sequences: 1,161,034
- 1,161,034 * 405 kb = 470,000,000 kb = **470 Gb**

Inference time: one batch of 512 kmer read takes 30 s
- 17,000 / 512 * 30 s = 995 s = 16.6 min
- 1,161,034/512 * 30 s = 57,000 s = 15.8 h

In [None]:
with open_db(p2db=p2db, k=150) as conn:
    display(pd.read_sql_query("SELECT * FROM predictions WHERE readid = '11089:ncbi:1-17000';", conn))
    display(pd.read_sql_query("SELECT * FROM label_probabilities WHERE read_kmer_id = '11089:ncbi:1-17000';", conn))

Unnamed: 0,id,readid,refseqid,refsource,refseq_strand,taxonomyid,lbl_true,lbl_pred,pos_true,pos_pred,read_pos,top_5_lbl_pred_0,top_5_lbl_pred_1,top_5_lbl_pred_2,top_5_lbl_pred_3,top_5_lbl_pred_4
0,1,11089:ncbi:1-17000,11089:ncbi:1,ncbi,-,11089,118,32,3164,7,,18,0,32,117,118


Unnamed: 0,id,read_kmer_id,read_50mer_nb,prob_000,prob_001,prob_002,prob_003,prob_004,prob_005,prob_006,...,prob_177,prob_178,prob_179,prob_180,prob_181,prob_182,prob_183,prob_184,prob_185,prob_186
0,1,11089:ncbi:1-17000,0,3.0779117e-09,4.6550543e-16,4.0059014e-09,5.709354e-14,3.985621e-13,1.4781354e-07,3.6368833e-13,...,6.536237e-14,1.3692335e-18,4.8284808e-14,4.28663e-27,1.7636006e-08,2.0901565e-20,1.2649406e-15,3.2826548e-25,1.3461108e-31,1.695962e-21
1,513,11089:ncbi:1-17000,1,7.798936e-12,3.0194006e-16,9.897284e-13,4.3037067e-11,8.870653e-16,5.2778373e-09,5.065814e-15,...,1.4544616e-15,1.7843624e-16,9.904177e-13,3.457488e-21,7.375788e-07,1.7063393e-21,6.5437836e-11,5.6732623e-25,4.585938e-27,1.069439e-21
2,1025,11089:ncbi:1-17000,2,1.0980089e-12,1.9275111e-15,1.2620867e-10,5.787828e-11,5.338331e-13,2.7942612e-11,5.7700195e-10,...,5.753128e-14,1.6327466e-15,4.4719354e-06,6.113753e-17,4.3007142e-07,1.6321036e-17,1.3829523e-09,7.2506775e-19,4.256189e-23,8.687584e-22
3,1537,11089:ncbi:1-17000,3,4.4603436e-09,4.5842836e-14,3.9782825e-12,1.9334909e-08,1.5190636e-10,2.5700672e-08,1.488595e-13,...,2.6302808e-07,7.203852e-11,4.3689414e-07,2.1555071e-17,0.000116948606,5.1792988e-18,1.0885436e-09,3.184696e-18,1.5609967e-19,3.963395e-22
4,2049,11089:ncbi:1-17000,4,8.554256e-06,2.554389e-13,0.0006343455,3.5571188e-07,2.3673221e-09,0.097307,7.587622e-09,...,3.750018e-07,8.665608e-11,6.433393e-09,5.5440495e-20,0.00476838,1.4658719e-17,5.7866656e-13,1.6711752e-17,5.3012568e-21,1.876316e-18
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
96,49153,11089:ncbi:1-17000,96,0.9953505,5.3798974e-20,1.0473234e-18,6.393374e-11,2.8581829e-10,3.2485372e-15,2.7151905e-12,...,5.3093893e-15,1.3183516e-18,3.1293095e-09,2.2244292e-16,1.6725651e-10,3.994397e-15,2.2199666e-19,1.2594476e-13,3.5062112e-13,1.599368e-17
97,49665,11089:ncbi:1-17000,97,0.7476607,8.0384427e-16,3.6662344e-12,1.3936913e-13,3.1677278e-12,3.3102418e-15,1.0410654e-08,...,2.8784293e-12,3.711314e-15,1.8920882e-05,3.1332e-16,4.1800877e-07,5.0522433e-12,1.0039908e-14,6.3892506e-13,5.0943116e-14,1.197173e-14
98,50177,11089:ncbi:1-17000,98,0.9361404,2.664984e-16,7.8721495e-17,3.706427e-12,1.3358595e-14,1.254372e-16,1.8388244e-10,...,1.06356905e-16,4.1285527e-16,1.4968358e-11,2.585157e-18,8.469594e-08,3.6794658e-11,2.3980224e-18,5.1243676e-15,1.01574735e-17,2.688872e-14
99,50689,11089:ncbi:1-17000,99,0.998154,1.2311882e-15,1.7796784e-17,5.404334e-11,1.2882388e-10,1.5932526e-10,8.845724e-11,...,1.8041767e-10,3.959757e-12,3.682707e-12,2.1554382e-12,4.0533246e-06,5.84147e-08,6.353569e-15,1.6207057e-10,1.1721044e-15,3.890514e-13


In [None]:
# with open_db(p2db, 150) as conn:
#     cursor = conn.cursor()
#     cursor.execute("DELETE FROM predictions;")
#     cursor.execute("DELETE FROM label_probabilities;")
#     conn.commit()

# with open_db(p2db, 150) as conn:
#     cursor = conn.cursor()
#     cursor.execute("DROP TABLE predictions")
#     cursor.execute("DROP TABLE label_probabilities")
#     conn.commit()

# conn.close()

# Test saving the database on NAS

In [None]:
p2db = Path('/mnt/k/metagentools') / 'ncbi/infer_results/yf-ncbi' / f'{p2aln.stem}.db'
assert p2db.parent.exists()
p2db

PosixPath('/mnt/k/metagentools/ncbi/infer_results/yf-ncbi/single_1seq_150bp.db')

In [None]:
b = 32   # number of k-mer in a batch
k = 150
true_label = 118
top_n = 5   # n for top-n prediction to keep
run_all_batches = False
nb_batches_to_run = 4

uid = datetime.today().strftime('%Y-%m-%d_%H_%M_%S')
# p2db = Path('/mnt/k') / 'ncbi/infer_results/yf-ncbi' / f'{p2aln.stem}.db'
nb_50mer = k - 49

print(f"Run prediction loop with the following parameters:")
print(f"   {b} k-mer per batch; {k} bp per sequence; keep top-{top_n} predictions")

aln.reset_iterator()
model = create_model_original(path2parameters=p2model)
print(f"Model loaded and ready to run ...")

# open connection to sqlite db
conn = open_db(p2db, k=k)
cursor = conn.cursor()

# create a dataframe to store results
pred_cols_str = 'readid refseqid refsource refseq_strand taxonomyid'.split(' ')
pred_cols_int = 'lbl_true lbl_pred pos_true pos_pred'.split(' ')
top_pred_cols = [f"top_{top_n}_lbl_pred_{i}" for i in range(top_n)]
prob_cols = [f"prob_{i:03d}" for i in range(187)]


def tprint(string):
    print(f"{datetime.now().strftime('%H:%M:%S')}    {string}")

print(f"Starting prediction loop ...")
for i,(metadata_batch, reads_batch) in enumerate(aln.cnn_virus_input_generator(bs=b, label=true_label)):
    loop_start = datetime.now()
    tprint(f"Batch {i+1:3d} ...")

    reads_kmer, (labels_true, position_true) = string_input_batch_to_tensors(reads_batch, k=k)
    reads_50mer = split_kmer_batch_into_50mers(reads_kmer)
    assert reads_50mer.shape == ((reads_kmer.shape[1]-49) * b, 50, 5), f"Problem with shape in batch {i+1}: {reads_50mer.shape}"

    tprint(f'  Starting prediction for {b:,} kmer reads ...')
    label_probs, position_probs = model.predict(reads_50mer)

    tprint('  Reshaping predictions ...')
    label_probs_kmer = tf.reshape(label_probs, shape=(b,nb_50mer,-1))
    position_probs_kmer = tf.reshape(position_probs, shape=(b,nb_50mer,-1))
    # successive_preds = count_successive_label_preds(label_probs_kmer)

    tprint('  Combining predictions ...')
    combined_predictions = tf.map_fn(
        fn=combine_prediction_batch,
        elems=[label_probs_kmer, position_probs_kmer], 
        fn_output_signature=tf.int64
        )

    label_predictions = combined_predictions[:,0]
    position_predictions = combined_predictions[:,1]
    top_preds = top_predictions(label_probs_kmer, n=top_n)

    # Add results for batch
    tprint('  Preparing prediction report ...')
    # 'readid', 'refseqid', 'refsource', 'refseq_strand', 'taxonomyid',
    # 'lbl_true', 'lbl_pred', 'pos_true', 'pos_pred'
    # 'top_5_lbl_pred_0', 'top_5_lbl_pred_1', 'top_5_lbl_pred_2',
    # 'top_5_lbl_pred_3', 'top_5_lbl_pred_4'
    preds_report = np.concatenate(
        [
            np.expand_dims(np.array(metadata_batch['readid']), axis=1),         # readid 
            np.expand_dims(np.array(metadata_batch['refseqid']), axis=1),       # refseqid
            np.expand_dims(np.array(metadata_batch['refsource']), axis=1),      # refsource
            np.expand_dims(np.array(metadata_batch['refseq_strand']), axis=1),  # refseq_strand
            np.expand_dims(np.array(metadata_batch['reftaxonomyid']), axis=1),  # taxonomyid
            np.expand_dims(np.array([true_label]*b), axis=1),                   # lbl_true
            np.expand_dims(label_predictions, axis=1),                          # lbl_pred
            np.expand_dims(np.array(metadata_batch['aln_start_pos']), axis=1),  # pos_true
            np.expand_dims(position_predictions, axis=1),                       # pos_pred
            top_preds[:, ::-1],                                                 # top_5_lbl_pred_0, top_5_lbl_pred_1, top_5_lbl_pred_2, top_5_lbl_pred_3, top_5_lbl_pred_4
        ],
        axis=1
    )

    df_preds = pd.DataFrame(
        data=preds_report, 
        columns=pred_cols_str + pred_cols_int + top_pred_cols
        )
    tprint('  Saving batch prediction report to db...')
    df_preds.to_sql('predictions', conn, if_exists='append', index=False)

    tprint('  Preparing label probabilities report ...')
    df_probs = None
    for read_50mer_nb in range(nb_50mer):
        probs_report = np.concatenate(
            [
                np.expand_dims(np.array(metadata_batch['readid']), axis=1),        # readid 
                np.expand_dims(np.array([read_50mer_nb]*b), axis=1),               # read_50mer_nb
                label_probs_kmer[:, read_50mer_nb, :]
            ],
            axis=1
        )

        df = pd.DataFrame(
            data=probs_report, 
            columns=['read_kmer_id', 'read_50mer_nb'] + prob_cols
            )
        df_probs = df if df_probs is None else pd.concat([df_probs, df], axis=0)

    tprint('  Saving batch label probabilities report to db...')
    df_probs.to_sql('label_probabilities', conn, if_exists='append', index=False)

    tprint(f"  Batch processing time: {(datetime.now() - loop_start).total_seconds():.2f} sec")
    if not run_all_batches and i+1 >= nb_batches_to_run: 
        print('Stopping')
        break

conn.close()

Run prediction loop with the following parameters:
   32 k-mer per batch; 150 bp per sequence; keep top-5 predictions
Creating CNN Model (Original)
Loading parameters from pretrained_model.h5
Created pretrained model
Model loaded and ready to run ...
Starting prediction loop ...
21:26:08    Batch   1 ...
21:26:08      Starting prediction for 32 kmer reads ...
21:26:12      Reshaping predictions ...
21:26:12      Combining predictions ...
21:26:13      Preparing prediction report ...
21:26:13      Saving batch prediction report to db...
21:26:13      Preparing label probabilities report ...
21:26:14      Saving batch label probabilities report to db...
21:26:17      Batch processing time: 9.23 sec
21:26:17    Batch   2 ...
21:26:18      Starting prediction for 32 kmer reads ...
21:26:19      Reshaping predictions ...
21:26:19      Combining predictions ...
21:26:20      Preparing prediction report ...
21:26:20      Saving batch prediction report to db...
21:26:20      Preparing label pr

In [None]:
with open_db(p2db=p2db, k=150) as conn:
    display(pd.read_sql_query("SELECT * FROM predictions WHERE readid = '11089:ncbi:1-17000';", conn))
    display(pd.read_sql_query("SELECT * FROM label_probabilities WHERE read_kmer_id = '11089:ncbi:1-17000';", conn))
    

Unnamed: 0,id,readid,refseqid,refsource,refseq_strand,taxonomyid,lbl_true,lbl_pred,pos_true,pos_pred,top_5_lbl_pred_0,top_5_lbl_pred_1,top_5_lbl_pred_2,top_5_lbl_pred_3,top_5_lbl_pred_4
0,1,11089:ncbi:1-17000,11089:ncbi:1,ncbi,-,11089,118,32,3164,7,18,0,32,117,118


Unnamed: 0,id,read_kmer_id,read_50mer_nb,prob_000,prob_001,prob_002,prob_003,prob_004,prob_005,prob_006,...,prob_177,prob_178,prob_179,prob_180,prob_181,prob_182,prob_183,prob_184,prob_185,prob_186
0,1,11089:ncbi:1-17000,0,3.0779117e-09,4.6550543e-16,4.0059014e-09,5.709354e-14,3.985621e-13,1.4781354e-07,3.6368833e-13,...,6.536237e-14,1.3692335e-18,4.8284808e-14,4.28663e-27,1.7636006e-08,2.0901565e-20,1.2649406e-15,3.2826548e-25,1.3461108e-31,1.695962e-21
1,33,11089:ncbi:1-17000,1,7.798936e-12,3.0194006e-16,9.897284e-13,4.3037067e-11,8.870653e-16,5.2778373e-09,5.065814e-15,...,1.4544616e-15,1.7843624e-16,9.904177e-13,3.457488e-21,7.375788e-07,1.7063393e-21,6.5437836e-11,5.6732623e-25,4.585938e-27,1.069439e-21
2,65,11089:ncbi:1-17000,2,1.0980089e-12,1.9275111e-15,1.2620867e-10,5.787828e-11,5.338331e-13,2.7942612e-11,5.7700195e-10,...,5.753128e-14,1.6327466e-15,4.4719354e-06,6.113753e-17,4.3007142e-07,1.6321036e-17,1.3829523e-09,7.2506775e-19,4.256189e-23,8.687584e-22
3,97,11089:ncbi:1-17000,3,4.4603436e-09,4.5842836e-14,3.9782825e-12,1.9334909e-08,1.5190636e-10,2.5700672e-08,1.488595e-13,...,2.6302808e-07,7.203852e-11,4.3689414e-07,2.1555071e-17,0.000116948606,5.1792988e-18,1.0885436e-09,3.184696e-18,1.5609967e-19,3.963395e-22
4,129,11089:ncbi:1-17000,4,8.554256e-06,2.554389e-13,0.0006343455,3.5571188e-07,2.3673221e-09,0.097307,7.587622e-09,...,3.750018e-07,8.665608e-11,6.433393e-09,5.5440495e-20,0.00476838,1.4658719e-17,5.7866656e-13,1.6711752e-17,5.3012568e-21,1.876316e-18
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
96,3073,11089:ncbi:1-17000,96,0.9953505,5.3798974e-20,1.0473234e-18,6.393374e-11,2.8581829e-10,3.2485372e-15,2.7151905e-12,...,5.3093893e-15,1.3183516e-18,3.1293095e-09,2.2244292e-16,1.6725651e-10,3.994397e-15,2.2199666e-19,1.2594476e-13,3.5062112e-13,1.599368e-17
97,3105,11089:ncbi:1-17000,97,0.7476607,8.0384427e-16,3.6662344e-12,1.3936913e-13,3.1677278e-12,3.3102418e-15,1.0410654e-08,...,2.8784293e-12,3.711314e-15,1.8920882e-05,3.1332e-16,4.1800877e-07,5.0522433e-12,1.0039908e-14,6.3892506e-13,5.0943116e-14,1.197173e-14
98,3137,11089:ncbi:1-17000,98,0.9361404,2.664984e-16,7.8721495e-17,3.706427e-12,1.3358595e-14,1.254372e-16,1.8388244e-10,...,1.06356905e-16,4.1285527e-16,1.4968358e-11,2.585157e-18,8.469594e-08,3.6794658e-11,2.3980224e-18,5.1243676e-15,1.01574735e-17,2.688872e-14
99,3169,11089:ncbi:1-17000,99,0.998154,1.2311882e-15,1.7796784e-17,5.404334e-11,1.2882388e-10,1.5932526e-10,8.845724e-11,...,1.8041767e-10,3.959757e-12,3.682707e-12,2.1554382e-12,4.0533246e-06,5.84147e-08,6.353569e-15,1.6207057e-10,1.1721044e-15,3.890514e-13


In [None]:
with open_db(p2db, 150) as conn:
    cursor = conn.cursor()
    # query="PRAGMA index_list('predictions');"
    # res = cursor.execute(query)
    # print(res.fetchall())
    query="PRAGMA index_info('idx_preds')"
    print(cursor.execute(query).fetchall())
    query="PRAGMA index_info('idx_probs')"
    print(cursor.execute(query).fetchall())

[(0, 1, 'readid'), (1, 2, 'refseqid'), (2, 8, 'pos_true')]
[(0, 1, 'read_kmer_id'), (1, 2, 'read_50mer_nb')]


# Inference for 25% of the simreads on 69 sequences

Estimate required space for full inference with probabilities is *470 Gb*. This is too large to save even on my NAS. Will first build a table with 25% of the total reads.

In [25]:
p2aln = pfs.data / 'ncbi/simreads/yf/single_69seq_150bp/single_69seq_150bp.aln'
assert p2aln.exists()
p2db = Path('/mnt/k/metagentools') / 'ncbi/infer_results/yf-ncbi' / f'{p2aln.stem}.db'
assert p2db.parent.exists()
files_in_tree(p2db.parent);

infer_results
  |--yf-ncbi
  |    |--single_1seq_150bp.db (0)
  |    |--single_69seq_150bp-bck.db (1)
  |    |--single_69seq_150bp.db (2)


In [26]:
aln = AlnFileReader(p2aln)
for nb_kmer_reads, o in enumerate(aln):
    pass
print(f"Total number of reads: {nb_kmer_reads:,d}")

Total number of reads: 1,161,033


In [43]:
def skip_existing_predictions(gen, p2db, bs):
    # Identify the readnb for the last saved prediction
    print('Checking predictions already in database...')
    with open_db(p2db=p2db, k=150) as conn:
        last_predictions_id = conn.execute("SELECT MAX(id) FROM predictions").fetchone()[0]
        nb_predictions = conn.execute("SELECT COUNT(*) FROM predictions").fetchone()[0]
        last_readid = conn.execute(f"SELECT readid FROM predictions WHERE id = {last_predictions_id};").fetchone()[0]
        # regex = re.compile(r'\d*:ncbi:(?P<kmer_read>\d-\d*)')
        # match = regex.search(last_readid)
        # kmer_read = match.group('read_nb')

    print(f"Database includes {nb_predictions:,d} predictions, corresponding to {nb_predictions//bs:,d} batches")
    print(f"Last prediction id: {last_predictions_id:,d} for kmer read '{last_readid}'")
    print(f"Skipping already processed predictions ...")

    for i, (batch_meta, batch_reads) in enumerate(gen):
        if i%100 == 0: 
            print(f"   Skipped first {(i+1)*bs:,d} kmer reads ({i+1:,d} batches)")
 
        if last_readid in batch_meta['readid']:
            print(f"   Reached last batch of saved prediction (batch {i+1:,d})")
            print('Can procees with normal prediction inference')
            nb_kmer_reads_skipped = nb_predictions
            nb_batches_skipped = i+1
            break
    return nb_batches_skipped, nb_kmer_reads_skipped

# aln.reset_iterator()
# gen2 = aln.cnn_virus_input_generator(bs=512, label=118)
# nbs, nrs = skip_existing_predictions(gen=gen2, p2db=p2db, bs=512)

In [None]:
b = 512   # number of k-mer in a batch
k = 150
true_label = 118
top_n = 5   # n for top-n prediction to keep
run_all_batches = False
nb_batches_to_run = int(nb_kmer_reads / b * 0.25)

uid = datetime.today().strftime('%Y-%m-%d_%H_%M_%S')
nb_50mer = k - 49

print(f"Run prediction loop with the following parameters:")
print(f"   {b} k-mer per batch; {k} bp per sequence; keep top-{top_n} predictions")

aln.reset_iterator()

model = create_model_original(path2parameters=p2model)
print(f"Model loaded and ready to run ...")

# open connection to sqlite db
conn = open_db(p2db, k=k)
cursor = conn.cursor()

# create a dataframe to store results
pred_cols_str = 'readid refseqid refsource refseq_strand taxonomyid'.split(' ')
pred_cols_int = 'lbl_true lbl_pred pos_true pos_pred'.split(' ')
top_pred_cols = [f"top_{top_n}_lbl_pred_{i}" for i in range(top_n)]
prob_cols = [f"prob_{i:03d}" for i in range(187)]


def tprint(string):
    print(f"{datetime.now().strftime('%H:%M:%S')}    {string}")

print(f"Starting prediction loop ...")
gen = aln.cnn_virus_input_generator(bs=b, label=true_label)

# Skip kmer reads that are already processed
nb_batches_skipped, nb_kmer_reads_skipped = skip_existing_predictions(gen=gen, p2db=p2db, bs=b)

# Proceed with prediction inference 
for i,(metadata_batch, reads_batch) in enumerate(gen):
    loop_start = datetime.now()
    tprint(f"Batch {i+1:3,d} (aln batch {nb_batches_skipped+i+1:3,d}) ...")

    reads_kmer, (labels_true, position_true) = string_input_batch_to_tensors(reads_batch, k=k)
    reads_50mer = split_kmer_batch_into_50mers(reads_kmer)
    assert reads_50mer.shape == ((reads_kmer.shape[1]-49) * b, 50, 5), f"Problem with shape in batch {i+1}: {reads_50mer.shape}"

    tprint(f'  Starting prediction for {b:,} kmer reads ...')
    label_probs, position_probs = model.predict(reads_50mer)

    tprint('  Reshaping predictions ...')
    label_probs_kmer = tf.reshape(label_probs, shape=(b,nb_50mer,-1))
    position_probs_kmer = tf.reshape(position_probs, shape=(b,nb_50mer,-1))

    tprint('  Combining predictions ...')
    combined_predictions = tf.map_fn(
        fn=combine_prediction_batch,
        elems=[label_probs_kmer, position_probs_kmer], 
        fn_output_signature=tf.int64
        )

    label_predictions = combined_predictions[:,0]
    position_predictions = combined_predictions[:,1]
    top_preds = top_predictions(label_probs_kmer, n=top_n)

    # Add results for current batch
    tprint('  Preparing prediction report ...')
    preds_report = np.concatenate(
        [
            np.expand_dims(np.array(metadata_batch['readid']), axis=1),         # readid 
            np.expand_dims(np.array(metadata_batch['refseqid']), axis=1),       # refseqid
            np.expand_dims(np.array(metadata_batch['refsource']), axis=1),      # refsource
            np.expand_dims(np.array(metadata_batch['refseq_strand']), axis=1),  # refseq_strand
            np.expand_dims(np.array(metadata_batch['reftaxonomyid']), axis=1),  # taxonomyid
            np.expand_dims(np.array([true_label]*b), axis=1),                   # lbl_true
            np.expand_dims(label_predictions, axis=1),                          # lbl_pred
            np.expand_dims(np.array(metadata_batch['aln_start_pos']), axis=1),  # pos_true
            np.expand_dims(position_predictions, axis=1),                       # pos_pred
            top_preds[:, ::-1],                                                 # top_5_lbl_pred_0, top_5_lbl_pred_1, top_5_lbl_pred_2, top_5_lbl_pred_3, top_5_lbl_pred_4
        ],
        axis=1
    )

    df_preds = pd.DataFrame(
        data=preds_report, 
        columns=pred_cols_str + pred_cols_int + top_pred_cols
        )
    tprint('  Saving batch prediction report to db...')
    df_preds.to_sql('predictions', conn, if_exists='append', index=False)

    tprint('  Preparing label probabilities report ...')
    df_probs = None
    for read_50mer_nb in range(nb_50mer):
        probs_report = np.concatenate(
            [
                np.expand_dims(np.array(metadata_batch['readid']), axis=1),        # readid 
                np.expand_dims(np.array([read_50mer_nb]*b), axis=1),               # read_50mer_nb
                label_probs_kmer[:, read_50mer_nb, :]
            ],
            axis=1
        )

        df = pd.DataFrame(
            data=probs_report, 
            columns=['read_kmer_id', 'read_50mer_nb'] + prob_cols
            )
        df_probs = df if df_probs is None else pd.concat([df_probs, df], axis=0)

    tprint('  Saving batch label probabilities report to db...')
    df_probs.to_sql('label_probabilities', conn, if_exists='append', index=False)

    tprint(f"  Batch processing time: {(datetime.now() - loop_start).total_seconds():.2f} sec")
    if not run_all_batches and i+1 >= nb_batches_to_run: 
        print('Stopping')
        break

conn.close()

In [47]:
# with open_db(p2db=p2db, k=150) as conn:
#     cursor=conn.cursor()
#     display(table_columns(cursor=cursor, table='predictions'))

In [45]:
with open_db(p2db=p2db, k=150) as conn:
    display(pd.read_sql_query("SELECT * FROM predictions WHERE readid = '11089:ncbi:1-17000';", conn))
    display(pd.read_sql_query("SELECT * FROM label_probabilities WHERE read_kmer_id = '11089:ncbi:1-17000';", conn))

Unnamed: 0,id,readid,refseqid,refsource,refseq_strand,taxonomyid,lbl_true,lbl_pred,pos_true,pos_pred,top_5_lbl_pred_0,top_5_lbl_pred_1,top_5_lbl_pred_2,top_5_lbl_pred_3,top_5_lbl_pred_4
0,1,11089:ncbi:1-17000,11089:ncbi:1,ncbi,-,11089,118,10,7804,0,118,32,10,62,94


Unnamed: 0,id,read_kmer_id,read_50mer_nb,prob_000,prob_001,prob_002,prob_003,prob_004,prob_005,prob_006,...,prob_177,prob_178,prob_179,prob_180,prob_181,prob_182,prob_183,prob_184,prob_185,prob_186
0,1,11089:ncbi:1-17000,0,9.452837e-16,1.8033755e-24,3.709243e-18,3.2934015e-16,1.11393635e-13,4.909559e-09,3.2027174e-19,...,4.620669e-10,2.1956457e-23,4.973716e-17,1.2870734e-19,6.685301e-13,2.7962992e-13,5.0206733e-22,2.3932547e-15,1.2987661e-21,3.557021e-18
1,513,11089:ncbi:1-17000,1,9.987287e-13,9.640493e-23,6.864665e-13,5.297059e-17,5.2123617e-09,1.9526384e-07,6.811955e-14,...,0.00029164975,9.457506e-20,4.247617e-15,2.2502056e-20,1.3939608e-09,4.1289537e-12,1.0989764e-16,5.3598934e-11,1.09030056e-14,1.850025e-18
2,1025,11089:ncbi:1-17000,2,4.1150392e-14,1.0409154e-21,3.1731182e-18,4.837065e-13,1.8273633e-12,5.6225136e-13,1.5127502e-14,...,2.7924674e-08,9.881187e-25,1.2789581e-19,5.2265872e-23,3.6289507e-16,5.854743e-18,4.597412e-22,2.7675505e-16,3.6334886e-19,7.191011e-22
3,1537,11089:ncbi:1-17000,3,2.4102365e-13,3.9195824e-19,4.2152547e-16,5.363886e-15,7.9354416e-13,1.9598592e-10,2.0204985e-14,...,2.3507164e-12,3.4494873e-22,1.4030163e-15,5.3813346e-20,2.1669545e-18,3.9168753e-20,4.6393943e-18,2.162604e-17,5.868014e-17,3.671587e-12
4,2049,11089:ncbi:1-17000,4,5.402446e-15,1.3982798e-24,5.1237304e-19,2.6393469e-12,7.070969e-12,4.5762413e-08,3.941572e-13,...,2.2205168e-10,2.4770586e-20,2.2223746e-16,1.3230725e-16,1.473535e-13,1.0368611e-15,7.554584e-18,6.905381e-15,6.823121e-19,5.281248e-14
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
96,49153,11089:ncbi:1-17000,96,1.266811e-12,1.7694516e-13,6.9926794e-11,7.442386e-16,2.1827995e-12,2.2899127e-11,6.6601347e-13,...,2.1055688e-09,2.6956715e-11,4.02884e-10,3.6730546e-06,7.332998e-13,9.214267e-17,4.41918e-18,1.0378567e-09,3.9226554e-16,1.512387e-09
97,49665,11089:ncbi:1-17000,97,3.6751678e-11,5.2148962e-11,1.1142021e-07,1.3979555e-13,0.00023110698,1.8854205e-11,1.128952e-07,...,3.426913e-08,3.5629816e-17,2.7815642e-11,7.9123413e-07,9.112657e-12,1.8065088e-14,2.2273976e-15,2.6934228e-07,3.2327324e-15,1.599009e-12
98,50177,11089:ncbi:1-17000,98,3.987377e-10,3.995657e-08,4.0260434e-06,4.5705413e-09,0.00026610115,3.9758719e-10,1.0729461e-07,...,6.970071e-08,1.210523e-17,2.2871915e-10,3.9424535e-06,1.6976076e-09,4.1861083e-09,3.708444e-10,6.0206723e-05,4.0571067e-09,5.209376e-10
99,50689,11089:ncbi:1-17000,99,3.773077e-12,6.944977e-13,6.4031043e-09,3.3788874e-09,7.6050675e-08,4.657467e-08,3.358606e-07,...,3.0940438e-08,1.2781211e-18,2.6286571e-08,1.7933372e-08,4.844457e-09,1.0032285e-13,2.9824456e-15,3.3517054e-08,4.4160476e-16,2.173458e-09


In [52]:
with open_db(p2db=p2db, k=150) as conn:
    tprint('Retrieve last prediction id and total nb predictions...')
    time_starts = datetime.now()
    results = conn.execute("SELECT MAX(id) FROM predictions").fetchone()
    tprint(f"{results[0]:,d} ({(datetime.now()-time_starts).total_seconds():.2f} sec)")
    time_starts = datetime.now()
    results = conn.execute("SELECT COUNT(readid) FROM predictions WHERE readid IS NOT NULL;").fetchone()
    tprint(f"{results[0]:,d} ({(datetime.now()-time_starts).total_seconds():.2f} sec)")
    time_starts = datetime.now()
    results = conn.execute("SELECT COUNT(*) FROM predictions").fetchone()
    tprint(f"{results[0]:,d} ({(datetime.now()-time_starts).total_seconds():.2f} sec)")


13:21:22    Retrieve last prediction id and total nb predictions...
13:21:22    388,096 (0.03 sec)
13:21:47    388,096 (24.83 sec)
13:21:48    388,096 (0.68 sec)


In [60]:
with open_db(p2db=p2db, k=150) as conn:
    last_readid = conn.execute("SELECT MAX(id) FROM predictions").fetchone()[0]
    readid = conn.execute(f"SELECT readid FROM predictions WHERE id = {last_readid};").fetchone()[0]
    regex = re.compile(r'\d*:ncbi:(?P<read_nb>\d*-\d*)')
    m = regex.search(readid)
    read_nb = '???' if m is None else m.group('read_nb')

last_readid, readid, read_nb



(388096, '11089:ncbi:23-2905', '23-2905')

## Technical note to accelerate `skip_existing_predictions`

The slow step in `skip_existing_predictions` is the query to get the total number of rows in the table. We can accelerate this by maintaining an accurate row count in a separate column when inserting new rows into a SQLite table named "predictions" that has a primary key "id" and an indexed column "readid". 

This can be done with the following steps:

1. **Create a trigger** that fires after an INSERT operation on the "predictions" table. The trigger will update the row count in a separate table or column.

2. **Create a table** to store the row count, for example:
```sql
    CREATE TABLE table_stats (
    id INTEGER PRIMARY KEY,
    row_count INTEGER NOT NULL
    );
```

3. **Create the trigger** to update the row count after an INSERT:
```sql
    CREATE TRIGGER update_prediction_count
    AFTER INSERT ON predictions
    BEGIN
    UPDATE table_stats 
    SET row_count = row_count + 1
    WHERE id = 1;
    END;
```

This trigger assumes there is only one row in the "prediction_stats" table with an id of 1. If you want to store the count per "readid", you can modify the trigger to:

```sql
    CREATE TRIGGER update_prediction_count
    AFTER INSERT ON predictions  
    BEGIN
    INSERT INTO table_stats (readid, row_count)
    VALUES (NEW.readid, 1)
    ON CONFLICT(readid) DO UPDATE SET row_count = row_count + 1;
    END;
```

This will insert a new row for each unique "readid" with an initial count of 1, and update the row_count if the "readid" already exists.

4. **Insert a row** into the "prediction_stats" table with an initial count:
```sql
    INSERT INTO table_stats (id, row_count) VALUES (1, 0);
```

Now, whenever a new row is inserted into the "predictions" table, the trigger will automatically update the row count in the "prediction_stats" table. This maintains an accurate count without needing to perform a full table scan with COUNT(*).

Remember to handle deletions as well by creating a BEFORE DELETE trigger that decrements the row count accordingly.

Citations:
- 1. https://www.sqlitetutorial.net/sqlite-insert/
- 2. https://www.sqlitetutorial.net/sqlite-count-function/
- 3. https://stackoverflow.com/questions/55007800/dynamic-way-to-insert-data-into-sqlite-table-when-column-counts-change
- 4. https://docs.python.org/es/3/library/sqlite3.html
- 5. https://www.sql-easy.com/learn/sqlite-count/
- 6. https://stackoverflow.com/questions/4474873/what-is-the-most-efficient-way-to-count-rows-in-a-table-in-sqlite
- 7. https://sqlite.org/forum/info/57c04743e1b6aa10
- 8. https://sqlite.org/forum/forumpost/f832398c19

# End of Section