In [1]:
import sys, os
import torch
print(f'Using torch version {torch.__version__}')
from importlib import import_module
from phosking.dataset import ESM_Embeddings_test
import numpy as np

Using torch version 1.13.0+cu117


This notebook is a simplified copy of the main file "phosking/test_model.py", and runs the best model for the sequences found in the corresponding fasta file (fasta_file below)


Parameters for testing:

In [2]:
# The following example parameters are equivalent for running:
# > python3 phosking/test_model.py -i data/test/small_test.fsa -p 1280 -m models/CNN_RNN.py -n CNN_RNN_FFNN -a 1280,512,1024 -aaw 16 -sd states_dicts/CNN_RNN.pth

model_file = 'models/CNN_RNN.py'
model_name = 'CNN_RNN_FFNN'
state_dict = 'states_dicts/CNN_RNN.pth'
model_args = (1280, 512, 1024)
params = 1280

fasta_file = 'data/test/small_test.fsa' # Example file with 3 sequences
aa_window = 16
two_dims = False
mode = 'phospho'
force_cpu = False

# In order to try other sequences, modify the 'data/test/small_test.fsa' file or change the file name

Model loading, sequence reading, computing of embeddings and test dataset creation:

In [3]:
print(f'Using python env in {sys.executable}')

# Hacky thing to import the model by storing the filename and model in strings
model_dir = os.path.dirname(model_file)
sys.path.append(model_dir)
model_module_name = os.path.basename(model_file)[:-3]
model_module = import_module(model_module_name)
model_class = getattr(model_module, model_name)

device = torch.device('cuda' if not force_cpu and torch.cuda.is_available() else 'cpu')
print(f'Using torch device of type {device.type}{": " + torch.cuda.get_device_name(device) if device.type == "cuda" else ""}')

if model_args:
    model: torch.nn.Module = model_class(*model_args)
else:
    model: torch.nn.Module = model_class()
model = model.to(device)
state_dict = torch.load(state_dict, map_location=device)
model.load_state_dict(state_dict)
model.eval()

dataset = ESM_Embeddings_test(fasta_file=fasta_file,
                              params=params,
                              device=device,
                              aa_window=aa_window,
                              two_dims=two_dims,
                              mode=mode
)

Using python env in /usr/bin/python3.11
Using torch device of type cuda: NVIDIA GeForce GTX 1660 Ti with Max-Q Design
Reading fasta...
Found 3 sequences!
Computing embeddings...
3 embeddings computed!

Model testing and printing of the output: it represents the phosphorylations through the sequence and the list of phosphorylable aminoacids with their scores.

Note that the chosen threshold for high significance of the score is 0.99 (' \* '), followed by other lower significant thresholds at 0.9 (' + ') and 0.75 (' . '). This is decision is based on the observed accuracies at various thresholds.

In [4]:
predictions = dict()
for seq_ID, seq in dataset.seq_data:
    if seq_ID in dataset.IDs():
        with torch.no_grad():
            idxs, inputs = dataset[seq_ID]
            inputs = inputs.to(device)

            preds = model(inputs)
            preds = preds.detach().cpu().numpy().flatten()

        predictions[seq_ID] = dict()
        for i,pos in enumerate(idxs):
            predictions[seq_ID][pos] = preds[i]

        # Printing results per sequence
        dots = ''
        i = 0
        for pos in range(len(seq)):
            if pos + 1 in idxs:
                if preds[i] > 0.99:
                    dots += '*'
                elif preds[i] > 0.9:
                    dots += '+'
                elif preds[i] > 0.75:
                    dots += '.'
                else:
                    dots += ' '
                i += 1
            else:
                dots += ' '
        
        print('- ' * 41 + '\n > ' + seq_ID)

        for i in range(len(seq)//80+1):
            l = i*80
            print(dots[l:l+80])
            print(seq[l:l+80])
            print(' '*9+'|'+'|'.join(list('{:<9}'.format(l+j*10) for j in range(1,9))))
        
        print('')
        print('Pos.  Score       '*5)
        for i,pos in enumerate(idxs):
            if i%5 == 0:
                if i != 0:
                    print('')
            else:
                print('|', end='  ')
            if preds[i] > 0.99:
                dot = '*'
            elif preds[i] > 0.9:
                dot = '+'
            elif preds[i] > 0.75:
                dot = '.'
            else:
                dot = ' '
            print('{:<6}{:<5.3g} {}'.format(pos, round(preds[i], 3), dot), end='  ')
        print('\n')
        
    else:
        print('- ' * 41 + '\n > ' + seq_ID)
        print('No phosphorylable aminoacids in this sequence...')

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 
 > P53004
                    *           **         +     *       *     **               
MNAEPERKFGVVVVGVGRAGSVRMRDLRNPHPSSAFLNLIGFVSRRELGSIDGVQQISLEDALSSQEVEVAYICSESSSH
         |10       |20       |30       |40       |50       |60       |70       |80       
  *                                                                 *           
EDYIRQFLNAGKHVLVEYPMTLSLAAAQELWELAEQKGKVLHEEHVELLMEEFAFLKKEVVGKDLLKGSLLFTAGPLEEE
         |90       |100      |110      |120      |130      |140      |150      |160      
             *         .             *        +   +                * *    * *   
RFGFPAFSGISRLTWLVSLFGELSLVSATLEERKEDQYMKMTVCLETEKKSPLSWIEEKGPGLKRNRYLSFHFKSGSLEN
         |170      |180      |190      |200      |210      |220      |230      |240      
                          *                       *     
VPNVGVNKNIFLKDQNIFVQKLLGQFSEKELAAEKKRILHCLGLAEEIQKYCCSRK
         |250      |260      |270    