### Notebook to extract trained feature vectors from enzyme sequences

In [2]:
# Check if we're using a GPU
import torch
torch.cuda.is_available()

True

In [39]:
# This is a bit of a pain, we need to have the layers I used to finetune the model in the python path
# for the roberta loading to work
import sys
sys.path.append('../..')
from go_annotation.ontology import Ontology

In [7]:
from fairseq.models.roberta import RobertaModel

roberta = RobertaModel.from_pretrained(
    '/projects/deepgreen/pstjohn/roberta_base_checkpoint',
    data_name_or_path='/projects/deepgreen/pstjohn/swissprot_go_annotation/fairseq_swissprot/',
    checkpoint_file='roberta.base_go_swissprot.pt')

_ = roberta.eval()  # disable dropout

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
_ = roberta.to(device)

In [16]:
example_sequence = \
"""
MSKPHSEAGTAFIQTQQLHAAMADTFLEHMCRLDIDSPPITARNTGIICTIGPASRSVET
LKEMIKSGMNVARLNFSHGTHEYHAETIKNVRTATESFASDPILYRPVAVALDTKGPEIR
TGLIKGSGTAEVELKKGATLKITLDNAYMEKCDENILWLDYKNICKVVEVGSKIYVDDGL
ISLQVKQKGADFLVTEVENGGSLGSKKGVNLPGAAVDLPAVSEKDIQDLKFGVEQDVDMV
FASFIRKASDVHEVRKVLGEKGKNIKIISKIENHEGVRRFDEILEASDGIMVARGDLGIE
IPAEKVFLAQKMMIGRCNRAGKPVICATQMLESMIKKPRPTRAEGSDVANAVLDGADCIM
LSGETAKGDYPLEAVRMQHLIAREAEAAIYHLQLFEELRRLAPITSDPTEATAVGAVEAS
FKCCSGAIIVLTKSGRSAHQVARYRPRAPIIAVTRNPQTARQAHLYRGIFPVLCKDPVQE
AWAEDVDLRVNFAMNVGKARGFFKKGDVVIVLTGWRPGSGFTNTMRVVPVP
"""

def encode(sequence):
    input_sequence = ' '.replace('B', 'D').replace('Z', 'E').replace('J', 'L').join(sequence.replace('\n', ''))
    return roberta.task.source_dictionary.encode_line(input_sequence, add_if_not_exist=False)[:roberta.model.max_positions()].long()

tokens = encode(example_sequence)
tokens

tensor([20,  8, 15, 14, 21,  8,  9,  5,  6, 11,  5, 17, 12, 16, 11, 16, 16,  4,
        21,  5,  5, 20,  5, 13, 11, 17,  4,  9, 21, 20, 23, 10,  4, 13, 12, 13,
         8, 14, 14, 12, 11,  5, 10, 18, 11,  6, 12, 12, 23, 11, 12,  6, 14,  5,
         8, 10,  8,  7,  9, 11,  4, 15,  9, 20, 12, 15,  8,  6, 20, 18,  7,  5,
        10,  4, 18, 17,  8, 21,  6, 11, 21,  9, 19, 21,  5,  9, 11, 12, 15, 18,
         7, 10, 11,  5, 11,  9,  8, 17,  5,  8, 13, 14, 12,  4, 19, 10, 14,  7,
         5,  7,  5,  4, 13, 11, 15,  6, 14,  9, 12, 10, 11,  6,  4, 12, 15,  6,
         8,  6, 11,  5,  9,  7,  9,  4, 15, 15,  6,  5, 11,  4, 15, 12, 11,  4,
        13, 18,  5, 19, 20,  9, 15, 23, 13,  9, 18, 12,  4, 22,  4, 13, 19, 15,
        18, 12, 23, 15,  7,  7,  9,  7,  6,  8, 15, 12, 19,  7, 13, 13,  6,  4,
        12,  8,  4, 16,  7, 15, 16, 15,  6,  5, 13, 17,  4,  7, 11,  9,  7,  9,
        18,  6,  6,  8,  4,  6,  8, 15, 15,  6,  7, 18,  4, 14,  6,  5,  5,  7,
        13,  4, 14,  5,  7,  8,  9, 15, 

In [29]:
# This calculates features for a single enzyme
def calc_enzyme_features(tokens):
    with torch.no_grad():
        return roberta.extract_features(tokens).detach().cpu().numpy()[0, 0, :]

In [30]:
calc_enzyme_features(tokens)[:10]

array([ 0.33626685, -0.15387164,  0.17733084, -0.3071365 , -0.7067522 ,
        0.0974492 ,  0.39190924,  0.2440962 , -0.58304805,  0.01171669],
      dtype=float32)

In [33]:
# Load swissprot sequences and annotations
import pandas as pd
import os
swissprot_dir = '/projects/deepgreen/pstjohn/swissprot_go_annotation'

swissprot = pd.read_parquet(os.path.join(swissprot_dir, 'parsed_swissprot_uniref_clusters.parquet'))
go_terms = pd.read_parquet(os.path.join(swissprot_dir, 'swissprot_quickgo.parquet'))

In [37]:
swissprot.head()

Unnamed: 0,UniRef100 ID,UniRef90 ID,UniRef50 ID,accession,EMBL,RefSeq,KEGG,InterPro,Pfam,NCBI Taxonomy,length,sequence,subcellularLocalization
0,UniRef100_Q9Q8J2,UniRef90_Q9Q8J2,UniRef50_P16712,Q9Q8J2,AF170726,NP_051822.1,vg:932054,IPR027417,PF04851,31530,478,MSVCSEIDYALYTELKKFLNSQPLFLFNADKNFVEVVPSSSFKFYI...,Virion
1,UniRef100_P14197,UniRef90_P14197,UniRef50_P14197,P14197,X16524,XP_643326.1,ddi:DDB_G0276031,IPR036322,PF00400,44689,478,MGSRLNPSSNMYIPMNGPRGGYYGMPSMGQLQHPLFNYQFPPGGFQ...,
2,UniRef100_A6VUT8,UniRef90_A6VUT8,UniRef50_Q65UI5,A6VUT8,CP000749,WP_012069002.1,mmw:Mmwyl1_1288,IPR011763,PF03255,400668,315,MNLDYLPFEQPIAELEQKIEELRLVGNDNELNISDEISRLEDKKIA...,Cytoplasm
3,UniRef100_A4QKB4,UniRef90_P56765,UniRef50_P56765,A4QKB4,AP009370,YP_001123295.1,,IPR011762,PF01039,50458,487,MEKSWFNLMFSKGELEYRGELSKAMDSFAPSEKTTISQDRFIYDMD...,Plastid
4,UniRef100_Q9SQR4,UniRef90_Q9SQR4,UniRef50_Q9SQR4,Q9SQR4,CP002686,NP_187048.1,ath:AT3G03980,IPR002347,,3702,270,MSTHSSISQPPLPLAGRVAIVTGSSRGIGRAIAIHLAELGARIVIN...,Plastid


In [35]:
go_terms.head()

Unnamed: 0,GENE PRODUCT DB,GENE PRODUCT ID,SYMBOL,QUALIFIER,GO TERM,GO ASPECT,ECO ID,GO EVIDENCE CODE,REFERENCE,WITH/FROM,TAXON ID,ASSIGNED BY,ANNOTATION EXTENSION,DATE
0,UniProtKB,A2CKF6,A2CKF6,part_of,GO:0005576,C,ECO:0000256,IEA,GO_REF:0000002,InterPro:IPR003571|InterPro:IPR018354,8613,InterPro,,20200613
1,UniProtKB,A2CKF6,A2CKF6,involved_in,GO:0009405,P,ECO:0000256,IEA,GO_REF:0000002,InterPro:IPR003571,8613,InterPro,,20200613
2,UniProtKB,A2CKF6,A2CKF6,part_of,GO:0005576,C,ECO:0000322,IEA,GO_REF:0000043,UniProtKB-KW:KW-0964,8613,UniProt,,20200613
3,UniProtKB,A2CKF6,A2CKF6,enables,GO:0090729,F,ECO:0000322,IEA,GO_REF:0000043,UniProtKB-KW:KW-0800,8613,UniProt,,20200613
4,UniProtKB,A2CKF6,A2CKF6,part_of,GO:0035792,C,ECO:0000322,IEA,GO_REF:0000043,UniProtKB-KW:KW-0629,8613,UniProt,,20200613


In [51]:
ont = Ontology()
ontology_data = pd.DataFrame(({'id': id_, **data} for id_, data in ont.G.nodes(data=True)))
ontology_data.head()

Unnamed: 0,id,name,namespace,index
0,GO:0000001,mitochondrion inheritance,biological_process,0.0
1,GO:0048308,organelle inheritance,biological_process,1.0
2,GO:0048311,mitochondrion distribution,biological_process,2.0
3,GO:0000002,mitochondrial genome maintenance,biological_process,3.0
4,GO:0007005,mitochondrion organization,biological_process,4.0


In [55]:
# This might be one way to get at NADH or NADPH dependent enzymes? But really we'll want pairs of enzymes
ontology_data[ontology_data.name.str.contains('NADH')].head()

Unnamed: 0,id,name,namespace,index
3891,GO:0003954,NADH dehydrogenase activity,molecular_function,3209.0
4192,GO:0004318,enoyl-[acyl-carrier-protein] reductase (NADH) ...,molecular_function,3503.0
4475,GO:0004589,orotate reductase (NADH) activity,molecular_function,3786.0
5846,GO:0006116,NADH oxidation,biological_process,5103.0
5847,GO:0006734,NADH metabolic process,biological_process,5104.0
