# Inference using original data by paper's author

In this notebook, we use the validation data from the CNN Virus paper to do inference using the pretrained model.

This notebook works when run locally and also should run on Colab, as long as the file system is in line with the unified file ystem (see documentation).

# 1. Imports and setup environment

### Install and import packages

In [None]:
# Install required custom packages if not installed yet.
import importlib.util
if not importlib.util.find_spec('eccore'):
    print('installing package: `eccore`')
    ! pip install -qqU eccore
else:
    print('`eccore` already installed')
if not importlib.util.find_spec('metagentorch'):
    print('installing package: `metagentorch')
    ! pip install -qqU metagentorch
else:
    print('`metagentorch` already installed')

`eccore` already installed
`metagentorch` already installed


In [None]:
# Import all required packages
import os

os.environ['KERAS_BACKEND'] = 'torch'

from functools import partial
from pathlib import Path
# from IPython.display import display, Markdown, HTML
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from eccore.core import files_in_tree
from eccore.ipython import nb_setup
from tqdm.notebook import tqdm, trange

# Setup the notebook for development
nb_setup()

import keras
import torch
from torch.utils.data import DataLoader, Dataset, IterableDataset

print(f"Pytorch version: {torch.__version__}")
print(f"Keras version: {keras.__version__}\n")

from metagentorch.cnn_virus.architecture import create_model_original
from metagentorch.cnn_virus.data import TextFileDataset
from metagentorch.core import (ProjectFileSystem, TextFileBaseReader,
                               list_available_devices)

Set autoreload mode
Pytorch version: 2.5.1
Keras version: 3.8.0



List all computing devices available on the machine

In [None]:
list_available_devices()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

CUDA available: True
Number of CUDA devices: 1
CUDA Device 0: NVIDIA GeForce GTX 1050
CPU available: cpu


# 2. Setup paths to files

Key folders and system information

In [None]:
pfs = ProjectFileSystem()
pfs.info()

Running linux on local computer
Device's home directory: /home/vtec
Project file structure:
 - Root ........ /home/vtec/projects/bio/metagentorch 
 - Data Dir .... /home/vtec/projects/bio/metagentorch/data 
 - Notebooks ... /home/vtec/projects/bio/metagentorch/nbs


In [None]:
pfs.readme()

ReadMe file for directory `data`:

### Data structure for this project
This directory includes all the data required for the project.

```text
data
 |--- CNN_Virus_data 
 |--- ncbi                
 |--- saved         
 |--- yf-reads
 |--- ....           
     
```
#### Sub-directories
- `CNN_Virus_data`: includes all the data related to the original CNN Virus paper, i.e. training data and validation data in a format that can be used by the CNN Virus code.
- `ncbi`: includes data related to the use of viral sequences from NCBI: reference sequences, simulated reads, inference datasets, inference results.
- `saved`: includes model saved parameters and preprocessing datasets.
- `yf-reads`: includes all data related to real yellow fever reads, from "wet" samples

Also available on AWS S3 at `https://s3.ap-southeast-1.amazonaws.com/bio.cnn-virus.data/data/...`

- `p2model`: path to file with saved original pretrained model
- `p2virus_labels` path to file with virus names and labels mapping for original model
- `p2simreads`: path to folder where reads files are located (FASTQ and ALN)

In [None]:
p2model = pfs.data / 'saved/cnn_virus_original/pretrained_model.h5'
assert p2model.is_file(), f"No file found at {p2model.absolute()}"

p2virus_labels = pfs.data / 'CNN_Virus_data/virus_name_mapping'
assert p2virus_labels.is_file(), f"No file found at {p2virus_labels.absolute()}"

p2original = pfs.data / 'CNN_Virus_data'
assert p2original.is_dir(), f"No directory found at {p2original.absolute()}"

In [None]:
pfs.readme(dir_path=p2original)

ReadMe file for directory `data/CNN_Virus_data`:

### CNN Virus data

This directory includes data used to train and validate the initial CNN Virus model, as well as a few smaller datasets for experimenting. 


#### File list and description:
##### 50-mer 
50-mer reads and their labels, in *text format* with one line per sample. Each line consists of three components, separated by tabs: the 50-mer read or sequence, the virus species label and the position label:
```text
'TTACNAGCTCCAGTCTAAGATTGTAACTGGCCTTTTTAAAGATTGCTCTA    94    5\n'
``` 
Files:
- `50mer_training`: dataset with 50,903,296 reads for training
- `50mer_validating`: dataset with 1,000,000 reads for validation
- `50mer_ds_100_reads`: small subset of 100 reads from the validating dataset for experiments

##### 150-mer
150-mer reads and their labels in *text format* in a similar format as above:
```text
'TTCTTTCACCACCACAACCAGTCGGCCGTGGAGAGGCGTCGCCGCGTCTCGTTCGTCGAGGCCGATCGACTGCCGCATGAGAGCGGGTGGTATTCTTCCGAAGACGACGGAGACCGGGACGGTGATGAGGAAACTGGAGAGAGCCACAAC    6    0\n'
```
Files:
- `ICTV_150mer_benchmarking`: dataset with 10,0000 read
- `150mer_ds_100_reads`: small subset of 100 reads from `ICTV_150mer_benchmarking`

##### Longer reads
Reads of various length with no labels, in simple *fasta format*. Each read sequence is preceded by a definition line: `> Sequence n`, where `n` is the sequence number.

Files:
- `training_sequences_300bp.fasta`: dataset with 9,000 300-mer reads
- `training_sequences_500bp.fasta`: dataset with 9,000 500-mer reads
- `validation_sequences.fasta`: dataset with 564 reads of mixed lengths ranging from 163-mer to 497-mer

##### Other files:
- `virus_name_mapping`: mapping between virus species and their numerical label
- `weight_of_classes`:  weights for each virus species class in the training dataset



In [None]:
files_in_tree(path=p2original);

data
  |--CNN_Virus_data
  |    |--50mer_validating (0)
  |    |--50mer_ds_100_reads (1)
  |    |--validation_sequences.fasta (2)
  |    |--50mer_training_yf_ncbi.fa (3)
  |    |--ICTV_150mer_benchmarking (4)
  |    |--readme.md (5)
  |    |--50mer_training (6)
  |    |--50mer_training_yf (7)
  |    |--training_sequences_500bp.fasta (8)
  |    |--weight_of_classes (9)
  |    |--150mer_ds_100_reads (10)
  |    |--virus_name_mapping (11)
  |    |--training_sequences_300bp.fasta (12)
  |    |--50mer_training-yf.fa (13)


For this experiment, we will use the dataset:
- 50mer_validating

In [None]:
# p2ds = p2original / '50mer_validating'     # full dataset
p2ds = p2original / '50mer_ds_100_reads'  # smaller dataset to test code
assert p2ds.is_file()
p2ds.absolute()

PosixPath('/home/vtec/projects/bio/metagentorch/data/CNN_Virus_data/50mer_ds_100_reads')

# 3. Create inference dataset

The model expect a dataset file in the following format:

```text
    AAAAAGATTTTGAGAGAGGTCGACCTGTCCTCCTAAAACGTTTACAAAAG
    CATGTAACGCAGCTTAGTCCGATCGTGGCTATAATCCGTCTTTCGATTTG
    AACAACATCTTGTTGATGATAACCGTCAAAGTGTTTTGGGTCTGGAGGGA
    AGTACCTGGAGAGCGTTAAGAAACACAAACGGCTGGATGTAGTGCCGCGC
    CCACGTCGATGAAGCTCCGACGAGAGTCGGCGCTGAGCCCGCGCACCTCC
```

`50mer_validating` is already in the correct format

The mapping between code and virus specie name are in the file `virus_labels.csv`

In [None]:
ds = TextFileDataset(p2file=p2ds)

for i,(seq, (lbl, pos)) in enumerate(ds):
    print(f"Sequence {i+1}:\n Seq:      {seq.shape}\n Label:    {lbl.shape}\n Position: {pos.shape}\n")
    if i+1 == 1 :break

Sequence 1:
 Seq:      torch.Size([50, 5])
 Label:    torch.Size([187])
 Position: torch.Size([10])



`AAAAAGATTTTGAGAGAGGTCGACCTGTCCTCCTAAAACGTTTACAAAAG`

In [None]:
seq[0:6, :]

tensor([[1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 0, 1, 0, 0]])

## Create the data loader for the model 

Define batch size and create a first dataset accessing data from the dataset text file. Batch size can be adjusted depending on the memory available on the GPU. For reference, `bs = 4096` was used with a 4GB GPU. 

In [None]:
bs = 20

dl = DataLoader(ds, batch_size=bs)

for i, (seq_b, (lbl_b, pos_b)) in enumerate(dl):
    print(f"Batch {i+1}:\n Seq:      {seq_b.shape}\n Label:    {lbl_b.shape}\n Position: {pos_b.shape}\n")
    print(seq_b[:2, :8, :])
    if i+1 == 1 :break

Batch 1:
 Seq:      torch.Size([20, 50, 5])
 Label:    torch.Size([20, 187])
 Position: torch.Size([20, 10])

tensor([[[1, 0, 0, 0, 0],
         [1, 0, 0, 0, 0],
         [1, 0, 0, 0, 0],
         [1, 0, 0, 0, 0],
         [1, 0, 0, 0, 0],
         [0, 0, 1, 0, 0],
         [1, 0, 0, 0, 0],
         [0, 0, 0, 1, 0]],

        [[0, 1, 0, 0, 0],
         [1, 0, 0, 0, 0],
         [0, 0, 0, 1, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 1, 0],
         [1, 0, 0, 0, 0],
         [1, 0, 0, 0, 0],
         [0, 1, 0, 0, 0]]])


The bases in the read sequences are encoded as a 5-dim one-hot-encoded vector, as the model expects.

In this example, each 50bp read in converted into a tensor of shape [50,5]

# 4. Inference

Load and review the pretrained model

In [None]:
model = create_model_original(path2parameters=p2model).to(device)

Creating CNN Model (Original)
Loading parameters from pretrained_model.h5
Created pretrained model


In [None]:
model.summary()

Present the inference dataset to the model and collect prediction.

The model returns two sets of probabilities:
- `prob_preds_species`: a vector of 187 values representing the probability that each of the 187 species are the correct ones, for each input read
- `prob_preds_pos`: a vector of 10 values representing the probability that the read is from the corresponding segment of the original sequence (1 to 10)

In [None]:
prob_preds_species, prob_preds_pos = model.predict(dl, verbose=1)

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 28ms/step




In [None]:
prob_preds_species.shape, prob_preds_pos.shape

((100, 187), (100, 10))

To find the prediction, we pick the argmax probability, which gives us the index/code for the predicted virus species.

In [None]:
prob_preds_species

array([[1.5632741e-20, 1.6047729e-34, 6.7372035e-25, ..., 1.4837560e-23,
        1.5598299e-29, 7.1682802e-37],
       [5.5474295e-14, 9.9999917e-01, 4.1924051e-11, ..., 1.1497376e-14,
        1.8635946e-17, 2.7175434e-15],
       [1.2425527e-13, 1.9970270e-27, 2.0033034e-15, ..., 1.1757105e-23,
        2.1326144e-27, 5.3270406e-27],
       ...,
       [8.8232195e-01, 1.6157335e-24, 7.3480706e-16, ..., 7.7248165e-29,
        7.5520697e-27, 5.1577789e-24],
       [7.8338002e-05, 7.1431316e-27, 5.0639932e-23, ..., 1.5747985e-27,
        1.2241149e-31, 4.4118172e-34],
       [2.7330055e-22, 6.1275908e-22, 3.9321292e-18, ..., 7.4173405e-26,
        1.2151302e-22, 5.7734036e-22]], dtype=float32)

In [None]:
label_preds = np.argmax(prob_preds_species, axis=1)
label_preds.shape, label_preds[:10]

((100,), array([ 71,   1, 158,   6,  71,  87,  10, 178,  71,  22]))

In [None]:
TP = 0
for i, (_, (label_target_b, _)) in enumerate(dl):
    lbl_target_b = np.argmax(label_target_b, axis=1)
    TP_batch = np.equal(label_preds[i*bs:(i+1)*bs], lbl_target_b).sum()
    TP = TP + TP_batch
    print(f"Batch Accuracy: {TP_batch.numpy()/len(lbl_target_b):.1%}")
print(f"Full Dataset Accuracy: {TP.numpy()/len(label_preds):.1%}")

Batch Accuracy: 95.0%
Batch Accuracy: 95.0%
Batch Accuracy: 95.0%
Batch Accuracy: 80.0%
Batch Accuracy: 90.0%
Full Dataset Accuracy: 91.0%


## end of section