In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from fairseq.data import data_utils
from fairseq.models.roberta import RobertaModel

In [3]:
import sys
sys.path.append('../..')
from go_annotation.ontology import Ontology

%matplotlib inline
sns.set(context='talk', style='ticks',
        color_codes=True, rc={'legend.frameon': False})

In [4]:
roberta = RobertaModel.from_pretrained(
    '/projects/deepgreen/pstjohn/roberta_base_checkpoint',
    data_name_or_path='/projects/deepgreen/pstjohn/swissprot_go_annotation/fairseq_swissprot/',
    checkpoint_file='roberta.base_go_swissprot.pt')

_ = roberta.eval()  # disable dropout (or leave in train mode to finetune)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
_ = roberta.to(device)

In [31]:
ont = Ontology()
_ancestor_array = ont.ancestor_array()


def normalize_logits(logits):
    bsz = logits.shape[0]
    index_tensor = logits.new_tensor(_ancestor_array, dtype=torch.int64)
    index_tensor = index_tensor.unsqueeze(0).expand((bsz, -1, -1))  # Array of ancestors, offset by one
    padded_logits = torch.nn.functional.pad(logits, (1, 0), value=float('inf'))  # Make 0 index return inf
    padded_logits = padded_logits.unsqueeze(-1).expand((-1, -1, index_tensor.shape[2]))
    normed_logits = torch.gather(padded_logits, 1, index_tensor)
    normed_logits, _ = torch.min(normed_logits, -1)

    return normed_logits

import requests
fasta = requests.get('https://www.uniprot.org/uniprot/P00362.fasta').text

def encode(fasta):
    input_sequence = ''.join(fasta.split('\n')[1:]).replace('', ' ')
    return roberta.task.source_dictionary.encode_line(input_sequence)

tokens = encode(fasta)

In [19]:
''.join(fasta.split('\n')[1:])

'MAVKVGINGFGRIGRNVFRAALKNPDIEVVAVNDLTDANTLAHLLKYDSVHGRLDAEVSVNGNNLVVNGKEIIVKAERDPENLAWGEIGVDIVVESTGRFTKREDAAKHLEAGAKKVIISAPAKNEDITIVMGVNQDKYDPKAHHVISNASCTTNCLAPFAKVLHEQFGIVRGMMTTVHSYTNDQRILDLPHKDLRRARAAAESIIPTTTGAAKAVALVLPELKGKLNGMAMRVPTPNVSVVDLVAELEKEVTVEEVNAALKAAAEGELKGILAYSEEPLVSRDYNGSTVSSTIDALSTMVIDGKMVKVVSWYDNETGYSHRVVDLAAYIASKGL'

In [22]:
tokens.shape

torch.Size([336])

In [36]:
logits, _ = roberta.model(tokens.long().cuda().unsqueeze(0),
                          features_only=True,
                          classification_head_name='go_prediction')

normalize_logits(logits)

tensor([[-12.6296, -12.6296, -12.6296,  ..., -15.5963, -14.9519, -14.9519]],
       device='cuda:0', grad_fn=<MinBackward0>)

In [46]:
np.array(ont.G.nodes)[(logits > 0).detach().cpu().numpy().flatten()]

IndexError: boolean index did not match indexed array along dimension 0; dimension is 44232 but corresponding boolean dimension is 32012

In [50]:
ont.G.nodes['GO:0000001']

{'name': 'mitochondrion inheritance',
 'namespace': 'biological_process',
 'index': 0}

In [53]:
import pandas as pd

In [None]:
ont.nod

In [59]:
df = pd.DataFrame(data_iter())

In [None]:
node_attr_dict_factory