# Test false positive on one non CoV sequence

**Objective**: 

Evaluate the performance of CNN-V when presented with reads simulated from a non CoV reference sequence but that has some similitudes. We want to identify whether the model generates many false positives.

# Setup


In [None]:
from ecutilities.ipython import nb_setup, pandas_nrows_ncols
nb_setup()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Set autoreload mode


In [None]:
import numpy as np
import os
import pandas as pd
import tempfile

from nbdev import show_doc
from pathlib import Path
from pprint import pprint
from tqdm.notebook import trange, tqdm

from ecutilities.ipython import pandas_nrows_ncols
from metagentools.art import ArtIllumina
from metagentools.cnn_virus.data import FastqFileReader

# Imports and paths

In [None]:
from ecutilities.core import validate_path
from metagentools.art import ArtIllumina, _run
from metagentools.core import TextFileBaseIterator
from metagentools.cnn_virus.architecture import load_model
from metagentools.cnn_virus.data import strings_to_tensors
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'}
from tensorflow.data import TextLineDataset

In [None]:
# Path to the data directory
p2data = Path(f"../../../data").resolve()
assert p2data.is_dir()
print(p2data)

# Path to the non corona virus sequence fasta files
p2ncov_data = p2data / 'ncov_data'
assert p2ncov_data.is_dir()
print(p2ncov_data)

p2inputs = p2ncov_data / 'rhinolophus_ferrumequinum/dna'
assert p2inputs.is_dir()
print(p2inputs)

# Path to the reference sequence file
p2refs = p2inputs/'Rhinolophus_ferrumequinum.mRhiFer1_v1.p.dna_rm.primary_assembly.1.clean.fa'
assert p2refs.is_file()
print(p2refs)

# Path to the simread and input file dir
p2simreads = p2ncov_data / 'ncov_simreads/mRhiFer1_v1.p.dna_rm.primary_assembly.1'
assert p2simreads.is_dir()
print(p2simreads)

/home/vtec/projects/bio/metagentools/data
/home/vtec/projects/bio/metagentools/data/ncov_data
/home/vtec/projects/bio/metagentools/data/ncov_data/rhinolophus_ferrumequinum/dna
/home/vtec/projects/bio/metagentools/data/ncov_data/rhinolophus_ferrumequinum/dna/Rhinolophus_ferrumequinum.mRhiFer1_v1.p.dna_rm.primary_assembly.1.clean.fa
/home/vtec/projects/bio/metagentools/data/ncov_data/ncov_simreads/mRhiFer1_v1.p.dna_rm.primary_assembly.1


# Create simreads

In [None]:
app = 'art_illumina'
art = ArtIllumina(
    path2app=app, 
    input_dir=p2inputs, 
    output_dir=p2simreads,
    app_in_system_path=True
)

Ready to operate with art: art_illumina
Input files from : /home/vtec/projects/bio/metagentools/data/ncov_data/rhinolophus_ferrumequinum/dna
Output files to :  /home/vtec/projects/bio/metagentools/data/ncov_data/ncov_simreads/mRhiFer1_v1.p.dna_rm.primary_assembly.1


In [None]:
# art.sim_reads(
#     input_file='Rhinolophus_ferrumequinum.mRhiFer1_v1.p.dna_rm.primary_assembly.1.clean.fa', 
#     output_seed='mRhiFer1_v1.p.dna_rm.1',
#     sim_type='single', 
#     read_length=50,
#     fold=100, 
#     ss='HS25',
#     overwrite=False
# )

In [None]:
p2fastq

Path('/home/vtec/projects/bio/metagentools/data/ncov_data/ncov_simreads/mRhiFer1_v1.p.dna_rm.primary_assembly.1/mRhiFer1_v1.p.dna_rm.primary_assembly.1.fq')

# Create input file from fastq

In [None]:
def modified_create_infer_ds_from_fastq(
    p2fastq: str|Path,             # Path to the fastq file (aln file path is inferred)
    output_dir:str|Path|None=None, # Path to directory where ds file will be saved
    overwrite_ds:bool=False,       # If True, overwrite existing ds file. If False, error is raised if ds file exists
    nsamples:int|None=None         # Used to limit the number of reads to use for inference, use all if None
)-> (Path, np.ndarray):      # Path to the dataset file, Array with additional read information
    """Build a dataset file for inference only, from simreads fastq to text format ready for the CNN Virus model
    
    > Note: currently also return additional read information as an array. 
    >
    > TODO: consider to save as a file
    """
    if output_dir is None:
        p2dir = Path()
    else:
        validate_path(output_dir, path_type='dir', raise_error=True)
        p2outdir = output_dir if isinstance(output_dir, Path) else Path(output_dir)
    
    p2dataset = p2outdir / f"{p2fastq.stem}_ds"
    if p2dataset.is_file():
        if overwrite_ds: 
            p2dataset.unlink()
        else:
            raise ValueError(f"{p2dataset.name} already exists in {p2dataset.absolute()}")
    p2dataset.touch()
    
    fastq = FastqFileReader(p2fastq)
    
    with open(p2dataset, 'a') as fp:
        i = 1
        for fastq_chunck in fastq.it:
            seq = fastq_chunck['sequence']
            fp.write(f"{seq}\t{0}\t{0}\n")
            i += 1
            if nsamples:
                if i > nsamples: break
                    
    print(f"Dataset with {i-1:,d} reads")    
    return p2dataset

In [None]:
p2aln = p2simreads / 'mRhiFer1_v1.p.dna_rm.primary_assembly.1.aln'
p2fastq = p2simreads / 'mRhiFer1_v1.p.dna_rm.primary_assembly.1.fq'


# modified_create_infer_ds_from_fastq(p2fastq, p2simreads, overwrite_ds=False)

# Get predictions on the dataset

In [None]:
p2saved = p2data / 'saved/cnn_virus_original/pretrained_model.h5'

We have 8.3 million reads

In [None]:
# it = TextFileBaseIterator(p2ds)
# for i, line in enumerate(it): pass
# i

This is too much to handle at once. Quick solution: run the file 

In [None]:
nreads = 250_000
bs = 32
nbatches = 10
print(f"Inference will require {nreads//bs:,d} batches, i.e. about {int(20/1000 * nreads//bs):,d} sec per iteration")

p2ds = p2simreads / 'mRhiFer1_v1.p.dna_rm.primary_assembly.1_ds'
assert p2ds.is_file()

def run_model_per_chunck():
    # make smaller dataset:
    p2saved = p2data / 'saved/cnn_virus_original/pretrained_model.h5'
    model = load_model(p2saved)
    it = TextFileBaseIterator(p2ds, nlines=nreads)

    sars_fp_all = []
    mers_fp_all = []
    nsamples_all = []
    for i, chunck in enumerate(it):
        print(i)
        print(f">>> Creating small dataset with {nreads:,d} reads")
        p2ds_small = p2simreads / f"{p2ds.stem}-small.{p2ds.suffix}"
        if p2ds_small.is_file(): p2ds_small.unlink()
        with open(p2ds_small, 'w') as fp:
            fp.write(next(it))

        print(f">>> Running original model in inference on small dataset")
        text_ds = TextLineDataset(p2ds_small).batch(bs)
        ds = text_ds.map(strings_to_tensors)
        
        prob_preds = model.predict(ds, verbose=1)
        preds_label = np.argmax(prob_preds[0], axis=1)
        preds_pos = np.argmax(prob_preds[1], axis=1)
        sars_fp = (preds_label == 117).sum()
        mers_fp = (preds_label == 94).sum()
        total_samples = preds_label.shape[0]
        sars_fp_all.append(sars_fp)
        mers_fp_all.append(mers_fp)
        nsamples_all.append(total_samples)
        print(f"False Positives")
        print(f"CoV:  Total: {sars_fp:5,d}, Ratio {sars_fp/total_samples:1.3f}")
        print(f"MERS: Total: {mers_fp:5,d}, Ratio {mers_fp/total_samples:1.3f}")
        
        if i+1 >= nbatches: break
            
    return sars_fp_all, mers_fp_all, nsamples_all

sars_fp_all, mers_fp_all, nsamples_all = run_model_per_chunck()

Inference will require 7,812 batches, i.e. about 156 sec per iteration
0
>>> Creating small dataset with 250,000 reads
>>> Running original model in inference on small dataset
False Positives
CoV:  Total: 2,791, Ratio 0.011
MERS: Total: 2,488, Ratio 0.010
1
>>> Creating small dataset with 250,000 reads
>>> Running original model in inference on small dataset
False Positives
CoV:  Total: 2,772, Ratio 0.011
MERS: Total: 2,519, Ratio 0.010
2
>>> Creating small dataset with 250,000 reads
>>> Running original model in inference on small dataset
False Positives
CoV:  Total: 2,691, Ratio 0.011
MERS: Total: 2,580, Ratio 0.010
3
>>> Creating small dataset with 250,000 reads
>>> Running original model in inference on small dataset
False Positives
CoV:  Total: 2,721, Ratio 0.011
MERS: Total: 2,468, Ratio 0.010
4
>>> Creating small dataset with 250,000 reads
>>> Running original model in inference on small dataset
False Positives
CoV:  Total: 2,867, Ratio 0.011
MERS: Total: 2,540, Ratio 0.010
5
>>

In [None]:
print(f"FP sars: {sum(sars_fp_all):,d}, FP mers: {sum(mers_fp_all):,d}, Total Reads: {sum(nsamples_all):,d}")

FP sars: 27,431, FP mers: 25,436, Total Reads: 2,500,000


In [None]:
sum(sars_fp_all) / sum(nsamples_all), sum(mers_fp_all) / sum(nsamples_all)

(0.0109724, 0.0101744)

# Others