### Notebook to extract trained feature vectors from enzyme sequences

In [1]:
# Check if we're using a GPU
import torch
torch.cuda.is_available()

True

In [2]:
# This is a bit of a pain, we need to have the layers I used to finetune the model in the python path
# for the roberta loading to work
import sys
sys.path.append('../..')
from go_annotation.ontology import Ontology

In [5]:
from fairseq.models.roberta import RobertaModel

roberta = RobertaModel.from_pretrained(
    '/projects/deepgreen/pstjohn/20210121_go_checkpoints/',
    data_name_or_path='/projects/deepgreen/pstjohn/swissprot_go_annotation/fairseq_swissprot/',
    checkpoint_file='swissprot_preinit.pt')

_ = roberta.eval()  # disable dropout

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
_ = roberta.to(device)

In [21]:
def count_parameters(model):
    table = []
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table += [(name, param)]
        total_params+=param

    return table, total_params
    
count_parameters(roberta)

([('model.encoder.sentence_encoder.embed_tokens.weight', 25344),
  ('model.encoder.sentence_encoder.embed_positions.weight', 787968),
  ('model.encoder.sentence_encoder.layers.0.self_attn.k_proj.weight', 589824),
  ('model.encoder.sentence_encoder.layers.0.self_attn.k_proj.bias', 768),
  ('model.encoder.sentence_encoder.layers.0.self_attn.v_proj.weight', 589824),
  ('model.encoder.sentence_encoder.layers.0.self_attn.v_proj.bias', 768),
  ('model.encoder.sentence_encoder.layers.0.self_attn.q_proj.weight', 589824),
  ('model.encoder.sentence_encoder.layers.0.self_attn.q_proj.bias', 768),
  ('model.encoder.sentence_encoder.layers.0.self_attn.out_proj.weight',
   589824),
  ('model.encoder.sentence_encoder.layers.0.self_attn.out_proj.bias', 768),
  ('model.encoder.sentence_encoder.layers.0.self_attn_layer_norm.weight', 768),
  ('model.encoder.sentence_encoder.layers.0.self_attn_layer_norm.bias', 768),
  ('model.encoder.sentence_encoder.layers.0.fc1.weight', 2359296),
  ('model.encoder.sent

In [8]:
example_sequence = \
"""
MSKPHSEAGTAFIQTQQLHAAMADTFLEHMCRLDIDSPPITARNTGIICTIGPASRSVET
LKEMIKSGMNVARLNFSHGTHEYHAETIKNVRTATESFASDPILYRPVAVALDTKGPEIR
TGLIKGSGTAEVELKKGATLKITLDNAYMEKCDENILWLDYKNICKVVEVGSKIYVDDGL
ISLQVKQKGADFLVTEVENGGSLGSKKGVNLPGAAVDLPAVSEKDIQDLKFGVEQDVDMV
FASFIRKASDVHEVRKVLGEKGKNIKIISKIENHEGVRRFDEILEASDGIMVARGDLGIE
IPAEKVFLAQKMMIGRCNRAGKPVICATQMLESMIKKPRPTRAEGSDVANAVLDGADCIM
LSGETAKGDYPLEAVRMQHLIAREAEAAIYHLQLFEELRRLAPITSDPTEATAVGAVEAS
FKCCSGAIIVLTKSGRSAHQVARYRPRAPIIAVTRNPQTARQAHLYRGIFPVLCKDPVQE
AWAEDVDLRVNFAMNVGKARGFFKKGDVVIVLTGWRPGSGFTNTMRVVPVP
"""

def encode(sequence):
    input_sequence = '<s> ' + ' '.replace('B', 'D').replace('Z', 'E').replace('J', 'L').join(sequence.replace('\n', ''))
    return roberta.task.source_dictionary.encode_line(input_sequence, add_if_not_exist=False)[:roberta.model.max_positions()].long()

tokens = encode(example_sequence)
tokens

tensor([ 0, 20,  8, 15, 14, 21,  8,  9,  5,  6, 11,  5, 17, 12, 16, 11, 16, 16,
         4, 21,  5,  5, 20,  5, 13, 11, 17,  4,  9, 21, 20, 23, 10,  4, 13, 12,
        13,  8, 14, 14, 12, 11,  5, 10, 18, 11,  6, 12, 12, 23, 11, 12,  6, 14,
         5,  8, 10,  8,  7,  9, 11,  4, 15,  9, 20, 12, 15,  8,  6, 20, 18,  7,
         5, 10,  4, 18, 17,  8, 21,  6, 11, 21,  9, 19, 21,  5,  9, 11, 12, 15,
        18,  7, 10, 11,  5, 11,  9,  8, 17,  5,  8, 13, 14, 12,  4, 19, 10, 14,
         7,  5,  7,  5,  4, 13, 11, 15,  6, 14,  9, 12, 10, 11,  6,  4, 12, 15,
         6,  8,  6, 11,  5,  9,  7,  9,  4, 15, 15,  6,  5, 11,  4, 15, 12, 11,
         4, 13, 18,  5, 19, 20,  9, 15, 23, 13,  9, 18, 12,  4, 22,  4, 13, 19,
        15, 18, 12, 23, 15,  7,  7,  9,  7,  6,  8, 15, 12, 19,  7, 13, 13,  6,
         4, 12,  8,  4, 16,  7, 15, 16, 15,  6,  5, 13, 17,  4,  7, 11,  9,  7,
         9, 18,  6,  6,  8,  4,  6,  8, 15, 15,  6,  7, 18,  4, 14,  6,  5,  5,
         7, 13,  4, 14,  5,  7,  8,  9, 

In [13]:
def calc_enzyme_features(tokens):
    with torch.no_grad():
        return roberta.extract_features(tokens).detach().cpu().numpy()[0, 0, :]

In [14]:
calc_enzyme_features(tokens)[:10]

array([ 0.11360448,  0.30674773, -0.01675967,  0.4423397 ,  0.01667783,
       -0.18440159, -0.18239962,  0.2038919 , -0.5497034 ,  0.28601235],
      dtype=float32)

In [15]:
# Load swissprot sequences and annotations
import pandas as pd
import os
swissprot_dir = '/projects/deepgreen/pstjohn/swissprot_go_annotation'

swissprot = pd.read_parquet(os.path.join(swissprot_dir, 'parsed_swissprot_uniref_clusters.parquet'))
go_terms = pd.read_parquet(os.path.join(swissprot_dir, 'swissprot_quickgo.parquet'))

In [None]:
swissprot.head()

In [None]:
go_terms.head()

In [None]:
ont = Ontology()
ontology_data = pd.DataFrame(({'id': id_, **data} for id_, data in ont.G.nodes(data=True)))
ontology_data.head()

In [None]:
# This might be one way to get at NADH or NADPH dependent enzymes? But really we'll want pairs of enzymes
ontology_data[ontology_data.name.str.contains('NADH')].head()