# Inference with original cnn_virus trained model

- Inference only, no training.
- Experiments with cov_reads files
- Trying locally

# Imports and setup environment

### Install and import packages

In [5]:
try:
    from ecutilities.ipython import nb_setup
    print('`ecutilities` already installed')
except ModuleNotFoundError as e:
    print('installing ecutilities')
    !pip install -qqU ecutilities
    from ecutilities.ipython import nb_setup

nb_setup()

installing ecutilities
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
[?25hSet autoreload mode


In [6]:
import numpy as np
import pandas as pd
import os
import tensorflow as tf

from pathlib import Path
from pprint import pprint
from tensorflow.python.client import device_lib
from tensorflow.keras.models import load_model

In [7]:
print(f"Tensorflow version: {tf.__version__}\n")

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'}

devices = device_lib.list_local_devices()
print('\nDevices:')
for d in devices:
    t = d.device_type
    name = d.physical_device_desc
    l = [item.split(':', 1) for item in name.split(', ')]
    name_attr = dict([x for x in l if len(x)==2])
    dev = name_attr.get('name', ' ')
    print(f"  - {t}  {d.name} {dev:25s}")

Tensorflow version: 2.9.2


Devices:
  - CPU  /device:CPU:0                          


## Install and import custom code, mount gdrive

In [8]:
# MODIFY AFTER INSTALLING ECUTILITIES >= 1.2.3
from __future__ import annotations
from ecutilities.core import validate_path

def safe_path(
    path:str|Path, # path to validate
)-> Path:          # validated path as a  pathlib.Path
    """"""
    validate_path(path, path_type='dir', raise_error=True)
    if isinstance(path, str): 
        path = Path(path)
    return path

def nbs_root_dir(
    path:str|Path|None = None, # path from where to seek for notebook parent directory
    pattern:str = 'nbs',       # pattern to identify the nbs directory
)-> Path:                      # path of the parent directory
    """Climb the directory tree up to the notebook directory and return its path"""    
    if path is None: path = Path()
    path = safe_path(path).absolute()
    tree = [path.name] + [p.name for p in path.parents]
    mask = [True if n.startswith(pattern) else False for n in tree]
    nbs = Path(f"{'../' * (mask.index(True))}").resolve()
    return nbs
    
# nbs = nbs_root_dir(Path())
# nbs

In [9]:
try:
    from google.colab import drive
    ON_COLAB = True
    print('Running on colab')
    print('Installing custom project code')   
    # !pip install -qqU git+https://github.com/vtecftwy/metagentools.git@cnn_virus
    !pip install -qqU metagentools
    drive.mount('/content/gdrive')
    
    p2drive = Path('/content/gdrive/MyDrive/Metagenonics')
    assert p2drive.is_dir()
    p2data =  p2drive / 'data'
    assert p2data.is_dir()

except ModuleNotFoundError as e:
    # print(e)
    ON_COLAB = False
    print('Running locally')
    print('Make sure you have installed the custom project code in your environment')
    p2data = nbs_root_dir().parent / 'data/'
    assert p2data.is_dir()
    print(p2data.absolute())

Running on colab
Installing custom project code
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.3/42.3 KB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.0/184.0 KB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m178.9/178.9 KB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 KB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone
Mounted at /content/gdrive


## Access AWS

https://realpython.com/python-boto3-aws-s3/#common-operations

https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#guide-configuration

In [1]:
!pip install -qqU urllib3>1.26.6

In [2]:
import urllib3
urllib3.__version__

'1.26.14'

In [3]:
!pip install -qqU boto3

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.7/132.7 KB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.4/10.4 MB[0m [31m31.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.6/79.6 KB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [4]:
import boto3

In [10]:
from ecutilities.core import get_config_value

In [14]:
import uuid

In [11]:
p2config = Path('/content/gdrive/MyDrive/private-across-accounts/config-api-keys.cfg')

In [12]:
os.environ['AWS_ACCESS_KEY_ID'] = get_config_value(section='aws', key='aws_access_key_id',path_to_config_file=p2config)
os.environ['AWS_SECRET_ACCESS_KEY'] = get_config_value(section='aws', key='aws_secret_access_key',path_to_config_file=p2config)

In [13]:
s3_client = boto3.client('s3')
s3_resource = boto3.resource('s3')

In [23]:
def create_temp_file(size, file_name, file_content):
    random_file_name = ''.join([str(uuid.uuid4().hex[:6]), file_name])
    with open(random_file_name, 'w') as f:
        f.write(str(file_content) * size)
    return random_file_name

In [24]:
bucket_name = 'bio.cnn-virus.data'
first_bucket = s3_resource.Bucket(name=bucket_name)

In [25]:
fname_1 = create_temp_file(300, 'file1.txt', 'f' )
s3_resource.Bucket(bucket_name).upload_file(
    Filename=fname_1, 
    Key=fname_1
    )

In [26]:
fname_2 = create_temp_file(300, 'file2.txt', 'f' )
s3_resource.meta.client.upload_file(
    Filename=fname_2, 
    Bucket=bucket_name,
    Key=fname_2)

In [29]:
os.makedirs('/content/dwld', exist_ok=True)
s3_resource.Object(
    bucket_name, 
    fname_1).download_file(
    f'/content/dwld/{fname_1}')

In [10]:
from metagentools.cnn_virus.data import strings_to_tensors
from metagentools.cnn_virus.data import FastaFileReader, FastqFileReader, AlnFileReader, parse_metadata_art_read_aln
from metagentools.core import TextFileBaseIterator

# Experiments with simulated reads

## Setup paths

This assumes that the shared gdrive directory is accessible through a shortcut called `Metagenomics` under the root of gdrive.

In [None]:
# p2drive = Path('/content/gdrive/MyDrive/Metagenonics')
# assert p2drive.is_dir()

# p2data =  p2drive / 'CNN_Virus_data'
# assert p2data.is_dir()

In [None]:
# path for original trained model
p2saved = p2data / 'saved/cnn_virus_original/pretrained_model.h5'
p2simreads = p2data / 'cov_simreads/single_10seq_50bp'
p2virus_labels = p2data / 'CNN_Virus_data/virus_name_mapping'
assert p2saved.is_file()
assert p2simreads.is_dir()
assert p2virus_labels.is_file()

#path for the learning weights file
# filepath_weights=p2data / "weight_of_classes"
# assert filepath_weights.is_file()

## Explore simread output files

In [None]:
p2fasta = p2data / 'cov_data/cov_virus_sequences_ten.fa'
p2fastq = p2simreads / f"{p2simreads.stem}.fq"
p2aln = p2simreads / f"{p2simreads.stem}.aln"
assert p2fastq.is_file()
assert p2aln.is_file()
assert p2fasta.is_file()

In [None]:
from metagentools.cnn_virus.data import FastaFileReader, FastqFileReader, AlnFileReader, parse_metadata_art_read_aln
from metagentools.core import TextFileBaseIterator

In [None]:
fasta = FastaFileReader(p2fasta)
fastq = FastqFileReader(p2fastq)
aln = AlnFileReader(p2aln)

### Exploring the simreads in fastaq

In [None]:
for i, (k,v) in enumerate(fastq.parse_fastq(add_seq=True).items()):
    print(k)
    print()
    pprint(v)
    if i+1 >= 3: break

2591237:ncbi:1-60400

{'read_nbr': '60400',
 'readid': '2591237:ncbi:1-60400',
 'seq_nbr': '1',
 'seqid': '2591237',
 'sequence': 'ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG',
 'source': 'ncbi'}
2591237:ncbi:1-60399

{'read_nbr': '60399',
 'readid': '2591237:ncbi:1-60399',
 'seq_nbr': '1',
 'seqid': '2591237',
 'sequence': 'GATCAATGTGGCATCTACAATACAGACAGCATGAAGCACCACCAAAGGAC',
 'source': 'ncbi'}
2591237:ncbi:1-60398

{'read_nbr': '60398',
 'readid': '2591237:ncbi:1-60398',
 'seq_nbr': '1',
 'seqid': '2591237',
 'sequence': 'ATCTACCAGTGGTAGATGGGTTCTTAATAATGAACATTATAGAGCTCTAC',
 'source': 'ncbi'}


### Exploring ALN header (reference sequences)

In [None]:
pprint(aln.ref_sequences)

{'11120': {'refseq_accession': 'MN987231',
           'refseq_length': '27617',
           'refseq_nbr': '5',
           'refseqid': '11120',
           'source': 'ncbi',
           'species': 'Infectious bronchitis virus  scientific name'},
 '11128': {'refseq_accession': 'LC494191',
           'refseq_length': '30942',
           'refseq_nbr': '2',
           'refseqid': '11128',
           'source': 'ncbi',
           'species': 'Bovine coronavirus  scientific name'},
 '1699095': {'refseq_accession': 'KT368904',
             'refseq_length': '27395',
             'refseq_nbr': '10',
             'refseqid': '1699095',
             'source': 'ncbi',
             'species': 'Camel alphacoronavirus  scientific name'},
 '2591237': {'refseq_accession': 'MK211378',
             'refseq_length': '30213',
             'refseq_nbr': '1',
             'refseqid': '2591237',
             'source': 'ncbi',
             'species': 'Coronavirus BtRs-BetaCoV/YN2018D  scientific name'},
 '277944': {

### Exploring read's metadata

In [None]:
for i, (k,v) in enumerate(aln.parse_aln(add_ref_seq_aligned=True).items()):
    print(k)
    pprint(v)
    print()
    if i+1>=3: break

2591237:ncbi:1-60400
{'aln_start_pos': '14770',
 'read_nbr': '60400',
 'readid': '2591237:ncbi:1-60400',
 'ref_seq_aligned': 'ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG',
 'ref_seq_strand': '+',
 'refseq_nbr': '1',
 'refseqid': '2591237',
 'source': 'ncbi'}

2591237:ncbi:1-60399
{'aln_start_pos': '17012',
 'read_nbr': '60399',
 'readid': '2591237:ncbi:1-60399',
 'ref_seq_aligned': 'GATCAATGTGGCATCTACAATACAGACAGCATGAAGCACCACCAAAGGAC',
 'ref_seq_strand': '-',
 'refseq_nbr': '1',
 'refseqid': '2591237',
 'source': 'ncbi'}

2591237:ncbi:1-60398
{'aln_start_pos': '9188',
 'read_nbr': '60398',
 'readid': '2591237:ncbi:1-60398',
 'ref_seq_aligned': 'ATCTACCAGTGGTAGATGGGTTCTTAATAATGAACATTATAGAGCTCTAC',
 'ref_seq_strand': '+',
 'refseq_nbr': '1',
 'refseqid': '2591237',
 'source': 'ncbi'}



In [None]:
fasta = FastaFileReader(p2fasta)
fastq = FastqFileReader(p2fastq)
aln = AlnFileReader(p2aln)

In [None]:
refseqs_fasta = fasta.parse_fasta(add_seq=True)
simreads = fastq.parse_fastq(add_seq=True)
refseqs_aln = aln.ref_sequences
simread_align = aln.parse_aln(add_ref_seq_aligned=True, add_read_seq_aligned=True)

Check consistency between refseqs from fasta and from aln 

In [None]:
# utility functions
def opposite_strand(seq):
    conv = {'A':'T', 'C':'G', 'G':'C', 'T':'A'}
    return ''.join([conv[base] for base in seq])

def reverse_sequence(seq):
    return seq[::-1]

opposite_strand('ACGT'), reverse_sequence('abcdef')

('TGCA', 'fedcba')

Check aln refseq information

In [None]:
# refseqid = '2591237'
refseqid = '11120'
original_seq = refseqs_fasta[refseqid]['sequence']
original_seq_accession = refseqs_fasta[refseqid]['accession']
original_seq_accession, len(original_seq)

('MN987231', 27617)

In [None]:
refseqs_aln[refseqid]['refseq_accession'], int(refseqs_aln[refseqid]['refseq_length'])

('MN987231', 27617)

In [None]:
assert original_seq_accession == refseqs_aln[refseqid]['refseq_accession']
assert len(original_seq) == int(refseqs_aln[refseqid]['refseq_length'])

### Check alignment information

In [None]:
pprint(simread_align['2591237:ncbi:1-60400'])

{'aln_start_pos': '14770',
 'read_nbr': '60400',
 'read_seq_aligned': 'ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG',
 'readid': '2591237:ncbi:1-60400',
 'ref_seq_aligned': 'ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG',
 'ref_seq_strand': '+',
 'refseq_nbr': '1',
 'refseqid': '2591237',
 'source': 'ncbi'}


Select all reads generated from a single reference sequence

In [None]:
print(f"Select all reads from reference sequence '{refseqid}''")
reads_from_refseq = {k:v for k,v in simread_align.items() if v['refseqid']==refseqid}
nbr_generated_reads = len(reads_from_refseq)
print(f"Total of {nbr_generated_reads:,d} reads")

Select all reads from reference sequence '11120''
Total of 55,200 reads


In [None]:
n = -1
selected_simread = [v for k,v in reads_from_refseq.items()][n]
pprint(selected_simread)

{'aln_start_pos': '1286',
 'read_nbr': '1',
 'read_seq_aligned': 'TCTTTGAAGAACTTCCAAATGGTTTTATGGGTGCGAAAATTTTCTCAACA',
 'readid': '11120:ncbi:5-1',
 'ref_seq_aligned': 'TCTTTGAAGAACTTCCAAATGGTTTTATGGGTGCGAAAATTTTCTCAACA',
 'ref_seq_strand': '+',
 'refseq_nbr': '5',
 'refseqid': '11120',
 'source': 'ncbi'}


In [None]:
def check_alignment(n, reads):
    selected_simread = [v for k,v in reads.items()][n]
    print(f"{'='*80}")
    print(f"checking read {selected_simread['readid']}")
    start = int(selected_simread['aln_start_pos'])
    strand = selected_simread['ref_seq_strand']
    print(f"simread info:")
    print(f" - from `{strand}` strand")
    print(f" - position: {start:,d}")

    if strand == '+':
        segment_from_refseq = original_seq[start:start+50]
    else:
        segment_from_refseq = opposite_strand(reverse_sequence(original_seq)[start:start+50])

    print('sequences:')
    print(f'- simread seq          :', selected_simread['read_seq_aligned'])
    print(f'- refseq aligned       :', selected_simread['ref_seq_aligned'])
    print(f'- segment in orig. seq :', segment_from_refseq)

In [None]:
for n in range(nbr_generated_reads-1, nbr_generated_reads-6, -1):
    check_alignment(n, reads_from_refseq)

checking read 11120:ncbi:5-1
simread info:
 - from `+` strand
 - position: 1,286
sequences:
- simread seq          : TCTTTGAAGAACTTCCAAATGGTTTTATGGGTGCGAAAATTTTCTCAACA
- refseq aligned       : TCTTTGAAGAACTTCCAAATGGTTTTATGGGTGCGAAAATTTTCTCAACA
- segment in orig. seq : TCTTTGAAGAACTTCCAAATGGTTTTATGGGTGCGAAAATTTTCTCAACA
checking read 11120:ncbi:5-2
simread info:
 - from `-` strand
 - position: 3,836
sequences:
- simread seq          : AAAGTTGTGTAGTAAGAAGATTTCTTACCACACTTACTCATTAAAGGAAT
- refseq aligned       : AAAGTTGTGTAGTAAGAAGATTTCTTACCACACTTACTCATTAAAGGAAT
- segment in orig. seq : AAAGTTGTGTAGTAAGAAGATTTCTTACCACACTTACTCATTAAAGGAAT
checking read 11120:ncbi:5-3
simread info:
 - from `-` strand
 - position: 26,257
sequences:
- simread seq          : AACAGCTTCTTTAAAGAAAGCCAATGTTGAGAAAATTTTCGCACCCATAA
- refseq aligned       : AACAGCTTCTTTAAAGAAAGCCAATGTTGAGAAAATTTTCGCACCCATAA
- segment in orig. seq : AACAGCTTCTTTAAAGAAAGCCAATGTTGAGAAAATTTTCGCACCCATAA
checking read 11120:ncbi:5-4
simread in

## Create test dataset

In [None]:
p2original_set = p2data / 'CNN_Virus_data/50mer_validating'
assert p2original_set.is_file()

Original model dataset has one input (reads) and two outputs (label for the read and relative position)

In [None]:
orig_ds = TextFileBaseIterator(p2original_set, nlines=2)
orig_ds.print_first_chuncks(1)

2-line chunk 1
AAAAAGATTTTGAGAGAGGTCGACCTGTCCTCCTAAAACGTTTACAAAAG	71	0
CATGTAACGCAGCTTAGTCCGATCGTGGCTATAATCCGTCTTTCGATTTG	1	7



In [None]:
next(orig_ds)

'CCACGTCGATGAAGCTCCGACGAGAGTCGGCGCTGAGCCCGCGCACCTCC\t71\t6\nAGCTCGTGGATCTCCCCTCCTTCTGCAGTTTCAACATCAGAAGCCCTGAA\t87\t1\n'

In a first step, we will only infer from the model, so we only need a text file with sequences

```python
2591237:ncbi:1-60400

{'read_nbr': '60400',
 'readid': '2591237:ncbi:1-60400',
 'seq_nbr': '1',
 'seqid': '2591237',
 'sequence': 'ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG',
 'source': 'ncbi'}
 ```

```
'11120:ncbi:5-1': {'aln_start_pos': '1286',
                    'read_nbr': '1',
                    'readid': '11120:ncbi:5-1',
                    'ref_seq_strand': '+',
                    'refseq_nbr': '5',
                    'refseqid': '11120',
                    'source': 'ncbi'},
```

In [None]:
def get_infer_ds_from_fastq(p2fastq, overwrite_ds=False, nsamples=None):
    """Build a dataset file from fastq read data for inference only"""
    fastq = FastqFileReader(p2fastq)
    aln = AlnFileReader(p2fastq.parent / f"{p2fastq.stem}.aln")
    
    p2dataset = Path(f"{p2fastq.stem}_ds")
    if p2dataset.is_file():
        if overwrite_ds: 
            p2dataset.unlink()
        else:
            raise ValueError(f"{p2dataset.name} already exists in {p2dataset.absolute()}")
    p2dataset.touch()
    
    read_ids = []
    read_refseqs = []
    read_start_pos = []
    read_strand = []
    
    with open(p2dataset, 'a') as fp:
        i = 1
        for fastq_chunck, aln_chunck in zip(fastq.it, aln.it):
            seq = fastq_chunck['sequence']
            fp.write(f"{seq}\t{0}\t{0}\n")
#             print(f"{seq}\t{refseqid}\t{start_pos}")
            
            aln_meta = parse_metadata_art_read_aln(aln_chunck['definition line'])
            read_ids.append(aln_meta['readid'])
            read_refseqs.append(aln_meta['refseqid'])
            read_start_pos.append(aln_meta['aln_start_pos'])
            read_strand.append(aln_meta['ref_seq_strand'])

            i += 1
            if nsamples:

                if i > nsamples: break
    print(f"Dataset with {i-1:,d} reads")    
    return p2dataset, np.array(list(zip(read_ids, read_refseqs, read_start_pos, read_strand)))

In [None]:
nsamples = None
p2ds, reads_info = get_infer_ds_from_fastq(p2fastq, overwrite_ds=True, nsamples=nsamples)

Dataset with 571,980 reads


In [None]:
reads_info[:4, :]

array([['2591237:ncbi:1-60400', '2591237', '14770', '+'],
       ['2591237:ncbi:1-60399', '2591237', '17012', '-'],
       ['2591237:ncbi:1-60398', '2591237', '9188', '+'],
       ['2591237:ncbi:1-60397', '2591237', '6764', '-']], dtype='<U21')

In [None]:
import tensorflow as tf
from tensorflow.keras.models import load_model

from metagentools.cnn_virus.data import strings_to_tensors

In [None]:
text_ds = tf.data.TextLineDataset(p2ds).batch(32)
ds = text_ds.map(strings_to_tensors)

# for xb, (y1b, y2b) in ds.take(1):
#     print(xb.shape, y1b.shape, y2b.shape)

## Inference

In [None]:
model = load_model(p2saved)

In [None]:
# model.summary()

In [None]:
prob_preds = model.predict(ds)

In [None]:
prob_preds[0].shape, prob_preds[1].shape

((571980, 187), (571980, 10))

In [None]:
class_preds = np.argmax(prob_preds[0], axis=1)
class_preds.shape
class_preds[:10]

array([117, 117, 117, 117,  32,  89, 117, 117,  94, 117])

In [None]:
# ground truth - not relevant here
# class_labels_ohe = None
# for xb, (yb, _) in ds:
#     class_labels_ohe = tf.concat([class_labels_ohe, yb], axis=0) if class_labels_ohe is not None else yb
# # print(class_labels_ohe.shape)
# class_labels = np.argmax(class_labels_ohe, axis=1)

## Evaluate Model for cov

Original model was trained with 187 different virus species.

In [None]:
with open(p2virus_labels, 'r') as fp:
    i, c = 0, 0
    cov = []
    while True:
        line = fp.readline()
        if line == '': break
        elif 'corona' in line: 
            c += 1
            line = line.replace('\t', '    \t')
            cov.append(f" - {line}")
        i += 1
print(f"Original model is trained to detect {i} virus species, including {c} coronavirus species:")
print(''.join(cov))

Original model is trained to detect 187 virus species, including 2 coronavirus species:
 - Middle_East_respiratory_syndrome-related_coronavirus    	94
 - Severe_acute_respiratory_syndrome-related_coronavirus    	117



In our case we only care about whether the model detects coronavirus species out of the sequences. We create two new functions:

In [None]:
def is_cov(y_preds):
    """Return 1 if the corresponding prediction is a corona virus, 0 otherwise"""
    return (y_preds == 94).astype(int) + (y_preds == 117).astype(int)

def cov_acc(y_true, y_preds):
    """Evaluates the accuracy of the model assuming all evaluated reads are from corona virus"""
    return is_cov(y_preds).sum()/y_preds.shape[0]

cov_acc(None, class_preds)

0.13603622504283366

In [None]:
aln = AlnFileReader(p2fastq.parent / f"{p2fastq.stem}.aln")
acc_per_refseq = {}

for refseqid in np.unique(reads_info[:,1]):
    mask = reads_info[:,1] == refseqid
    acc = cov_acc(None, class_preds[mask])
    aln_refseq_meta = aln.ref_sequences[refseqid]
    print(f"Reference Sequence: {aln_refseq_meta['species']}:")
    print(f"  Nbr reads: {class_preds[mask].shape[0]:,d}")
    print(f"  Accuracy:  {acc:.3f}")

Reference Sequence: Infectious bronchitis virus  scientific name:
  Nbr reads: 55,200
  Accuracy:  0.065
Reference Sequence: Bovine coronavirus  scientific name:
  Nbr reads: 61,800
  Accuracy:  0.055
Reference Sequence: Camel alphacoronavirus  scientific name:
  Nbr reads: 54,700
  Accuracy:  0.074
Reference Sequence: Coronavirus BtRs-BetaCoV/YN2018D  scientific name:
  Nbr reads: 60,400
  Accuracy:  0.733
Reference Sequence: Human coronavirus NL63  scientific name:
  Nbr reads: 55,000
  Accuracy:  0.067
Reference Sequence: Porcine epidemic diarrhea virus  scientific name:
  Nbr reads: 223,800
  Accuracy:  0.069
Reference Sequence: Human coronavirus OC43  scientific name:
  Nbr reads: 61,080
  Accuracy:  0.057


In [None]:
for refseqid in np.unique(reads_info[:,1]):
    mask_refseq = reads_info[:,1] == refseqid
    mask_strand_coding = reads_info[:,3] == '+'
    mask_strand_template = reads_info[:,3] == '-'
    mask_coding = (mask_strand_coding.astype(int) * mask_refseq.astype(int)).astype(bool)
    mask_template = (mask_strand_template.astype(int) * mask_refseq.astype(int)).astype(bool)

    aln_refseq_meta = aln.ref_sequences[refseqid]
    acc = cov_acc(None, class_preds[mask_refseq])
    acc_coding = cov_acc(None, class_preds[mask_coding])
    acc_template = cov_acc(None, class_preds[mask_template])
       
    print(f"Ref. Sequence: {aln_refseq_meta['species'].replace('scientific name', '').strip()}:")
    print(f"  Accuracy :............... {acc:.3f}")
    print(f"  Acc. coding strand: ..... {acc_coding:.3f}")
    print(f"  Acc. template strand: ... {acc_info:.3f}")
    print(f"  Nbr reads: {class_preds[mask_refseq].shape[0]:,d}, incl. {mask_coding.sum():,d} from coding strand and {mask_template.sum():,d} from template strand")
    print()

Ref. Sequence: Infectious bronchitis virus:
  Accuracy :............... 0.065
  Acc. coding strand: ..... 0.068
  Acc. template strand: ... 0.051
  Nbr reads: 55,200, incl. 27,632 from coding strand and 27,568 from template strand

Ref. Sequence: Bovine coronavirus:
  Accuracy :............... 0.055
  Acc. coding strand: ..... 0.058
  Acc. template strand: ... 0.051
  Nbr reads: 61,800, incl. 30,928 from coding strand and 30,872 from template strand

Ref. Sequence: Camel alphacoronavirus:
  Accuracy :............... 0.074
  Acc. coding strand: ..... 0.070
  Acc. template strand: ... 0.051
  Nbr reads: 54,700, incl. 27,313 from coding strand and 27,387 from template strand

Ref. Sequence: Coronavirus BtRs-BetaCoV/YN2018D:
  Accuracy :............... 0.733
  Acc. coding strand: ..... 0.733
  Acc. template strand: ... 0.051
  Nbr reads: 60,400, incl. 30,099 from coding strand and 30,301 from template strand

Ref. Sequence: Human coronavirus NL63:
  Accuracy :............... 0.067
  Acc. c

**Confusion Matrix**
The matrix **columns** represent the prediction labels and the **rows** represent the real labels. The confusion matrix is always a 2-D array of shape [n, n], where n is the number of valid labels for a given classification task.

In [None]:
is_cov(class_preds)

array([1, 1, 1, ..., 0, 0, 0])

In [None]:
confmat = tf.math.confusion_matrix(np.ones_like(class_preds), is_cov(class_preds))
confmat.shape

TensorShape([2, 2])

In [None]:
confmat.numpy()

array([[     0,      0],
       [494170,  77810]], dtype=int32)

# Simreads from 25 sequences

In [None]:
p2saved = p2data / 'saved/cnn_virus_original/pretrained_model.h5'
p2simreads = p2data / 'cov_simreads/single_25seq_50bp'
p2virus_labels = p2data / 'CNN_Virus_data/virus_name_mapping'
assert p2saved.is_file()
assert p2simreads.is_dir()
assert p2virus_labels.is_file()

In [None]:
p2fastq = p2simreads / f"{p2simreads.stem}.fq"
p2aln = p2simreads / f"{p2simreads.stem}.aln"
assert p2fastq.is_file()
assert p2aln.is_file()

nsamples = 1_000_000
p2ds, reads_info = get_infer_ds_from_fastq(p2fastq, overwrite_ds=True, nsamples=nsamples)

text_ds = tf.data.TextLineDataset(p2ds).batch(32)
ds = text_ds.map(strings_to_tensors)

Dataset with 1,000,000 reads


In [None]:
model = load_model(p2saved)

### handle GPU with tf

In [None]:
# device = tf.config.list_physical_devices('GPU')[0]
# device

PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')

In [None]:
# tf.config.experimental.get_memory_info(device='GPU:0')

{'current': 301555968, 'peak': 330915840}

In [None]:
# physical_devices = tf.config.list_physical_devices('GPU')
# try:
#     tf.config.experimental.set_memory_growth(physical_devices[0], True)
#     assert tf.config.experimental.get_memory_growth(physical_devices[0])
# except:
#     print('Invalid device or cannot modify virtual devices once initialized.')

Invalid device or cannot modify virtual devices once initialized.


In [None]:
# tf.keras.backend.clear_session()

In [None]:
# model.summary()

In [None]:
# try:
#     del model
# except:
#     pass
# import gc
# gc.collect()

In [None]:
# gpus = tf.config.list_physical_devices('GPU')
# if gpus:
#   # Restrict TensorFlow to only use the first GPU
#   try:
#     tf.config.set_visible_devices(gpus[0], 'GPU')
#     logical_gpus = tf.config.list_logical_devices('GPU')
#     print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
#   except RuntimeError as e:
#     # Visible devices must be set before GPUs have been initialized
#     print(e)

1 Physical GPUs, 1 Logical GPU


### end of section

In [None]:
prob_preds = model.predict(ds, verbose=1)



In [None]:
prob_preds[0].shape, prob_preds[1].shape

((1000000, 187), (1000000, 10))

In [None]:
class_preds = np.argmax(prob_preds[0], axis=1)
class_preds.shape

(1000000,)

In [None]:
def is_cov(y_preds):
    """Return 1 if the corresponding prediction is a corona virus, 0 otherwise"""
    return (y_preds == 94).astype(int) + (y_preds == 117).astype(int)

def cov_acc(y_true, y_preds):
    """Evaluates the accuracy of the model assuming all evaluated reads are from corona virus"""
    return is_cov(y_preds).sum()/y_preds.shape[0]

cov_acc(None, class_preds)

0.324955

In [None]:
np.unique(reads_info[:,1]).shape[0]

11

In [None]:
for refseqid in np.unique(reads_info[:,1]):
    mask_refseq = reads_info[:,1] == refseqid
    mask_strand_coding = reads_info[:,3] == '+'
    mask_strand_template = reads_info[:,3] == '-'
    mask_coding = (mask_strand_coding.astype(int) * mask_refseq.astype(int)).astype(bool)
    mask_template = (mask_strand_template.astype(int) * mask_refseq.astype(int)).astype(bool)

    aln_refseq_meta = aln.ref_sequences[refseqid]
    acc = cov_acc(None, class_preds[mask_refseq])
    acc_coding = cov_acc(None, class_preds[mask_coding])
    acc_template = cov_acc(None, class_preds[mask_template])
       
    print(f"Ref. Sequence: {aln_refseq_meta['species'].replace('scientific name', '').strip()}:")
    print(f"  Accuracy :............... {acc:.3f}")
    print(f"  Acc. coding strand: ..... {acc_coding:.3f}")
    print(f"  Acc. template strand: ... {acc_template:.3f}")
    print(f"  Nbr reads: {class_preds[mask_refseq].shape[0]:,d}, incl. {mask_coding.sum():,d} from coding strand and {mask_template.sum():,d} from template strand")
    print()

Ref. Sequence: Infectious bronchitis virus:
  Accuracy :............... 0.059
  Acc. coding strand: ..... 0.061
  Acc. template strand: ... 0.058
  Nbr reads: 110,500, incl. 55,405 from coding strand and 55,095 from template strand

Ref. Sequence: Bovine coronavirus:
  Accuracy :............... 0.054
  Acc. coding strand: ..... 0.058
  Acc. template strand: ... 0.051
  Nbr reads: 61,800, incl. 30,850 from coding strand and 30,950 from template strand

Ref. Sequence: Murine hepatitis virus:
  Accuracy :............... 0.052
  Acc. coding strand: ..... 0.053
  Acc. template strand: ... 0.051
  Nbr reads: 62,200, incl. 31,030 from coding strand and 31,170 from template strand

Ref. Sequence: Bat coronavirus HKU10:
  Accuracy :............... 0.073
  Acc. coding strand: ..... 0.072
  Acc. template strand: ... 0.074
  Nbr reads: 10,700, incl. 5,403 from coding strand and 5,297 from template strand

Ref. Sequence: Middle East respiratory syndrome-related coronavirus:
  Accuracy :............

# New Section