# Inference with original cnn_virus trained model

- Inference only, no training.
- Experiments with cov_reads files
- Trying locally

# Imports and setup environment

### Install and import packages

In [None]:
try:
    from ecutilities.ipython import nb_setup
    print('`ecutilities` already installed')
except ModuleNotFoundError as e:
    print('will install ecutilities')
    !pip install -qqU ecutilities
    from ecutilities.ipython import nb_setup

nb_setup()

`ecutilities` already installed
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Set autoreload mode


In [None]:
import numpy as np
import pandas as pd
import os
import tensorflow as tf

from pathlib import Path
from pprint import pprint
from tensorflow.python.client import device_lib
from tensorflow.keras.models import load_model

from metagentools.cnn_virus.data import strings_to_tensors, create_infer_ds_from_fastq
from metagentools.cnn_virus.data import FastaFileReader, FastqFileReader, AlnFileReader, parse_metadata_art_read_aln
from metagentools.core import TextFileBaseIterator

In [None]:
print(f"Tensorflow version: {tf.__version__}\n")

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # or any {'0', '1', '2'}

devices = device_lib.list_local_devices()
print('\nDevices:')
for d in devices:
    t = d.device_type
    name = d.physical_device_desc
    l = [item.split(':', 1) for item in name.split(', ')]
    name_attr = dict([x for x in l if len(x)==2])
    dev = name_attr.get('name', ' ')
    print(f"  - {t}  {d.name} {dev:25s}")

Tensorflow version: 2.8.2


Devices:
  - CPU  /device:CPU:0                          
  - GPU  /device:GPU:0  NVIDIA GeForce GTX 1050 


## Install and import custom code, mount gdrive

In [None]:
# MODIFY AFTER INSTALLING ECUTILITIES >= 1.2.3

from ecutilities.core import validate_path

def safe_path(
    path:str|Path, # path to validate
)-> Path:          # validated path as a  pathlib.Path
    """"""
    validate_path(path, path_type='dir', raise_error=True)
    if isinstance(path, str): 
        path = Path(path)
    return path

def nbs_root_dir(
    path:str|Path|None = None, # path from where to seek for notebook parent directory
    pattern:str = 'nbs',       # pattern to identify the nbs directory
)-> Path:                      # path of the parent directory
    """Climb the directory tree up to the notebook directory and return its path"""    
    if path is None: path = Path()
    path = safe_path(path).absolute()
    tree = [path.name] + [p.name for p in path.parents]
    mask = [True if n.startswith(pattern) else False for n in tree]
    nbs = Path(f"{'../' * (mask.index(True))}").resolve()
    return nbs
    
nbs = nbs_root_dir(Path())
nbs

PosixPath('/home/vtec/projects/bio/metagentools/nbs')

In [None]:
try:
    from google.colab import drive
    ON_COLAB = True
    print('Running on colab')
    print('Installing custom project code')   
    !pip install -U git+https://github.com/vtecftwy/metagenomics.git@refactor_cnn_virus
    drive.mount('/content/gdrive')
    
    p2drive = Path('/content/gdrive/MyDrive/Metagenonics')
    assert p2drive.is_dir()
    p2data =  p2drive / 'CNN_Virus_data'
    assert p2data.is_dir()

except ModuleNotFoundError as e:
#     print(e)
    ON_COLAB = False
    print('Running locally')
    print('Make sure you have installed the custom project code in your environment')
    p2data = nbs_root_dir().parent / 'data/'
    assert p2data.is_dir()
    print(p2data.absolute())

Running locally
Make sure you have installed the custom project code in your environment
/home/vtec/projects/bio/metagentools/data


# Experiments with simulated reads

## Setup paths

This assumes that the shared gdrive directory is accessible through a shortcut called `Metagenomics` under the root of gdrive.

In [None]:
# p2drive = Path('/content/gdrive/MyDrive/Metagenonics')
# assert p2drive.is_dir()

# p2data =  p2drive / 'CNN_Virus_data'
# assert p2data.is_dir()

In [None]:
# path for original trained model
p2saved = p2data / 'saved/cnn_virus_original/pretrained_model.h5'
p2simreads = p2data / 'cov_simreads/single_10seq_50bp'
p2virus_labels = p2data / 'CNN_Virus_data/virus_name_mapping'
assert p2saved.is_file()
assert p2simreads.is_dir()
assert p2virus_labels.is_file()

#path for the learning weights file
# filepath_weights=p2data / "weight_of_classes"
# assert filepath_weights.is_file()

## Explore simread output files

In [None]:
p2fasta = p2data / 'cov_data/cov_virus_sequences_ten.fa'
p2fastq = p2simreads / f"{p2simreads.stem}.fq"
p2aln = p2simreads / f"{p2simreads.stem}.aln"
assert p2fastq.is_file()
assert p2aln.is_file()
assert p2fasta.is_file()

In [None]:
from metagentools.cnn_virus.data import FastaFileReader, FastqFileReader, AlnFileReader, parse_metadata_art_read_aln
from metagentools.core import TextFileBaseIterator

In [None]:
fasta = FastaFileReader(p2fasta)
fastq = FastqFileReader(p2fastq)
aln = AlnFileReader(p2aln)

### Exploring the simreads in fastaq

TODO: 
- change `add_seq` into `add_read_seq`
- change `seqid` into `refseqid`
- change seq_nbr into refseq_nbr

- parse: seqid and taxonomyid


In [None]:
for i, (k,v) in enumerate(fastq.parse_fastq(add_readseq=True).items()):
    print('readid:',k)
    pprint(v)
    print()

    if i+1 >= 3: break

readid: 2591237:ncbi:1-60400
{'readid': '2591237:ncbi:1-60400',
 'readnb': '60400',
 'readseq': 'ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}

readid: 2591237:ncbi:1-60399
{'readid': '2591237:ncbi:1-60399',
 'readnb': '60399',
 'readseq': 'GATCAATGTGGCATCTACAATACAGACAGCATGAAGCACCACCAAAGGAC',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}

readid: 2591237:ncbi:1-60398
{'readid': '2591237:ncbi:1-60398',
 'readnb': '60398',
 'readseq': 'ATCTACCAGTGGTAGATGGGTTCTTAATAATGAACATTATAGAGCTCTAC',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}



### Exploring ALN header (reference sequences)

In [None]:
pprint(aln.ref_sequences)

{'11120:ncbi:5': {'refseq_accession': 'MN987231',
                  'refseq_length': '27617',
                  'refseqid': '11120:ncbi:5',
                  'refseqnb': '5',
                  'refsource': 'ncbi',
                  'reftaxonomyid': '11120',
                  'species': 'Infectious bronchitis virus  scientific name'},
 '11128:ncbi:2': {'refseq_accession': 'LC494191',
                  'refseq_length': '30942',
                  'refseqid': '11128:ncbi:2',
                  'refseqnb': '2',
                  'refsource': 'ncbi',
                  'reftaxonomyid': '11128',
                  'species': 'Bovine coronavirus  scientific name'},
 '1699095:ncbi:10': {'refseq_accession': 'KT368904',
                     'refseq_length': '27395',
                     'refseqid': '1699095:ncbi:10',
                     'refseqnb': '10',
                     'refsource': 'ncbi',
                     'reftaxonomyid': '1699095',
                     'species': 'Camel alphacoronavirus

### Exploring read's metadata

In [None]:
for i, (k,v) in enumerate(aln.parse_aln(add_ref_seq_aligned=True, add_read_seq_aligned=True).items()):
    print(k)
    pprint(v)
    print()
    if i+1>=3: break

2591237:ncbi:1-60400
{'aln_start_pos': '14770',
 'read_seq_aligned': 'ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG',
 'readid': '2591237:ncbi:1-60400',
 'readnb': '60400',
 'ref_seq_aligned': 'ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG',
 'refseq_strand': '+',
 'refseqid': '2591237:ncbi:1',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}

2591237:ncbi:1-60399
{'aln_start_pos': '17012',
 'read_seq_aligned': 'GATCAATGTGGCATCTACAATACAGACAGCATGAAGCACCACCAAAGGAC',
 'readid': '2591237:ncbi:1-60399',
 'readnb': '60399',
 'ref_seq_aligned': 'GATCAATGTGGCATCTACAATACAGACAGCATGAAGCACCACCAAAGGAC',
 'refseq_strand': '-',
 'refseqid': '2591237:ncbi:1',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}

2591237:ncbi:1-60398
{'aln_start_pos': '9188',
 'read_seq_aligned': 'ATCTACCAGTGGTAGATGGGTTCTTAATAATGAACATTATAGAGCTCTAC',
 'readid': '2591237:ncbi:1-60398',
 'readnb': '60398',
 'ref_seq_aligned': 'ATCTACCAGTGGTAGATGGGTTCTTAATAATGAACATTATAGAGCTCTAC

In [None]:
fasta = FastaFileReader(p2fasta)
fastq = FastqFileReader(p2fastq)
aln = AlnFileReader(p2aln)

In [None]:
refseqs_fasta = fasta.parse_fasta(add_seq=True)
simreads = fastq.parse_fastq(add_readseq=True)
refseqs_aln = aln.ref_sequences
simread_align = aln.parse_aln(add_ref_seq_aligned=True, add_read_seq_aligned=True)

Check consistency between refseqs from fasta and from aln 

In [None]:
# utility functions
def opposite_strand(seq):
    conv = {'A':'T', 'C':'G', 'G':'C', 'T':'A'}
    return ''.join([conv[base] for base in seq])

def reverse_sequence(seq):
    return seq[::-1]

opposite_strand('ACGT'), reverse_sequence('a b c d e f')

('TGCA', 'f e d c b a')

Check aln refseq information

In [None]:
# refseqid = '2591237:ncbi:1'
refseqid = '11128:ncbi:2'
original_seq = refseqs_fasta[refseqid]['sequence']
original_seq_accession = refseqs_fasta[refseqid]['accession']
original_seq_accession, len(original_seq)

('LC494191', 30942)

In [None]:
refseqs_aln[refseqid]['refseq_accession'], int(refseqs_aln[refseqid]['refseq_length'])

('LC494191', 30942)

In [None]:
assert original_seq_accession == refseqs_aln[refseqid]['refseq_accession']
assert len(original_seq) == int(refseqs_aln[refseqid]['refseq_length'])

### Check alignment information

In [None]:
pprint(simread_align['2591237:ncbi:1-60400'])

{'aln_start_pos': '14770',
 'read_seq_aligned': 'ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG',
 'readid': '2591237:ncbi:1-60400',
 'readnb': '60400',
 'ref_seq_aligned': 'ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG',
 'refseq_strand': '+',
 'refseqid': '2591237:ncbi:1',
 'refseqnb': '1',
 'refsource': 'ncbi',
 'reftaxonomyid': '2591237'}


Select all reads generated from a single reference sequence

In [None]:
print(f"Select all reads from reference sequence '{refseqid}''")
reads_from_refseq = {k:v for k,v in simread_align.items() if v['refseqid']==refseqid}
nbr_generated_reads = len(reads_from_refseq)
print(f"Total of {nbr_generated_reads:,d} reads")

Select all reads from reference sequence '11128:ncbi:2''
Total of 61,800 reads


In [None]:
n = -1
selected_simread = [v for k,v in reads_from_refseq.items()][n]
pprint(selected_simread)

{'aln_start_pos': '21629',
 'read_seq_aligned': 'GCTATGCTGGATATTCAGCCTGAAGACTACAGAAGTGTTGATGTTGCTAT',
 'readid': '11128:ncbi:2-1',
 'readnb': '1',
 'ref_seq_aligned': 'GCTATGCTGGATATTCAGCCTGAAGACTACAGAAGTGTTGATGTTGCTAT',
 'refseq_strand': '+',
 'refseqid': '11128:ncbi:2',
 'refseqnb': '2',
 'refsource': 'ncbi',
 'reftaxonomyid': '11128'}


In [None]:
def check_alignment(n, reads):
    selected_simread = [v for k,v in reads.items()][n]
    print(f"{'='*80}")
    print(f"checking read {selected_simread['readid']}")
    start = int(selected_simread['aln_start_pos'])
    strand = selected_simread['refseq_strand']
    print(f"simread info:")
    print(f" - from `{strand}` strand")
    print(f" - position: {start:,d}")

    if strand == '+':
        segment_from_refseq = original_seq[start:start+50]
    else:
        segment_from_refseq = opposite_strand(reverse_sequence(original_seq)[start:start+50])

    print('sequences:')
    print(f'- simread seq          :', selected_simread['read_seq_aligned'])
    print(f'- refseq aligned       :', selected_simread['ref_seq_aligned'])
    print(f'- segment in orig. seq :', segment_from_refseq)

In [None]:
for n in range(nbr_generated_reads-1, nbr_generated_reads-6, -1):
    check_alignment(n, reads_from_refseq)

checking read 11128:ncbi:2-1
simread info:
 - from `+` strand
 - position: 21,629
sequences:
- simread seq          : GCTATGCTGGATATTCAGCCTGAAGACTACAGAAGTGTTGATGTTGCTAT
- refseq aligned       : GCTATGCTGGATATTCAGCCTGAAGACTACAGAAGTGTTGATGTTGCTAT
- segment in orig. seq : GCTATGCTGGATATTCAGCCTGAAGACTACAGAAGTGTTGATGTTGCTAT
checking read 11128:ncbi:2-2
simread info:
 - from `-` strand
 - position: 14,882
sequences:
- simread seq          : CTTCTTAAATACATGTTCTTGTAAAAGGACTCATCAGTAAACTTTTGTCC
- refseq aligned       : CTTCTTAAATACATGTTCTTGTAAAAGGACTCATCAGTAAACTTTTGTCC
- segment in orig. seq : CTTCTTAAATACATGTTCTTGTAAAAGGACTCATCAGTAAACTTTTGTCC
checking read 11128:ncbi:2-3
simread info:
 - from `+` strand
 - position: 16,050
sequences:
- simread seq          : ATTTAAGAAGTGCAGTTATGCAGAGTGTTGGAGCTTGCGTGGTCTGCTCT
- refseq aligned       : ATTTAAGAAGTGCAGTTATGCAGAGTGTTGGAGCTTGCGTGGTCTGCTCT
- segment in orig. seq : ATTTAAGAAGTGCAGTTATGCAGAGTGTTGGAGCTTGCGTGGTCTGCTCT
checking read 11128:ncbi:2-4
simread 

## Create test dataset

In [None]:
p2original_set = p2data / 'CNN_Virus_data/50mer_validating'
assert p2original_set.is_file()

Original model dataset has one input (reads) and two outputs (label for the read and relative position)

In [None]:
orig_ds = TextFileBaseIterator(p2original_set, nlines=5)
orig_ds.print_first_chuncks(1)

5-line chunk 1
AAAAAGATTTTGAGAGAGGTCGACCTGTCCTCCTAAAACGTTTACAAAAG	71	0
CATGTAACGCAGCTTAGTCCGATCGTGGCTATAATCCGTCTTTCGATTTG	1	7
AACAACATCTTGTTGATGATAACCGTCAAAGTGTTTTGGGTCTGGAGGGA	158	6
AGTACCTGGAGAGCGTTAAGAAACACAAACGGCTGGATGTAGTGCCGCGC	6	7
CCACGTCGATGAAGCTCCGACGAGAGTCGGCGCTGAGCCCGCGCACCTCC	71	6



In [None]:
next(orig_ds)

'ATGGTGCGCCTTCAGTATAAGGATGCTAATATTAGTATGTATCTGGCAAT\t0\t1\nCTGGTGGCGCACGTCGAGGCCCTGGCCAGCTGGTTGATGACTTTACCCTG\t25\t2\nCATGATGGAATCGGTCACGGAAAGCATTCTAAATGGATACGTACAACTAC\t4\t5\nTATTGTACATCTATTACGTCTTTTCGACTATCAATAGTAAATCGTCTGTC\t12\t4\nATTGCGTCTTTTGTAAAGATCACAACAAACATGTTTCGCAAGCCGGACAT\t4\t6\n'

In a first step, we will only infer from the model, so we only need a text file with sequences

```python
2591237:ncbi:1-60400

{'read_nbr': '60400',
 'readid': '2591237:ncbi:1-60400',
 'seq_nbr': '1',
 'seqid': '2591237',
 'sequence': 'ACAACTCCTATTCGTAGTTGAAGTTGTTGACAAATACTTTGATTGTTACG',
 'source': 'ncbi'}
 ```

In [None]:
nsamples = None
p2ds, reads_info = create_infer_ds_from_fastq(p2fastq, overwrite_ds=True, nsamples=nsamples)

Dataset with 571,980 reads


In [None]:
reads_info[:4, :]

array([['2591237:ncbi:1-60400', '2591237:ncbi:1', '14770', '+'],
       ['2591237:ncbi:1-60399', '2591237:ncbi:1', '17012', '-'],
       ['2591237:ncbi:1-60398', '2591237:ncbi:1', '9188', '+'],
       ['2591237:ncbi:1-60397', '2591237:ncbi:1', '6764', '-']],
      dtype='<U21')

In [None]:
import tensorflow as tf
from tensorflow.keras.models import load_model

from metagentools.cnn_virus.data import strings_to_tensors

In [None]:
text_ds = tf.data.TextLineDataset(p2ds).batch(32)
ds = text_ds.map(strings_to_tensors)

# for xb, (y1b, y2b) in ds.take(1):
#     print(xb.shape, y1b.shape, y2b.shape)

## Inference

In [None]:
model = load_model(p2saved)

In [None]:
# model.summary()

In [None]:
prob_preds = model.predict(ds, verbose=1)



In [None]:
prob_preds[0].shape, prob_preds[1].shape

((571980, 187), (571980, 10))

In [None]:
class_preds = np.argmax(prob_preds[0], axis=1)
class_preds.shape
class_preds[:10]

array([117, 117, 117, 117,  32,  89, 117, 117,  94, 117])

## Evaluate Model for cov

Original model was trained with 187 different virus species.

In [None]:
p2virus_labels = p2data / 'CNN_Virus_data/virus_name_mapping'
with open(p2virus_labels, 'r') as fp:
    i, c = 0, 0
    cov = []
    while True:
        line = fp.readline()
        if line == '': break
        elif ('corona' in line) or ('mers' in line) : 
            c += 1
            line = line.replace('\t', '    \t')
            cov.append(f" - {line}")
        i += 1
print(f"Original model is trained to detect {i} virus species, including {c} coronavirus species:")
print(''.join(cov))

Original model is trained to detect 187 virus species, including 2 coronavirus species:
 - Middle_East_respiratory_syndrome-related_coronavirus    	94
 - Severe_acute_respiratory_syndrome-related_coronavirus    	117



In [None]:
fasta.reset_iterator()
fasta.it.print_first_chuncks(10)


Sequence 1:
>2591237:ncbi:1 [MK211378]	2591237	ncbi	1 [MK211378] 2591237	Coronavirus BtRs-BetaCoV/YN2018D		scientific name

TATTAGGTTTTCTACCTACCCAGGAAAAGCCAACCAACCTCGATCTCTTGTAGATCTGTTCTCTAAACGAACTTTAAAAT ...

Sequence 2:
>11128:ncbi:2 [LC494191]	11128	ncbi	2 [LC494191] 11128	Bovine coronavirus		scientific name

CATCCCGCTTCACTGATCTCTTGTTAGATCTTTTCATAATCTAAACTTTATAAAAACATCCACTCCCTGTAGTCTATGCC ...

Sequence 3:
>31631:ncbi:3 [KY967361]	31631	ncbi	3 [KY967361] 31631	Human coronavirus OC43		scientific name

ATCTCTTGTTAGATCTTTTTGTAATCTAAACTTTATAAAAACATCCACTCCCTGTAATCTATGCTTGTGGGCGTAGATTT ...

Sequence 4:
>277944:ncbi:4 [LC654455]	277944	ncbi	4 [LC654455] 277944	Human coronavirus NL63		scientific name

ATTTTCTTATTTAGACTTTGTGTCTACTCTTCTCAACTAAACGAAATTTTTCTAGTGCTGTCATTTGTTATGGCAGTCCT ...

Sequence 5:
>11120:ncbi:5 [MN987231]	11120	ncbi	5 [MN987231] 11120	Infectious bronchitis virus		scientific name

TCCTAAGTGTGATATAAATATATATCATACACACTAGCCTTGCGCTAGATTTCTAACTTAACAAAACGGACTTAAATACC ...

Sequence 

In our case we only care about whether the model detects coronavirus species out of the sequences. We create two new functions:

In [None]:
def is_cov(y_preds):
    """Return 1 if the corresponding prediction is a corona virus, 0 otherwise"""
    return (y_preds == 94).astype(int) + (y_preds == 117).astype(int)

def is_mers(y_preds):
    return y_preds == 94

def is_sars(y_preds):
    return y_preds == 117

def cov_acc(y_true, y_preds):
    """Evaluates the accuracy of the model assuming all evaluated reads are from corona virus"""
    return is_cov(y_preds).sum()/y_preds.shape[0]

def mers_acc(y_true, y_preds):
    """Evaluates the accuracy of the model assuming all evaluated reads are from corona virus"""
    return is_mers(y_preds).sum()/y_preds.shape[0]

def sars_acc(y_true, y_preds):
    """Evaluates the accuracy of the model assuming all evaluated reads are from corona virus"""
    return is_sars(y_preds).sum()/y_preds.shape[0]

# cov_acc(None, class_preds)

In [None]:
aln = AlnFileReader(p2fastq.parent / f"{p2fastq.stem}.aln")
acc_per_refseq = {}

for refseqid in np.unique(reads_info[:,1]):
    mask = reads_info[:,1] == refseqid
    acc = cov_acc(None, class_preds[mask])
    aln_refseq_meta = aln.ref_sequences[refseqid]
    print(f"Reference Sequence: {aln_refseq_meta['species']}:")
    print(f"  Nbr reads: {class_preds[mask].shape[0]:,d}")
    print(f"  Accuracy:       {acc:.3f}")
    print(f"  Accuracy MERS:  {mers_acc(None, class_preds[mask]):.3f}")
    print(f"  Accuracy SARS:  {sars_acc(None, class_preds[mask]):.3f}")

Reference Sequence: Infectious bronchitis virus  scientific name:
  Nbr reads: 55,200
  Accuracy:       0.065
  Accuracy MERS:  0.032
  Accuracy SARS:  0.034
Reference Sequence: Bovine coronavirus  scientific name:
  Nbr reads: 61,800
  Accuracy:       0.055
  Accuracy MERS:  0.030
  Accuracy SARS:  0.026
Reference Sequence: Camel alphacoronavirus  scientific name:
  Nbr reads: 54,700
  Accuracy:       0.074
  Accuracy MERS:  0.035
  Accuracy SARS:  0.039
Reference Sequence: Coronavirus BtRs-BetaCoV/YN2018D  scientific name:
  Nbr reads: 60,400
  Accuracy:       0.733
  Accuracy MERS:  0.014
  Accuracy SARS:  0.719
Reference Sequence: Human coronavirus NL63  scientific name:
  Nbr reads: 55,000
  Accuracy:       0.067
  Accuracy MERS:  0.031
  Accuracy SARS:  0.035
Reference Sequence: Porcine epidemic diarrhea virus  scientific name:
  Nbr reads: 56,000
  Accuracy:       0.068
  Accuracy MERS:  0.032
  Accuracy SARS:  0.036
Reference Sequence: Porcine epidemic diarrhea virus  scientifi

In [None]:
reads_info[:3, :]

array([['2591237:ncbi:1-60400', '2591237:ncbi:1', '14770', '+'],
       ['2591237:ncbi:1-60399', '2591237:ncbi:1', '17012', '-'],
       ['2591237:ncbi:1-60398', '2591237:ncbi:1', '9188', '+']],
      dtype='<U21')

In [None]:
for refseqid in np.unique(reads_info[:,1]):
    mask_refseq = reads_info[:,1] == refseqid
    mask_strand_coding = reads_info[:,3] == '+'
    mask_strand_template = reads_info[:,3] == '-'
    mask_coding = (mask_strand_coding.astype(int) * mask_refseq.astype(int)).astype(bool)
    mask_template = (mask_strand_template.astype(int) * mask_refseq.astype(int)).astype(bool)

    aln_refseq_meta = aln.ref_sequences[refseqid]
    acc = cov_acc(None, class_preds[mask_refseq])
    acc_coding = cov_acc(None, class_preds[mask_coding])
    acc_template = cov_acc(None, class_preds[mask_template])
       
    print(f"Ref. Sequence: {aln_refseq_meta['species'].replace('scientific name', '').strip()}:")
    print(f"  Accuracy :............... {acc:.3f}")
    print(f"  Acc. coding strand: ..... {acc_coding:.3f}")
    print(f"  Acc. template strand: ... {acc_template:.3f}")
    print(f"  Nbr reads: {class_preds[mask_refseq].shape[0]:,d}, incl. {mask_coding.sum():,d} from coding strand and {mask_template.sum():,d} from template strand")
    print()

Ref. Sequence: Infectious bronchitis virus:
  Accuracy :............... 0.065
  Acc. coding strand: ..... 0.068
  Acc. template strand: ... 0.063
  Nbr reads: 55,200, incl. 27,632 from coding strand and 27,568 from template strand

Ref. Sequence: Bovine coronavirus:
  Accuracy :............... 0.055
  Acc. coding strand: ..... 0.058
  Acc. template strand: ... 0.053
  Nbr reads: 61,800, incl. 30,928 from coding strand and 30,872 from template strand

Ref. Sequence: Camel alphacoronavirus:
  Accuracy :............... 0.074
  Acc. coding strand: ..... 0.070
  Acc. template strand: ... 0.077
  Nbr reads: 54,700, incl. 27,313 from coding strand and 27,387 from template strand

Ref. Sequence: Coronavirus BtRs-BetaCoV/YN2018D:
  Accuracy :............... 0.733
  Acc. coding strand: ..... 0.733
  Acc. template strand: ... 0.733
  Nbr reads: 60,400, incl. 30,099 from coding strand and 30,301 from template strand

Ref. Sequence: Human coronavirus NL63:
  Accuracy :............... 0.067
  Acc. c

# Simreads from 25 sequences

In [None]:
p2saved = p2data / 'saved/cnn_virus_original/pretrained_model.h5'
p2simreads = p2data / 'cov_simreads/single_25seq_50bp'
p2virus_labels = p2data / 'CNN_Virus_data/virus_name_mapping'
assert p2saved.is_file()
assert p2simreads.is_dir()
assert p2virus_labels.is_file()

In [None]:
p2fastq = p2simreads / f"{p2simreads.stem}.fq"
p2aln = p2simreads / f"{p2simreads.stem}.aln"
assert p2fastq.is_file()
assert p2aln.is_file()

fastq = FastqFileReader(p2fastq)
aln = AlnFileReader(p2aln)

nsamples = None
p2ds, reads_info = create_infer_ds_from_fastq(p2fastq, overwrite_ds=True, nsamples=nsamples)

text_ds = tf.data.TextLineDataset(p2ds).batch(32)
ds = text_ds.map(strings_to_tensors)

Dataset with 1,442,519 reads


In [None]:
model = load_model(p2saved)

In [None]:
prob_preds = model.predict(ds, verbose=1)



In [None]:
prob_preds[0].shape, prob_preds[1].shape

((1442519, 187), (1442519, 10))

In [None]:
class_preds = np.argmax(prob_preds[0], axis=1)
class_preds.shape

(1442519,)

In [None]:
def is_cov(y_preds):
    """Return 1 if the corresponding prediction is a corona virus, 0 otherwise"""
    return (y_preds == 94).astype(int) + (y_preds == 117).astype(int)

def is_mers(y_preds):
    return y_preds == 94

def is_sars(y_preds):
    return y_preds == 117

def cov_acc(y_true, y_preds):
    """Evaluates the accuracy of the model assuming all evaluated reads are from corona virus"""
    return is_cov(y_preds).sum()/y_preds.shape[0]

def mers_acc(y_true, y_preds):
    """Evaluates the accuracy of the model assuming all evaluated reads are from corona virus"""
    return is_mers(y_preds).sum()/y_preds.shape[0]

def sars_acc(y_true, y_preds):
    """Evaluates the accuracy of the model assuming all evaluated reads are from corona virus"""
    return is_sars(y_preds).sum()/y_preds.shape[0]

# cov_acc(None, class_preds)

In [None]:
np.unique(reads_info[:,1]).shape[0]

25

In [None]:
for refseqid in np.unique(reads_info[:,1]):
    mask_refseq = reads_info[:,1] == refseqid
    mask_strand_coding = reads_info[:,3] == '+'
    mask_strand_template = reads_info[:,3] == '-'
    mask_coding = (mask_strand_coding.astype(int) * mask_refseq.astype(int)).astype(bool)
    mask_template = (mask_strand_template.astype(int) * mask_refseq.astype(int)).astype(bool)

    aln_refseq_meta = aln.ref_sequences[refseqid]
    acc = cov_acc(None, class_preds[mask_refseq])
    acc_coding = cov_acc(None, class_preds[mask_coding])
    acc_template = cov_acc(None, class_preds[mask_template])
    
    species = aln_refseq_meta['species'].replace('scientific name', '').strip()
    refid = aln_refseq_meta['refseqid'].strip()
    refseq_accession = aln_refseq_meta['refseq_accession'].strip()
       
    print(f"Ref. Sequence: {species} ({refseqid} / {refseq_accession}):")
    print(f"  Accuracy :............... {acc:.3f}")
    print(f"  Accuracy MERS: .......... {mers_acc(None, class_preds[mask_refseq]):.3f}")
    print(f"  Accuracy SARS: .......... {sars_acc(None, class_preds[mask_refseq]):.3f}")
    print(f"  Acc. coding strand: ..... {acc_coding:.3f}")
    print(f"  Acc. template strand: ... {acc_template:.3f}")
    print(f"  Nbr reads: {class_preds[mask_refseq].shape[0]:,d}, incl. {mask_coding.sum():,d} from coding strand and {mask_template.sum():,d} from template strand")
    print()

Ref. Sequence: Infectious bronchitis virus (11120:ncbi:17 / MW792514):
  Accuracy :............... 0.055
  Accuracy MERS: .......... 0.026
  Accuracy SARS: .......... 0.029
  Acc. coding strand: ..... 0.054
  Acc. template strand: ... 0.056
  Nbr reads: 55,300, incl. 27,748 from coding strand and 27,552 from template strand

Ref. Sequence: Infectious bronchitis virus (11120:ncbi:19 / EU526388):
  Accuracy :............... 0.055
  Accuracy MERS: .......... 0.027
  Accuracy SARS: .......... 0.028
  Acc. coding strand: ..... 0.053
  Acc. template strand: ... 0.057
  Nbr reads: 55,400, incl. 27,670 from coding strand and 27,730 from template strand

Ref. Sequence: Infectious bronchitis virus (11120:ncbi:5 / MN987231):
  Accuracy :............... 0.063
  Accuracy MERS: .......... 0.031
  Accuracy SARS: .......... 0.032
  Acc. coding strand: ..... 0.068
  Acc. template strand: ... 0.059
  Nbr reads: 55,200, incl. 27,657 from coding strand and 27,543 from template strand

Ref. Sequence: Bovin

# Simreads from 100 sequences

In [None]:
p2saved = p2data / 'saved/cnn_virus_original/pretrained_model.h5'
p2simreads = p2data / 'cov_simreads/single_100seq_50bp'
p2virus_labels = p2data / 'CNN_Virus_data/virus_name_mapping'
assert p2saved.is_file()
assert p2simreads.is_dir()
assert p2virus_labels.is_file()

In [None]:
p2fastq = p2simreads / f"{p2simreads.stem}.fq"
p2aln = p2simreads / f"{p2simreads.stem}.aln"
assert p2fastq.is_file()
assert p2aln.is_file()

fastq = FastqFileReader(p2fastq)
aln = AlnFileReader(p2aln)

nsamples = 3_000_000
p2ds, reads_info = create_infer_ds_from_fastq(p2fastq, overwrite_ds=True, nsamples=nsamples)

text_ds = tf.data.TextLineDataset(p2ds).batch(32)
ds = text_ds.map(strings_to_tensors)

Dataset with 3,000,000 reads


In [None]:
model = load_model(p2saved)

In [None]:
prob_preds = model.predict(ds, verbose=1)



ResourceExhaustedError: OOM when allocating tensor with shape[3000000,187] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:ConcatV2] name: concat

Seems that the inference runs fine for 93750 steps (3M samples) but then when the output is assembled, cannot allocate the prob_preds of shape shape `[3000000,187]` on the GPU.

Questions:
- why does this have to be on the GPU?
- is it possible to have it on the main memory instead


In [None]:
prob_preds[0].shape, prob_preds[1].shape

In [None]:
class_preds = np.argmax(prob_preds[0], axis=1)
class_preds.shape

In [None]:
def is_cov(y_preds):
    """Return 1 if the corresponding prediction is a corona virus, 0 otherwise"""
    return (y_preds == 94).astype(int) + (y_preds == 117).astype(int)

def cov_acc(y_true, y_preds):
    """Evaluates the accuracy of the model assuming all evaluated reads are from corona virus"""
    return is_cov(y_preds).sum()/y_preds.shape[0]

cov_acc(None, class_preds)

In [None]:
np.unique(reads_info[:,1]).shape[0]

In [None]:
for refseqid in np.unique(reads_info[:,1]):
    mask_refseq = reads_info[:,1] == refseqid
    mask_strand_coding = reads_info[:,3] == '+'
    mask_strand_template = reads_info[:,3] == '-'
    mask_coding = (mask_strand_coding.astype(int) * mask_refseq.astype(int)).astype(bool)
    mask_template = (mask_strand_template.astype(int) * mask_refseq.astype(int)).astype(bool)

    aln_refseq_meta = aln.ref_sequences[refseqid]
    acc = cov_acc(None, class_preds[mask_refseq])
    acc_coding = cov_acc(None, class_preds[mask_coding])
    acc_template = cov_acc(None, class_preds[mask_template])
       
    print(f"Ref. Sequence: {aln_refseq_meta['species'].replace('scientific name', '').strip()}:")
    print(f"  Accuracy :............... {acc:.3f}")
    print(f"  Accuracy MERS:  {mers_acc(None, class_preds[mask]):.3f}")
    print(f"  Accuracy SARS:  {sars_acc(None, class_preds[mask]):.3f}")
    print(f"  Acc. coding strand: ..... {acc_coding:.3f}")
    print(f"  Acc. template strand: ... {acc_template:.3f}")
    print(f"  Nbr reads: {class_preds[mask_refseq].shape[0]:,d}, incl. {mask_coding.sum():,d} from coding strand and {mask_template.sum():,d} from template strand")
    print()

# New Section

Access AWS from colab:
- https://colab.research.google.com/github/bytehub-ai/code-examples/blob/main/tutorials/04_using_cloud_storage.ipynb
- https://python.plainenglish.io/how-to-load-data-from-aws-s3-into-google-colab-7e76fbf534d2
- https://medium.com/@lily_su/accessing-s3-bucket-from-google-colab-16f7ee6c5b51
- 

## handle GPU with tf

In [None]:
# device = tf.config.list_physical_devices('GPU')[0]
# device

PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')

In [None]:
# tf.config.experimental.get_memory_info(device='GPU:0')

{'current': 301555968, 'peak': 330915840}

In [None]:
# physical_devices = tf.config.list_physical_devices('GPU')
# try:
#     tf.config.experimental.set_memory_growth(physical_devices[0], True)
#     assert tf.config.experimental.get_memory_growth(physical_devices[0])
# except:
#     print('Invalid device or cannot modify virtual devices once initialized.')

Invalid device or cannot modify virtual devices once initialized.


In [None]:
# tf.keras.backend.clear_session()

In [None]:
# model.summary()

In [None]:
# try:
#     del model
# except:
#     pass
# import gc
# gc.collect()

In [None]:
# gpus = tf.config.list_physical_devices('GPU')
# if gpus:
#   # Restrict TensorFlow to only use the first GPU
#   try:
#     tf.config.set_visible_devices(gpus[0], 'GPU')
#     logical_gpus = tf.config.list_logical_devices('GPU')
#     print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
#   except RuntimeError as e:
#     # Visible devices must be set before GPUs have been initialized
#     print(e)

1 Physical GPUs, 1 Logical GPU


## end of section