Following some of the examples here: https://github.com/pytorch/fairseq/tree/master/examples/roberta

In [1]:
from fairseq.models.roberta import RobertaModel

In [2]:
roberta = RobertaModel.from_pretrained('/projects/deepgreen/pstjohn/roberta_base_checkpoint/',
                                       checkpoint_file='checkpoint_best.pt')
_ = roberta.eval()  # disable dropout (or leave in train mode to finetune)

In [3]:
# https://www.uniprot.org/uniprot/P14618.fasta
example_sequence = \
"""
MSKPHSEAGTAFIQTQQLHAAMADTFLEHMCRLDIDSPPITARNTGIICTIGPASRSVET
LKEMIKSGMNVARLNFSHGTHEYHAETIKNVRTATESFASDPILYRPVAVALDTKGPEIR
TGLIKGSGTAEVELKKGATLKITLDNAYMEKCDENILWLDYKNICKVVEVGSKIYVDDGL
ISLQVKQKGADFLVTEVENGGSLGSKKGVNLPGAAVDLPAVSEKDIQDLKFGVEQDVDMV
FASFIRKASDVHEVRKVLGEKGKNIKIISKIENHEGVRRFDEILEASDGIMVARGDLGIE
IPAEKVFLAQKMMIGRCNRAGKPVICATQMLESMIKKPRPTRAEGSDVANAVLDGADCIM
LSGETAKGDYPLEAVRMQHLIAREAEAAIYHLQLFEELRRLAPITSDPTEATAVGAVEAS
FKCCSGAIIVLTKSGRSAHQVARYRPRAPIIAVTRNPQTARQAHLYRGIFPVLCKDPVQE
AWAEDVDLRVNFAMNVGKARGFFKKGDVVIVLTGWRPGSGFTNTMRVVPVP
"""

Unlike the roberta example on fairseq, we're not using the GPT-2 byte-pair encoder, so the standard `roberta.encode` and `roberta.decode` methods won't work

In [4]:
def encode(sequence):
    input_sequence = ' '.join(sequence.replace('\n', ''))
    return roberta.task.source_dictionary.encode_line(input_sequence)

tokens = encode(example_sequence)
tokens

tensor([20,  8, 15, 14, 21,  8,  9,  5,  6, 11,  5, 17, 12, 16, 11, 16, 16,  4,
        21,  5,  5, 20,  5, 13, 11, 17,  4,  9, 21, 20, 23, 10,  4, 13, 12, 13,
         8, 14, 14, 12, 11,  5, 10, 18, 11,  6, 12, 12, 23, 11, 12,  6, 14,  5,
         8, 10,  8,  7,  9, 11,  4, 15,  9, 20, 12, 15,  8,  6, 20, 18,  7,  5,
        10,  4, 18, 17,  8, 21,  6, 11, 21,  9, 19, 21,  5,  9, 11, 12, 15, 18,
         7, 10, 11,  5, 11,  9,  8, 17,  5,  8, 13, 14, 12,  4, 19, 10, 14,  7,
         5,  7,  5,  4, 13, 11, 15,  6, 14,  9, 12, 10, 11,  6,  4, 12, 15,  6,
         8,  6, 11,  5,  9,  7,  9,  4, 15, 15,  6,  5, 11,  4, 15, 12, 11,  4,
        13, 18,  5, 19, 20,  9, 15, 23, 13,  9, 18, 12,  4, 22,  4, 13, 19, 15,
        18, 12, 23, 15,  7,  7,  9,  7,  6,  8, 15, 12, 19,  7, 13, 13,  6,  4,
        12,  8,  4, 16,  7, 15, 16, 15,  6,  5, 13, 17,  4,  7, 11,  9,  7,  9,
        18,  6,  6,  8,  4,  6,  8, 15, 15,  6,  7, 18,  4, 14,  6,  5,  5,  7,
        13,  4, 14,  5,  7,  8,  9, 15, 

In [5]:
import torch
if torch.cuda.is_available():
    print("Using the GPU")
    roberta.cuda()

with torch.no_grad():
    features = roberta.extract_features(tokens.to(torch.int64))

In [6]:
# Create the ontology object and backpropogate GO labels

import sys
sys.path.append('/home/pstjohn/Research/20201119_fairseq/go_annotation')

from ontology import Ontology
ont = Ontology()
ont.total_nodes

32012

In [7]:
from scipy.sparse import coo_matrix

In [9]:
import numpy as np

In [29]:
arr = np.array(list(ont.iter_ancestor_array()))
index_tensor = torch.tensor(coo_matrix((arr[:, 0] + 1, (arr[:, 1], arr[:, 2]))).todense())

In [31]:
index_tensor.shape

torch.Size([32012, 88])

In [15]:
roberta.model.register_classification_head(
    'go_prediction',
    num_classes=ont.total_nodes,
)

In [16]:
inputs = tokens.to(torch.int64).unsqueeze_(0).expand((2, 532))

In [17]:
with torch.no_grad():
    logits, _ = roberta.model(
        inputs,
        features_only=True,
        classification_head_name='go_prediction')

In [18]:
import torch

In [42]:
torch.min

<function _VariableFunctionsClass.min>

In [44]:
normed_logits, _ = torch.min(torch.gather(torch.nn.functional.pad(logits, (1, 0), value=float('inf')).unsqueeze(-1).expand(-1, -1, 88), 1, index_tensor.unsqueeze(0).expand((2, -1, -1))), -1)

In [45]:
normed_logits

tensor([[-0.1956, -0.1956, -0.0722,  ..., -0.1196, -0.0788, -0.0881],
        [-0.1956, -0.1956, -0.0722,  ..., -0.1196, -0.0788, -0.0881]])

In [20]:
import numpy as np

In [None]:
from torch_scatter import scatter

In [None]:
normed_logits = scatter(torch.gather(logits, 1, term_tensor), ancestor_tensor, reduce='min')

In [None]:
head_nodes = ont.get_head_node_indices()
_bf_index = ont.terms_to_indices(ont.get_descendants(ont.term_index[head_nodes[0]]))
_mp_index = ont.terms_to_indices(ont.get_descendants(ont.term_index[head_nodes[1]]))
_cc_index = ont.terms_to_indices(ont.get_descendants(ont.term_index[head_nodes[2]]))

In [None]:
torch.gather(normed_logits, -1, convert_and_resize(_cc_index)).shape