In [1]:
import os
import numpy as np
import pandas as pd

In [2]:
import scipy
import scipy.sparse

In [3]:
from fairseq.models.roberta import RobertaModel

import sys
sys.path.append('../..')
from go_annotation import fairseq_layers

In [4]:
roberta = RobertaModel.from_pretrained('/gpfs/alpine/scratch/pstjohn/bie108/fairseq-uniparc/roberta_base_checkpoint/',
                                       data_name_or_path='/ccs/home/pstjohn/project_work/swissprot_go_annotation/fairseq_swissprot/input0',
                                       checkpoint_file='roberta.base_with_go_bias.pt')
_ = roberta.eval()  # disable dropout (or leave in train mode to finetune)

In [5]:
with open('/ccs/home/pstjohn/project_work/swissprot_go_annotation/fairseq_swissprot_debug/input0/train.raw') as f:
    input0 = f.readline()

In [6]:
import torch
import numpy as np
import scipy.sparse

targets = scipy.sparse.load_npz('/ccs/home/pstjohn/project_work/swissprot_go_annotation/fairseq_swissprot_debug/label/train.npz')[0]
targets = torch.tensor(np.asarray(targets.todense()))

In [7]:
def encode(sequence):
    return roberta.task.source_dictionary.encode_line(sequence)

tokens = encode(input0)
tokens

tensor([20,  4,  9,  8,  4, 11,  4, 16, 14, 12,  5,  4,  7, 18,  6, 11,  7, 18,
         4, 14,  6,  8, 15,  8,  7,  8, 18, 10,  5,  4,  4,  4,  5,  5,  4,  5,
         9,  6, 11, 11, 16,  4, 18, 18,  7,  4, 13,  8, 13, 13, 12, 10, 21, 20,
         4, 18,  5,  4, 16,  5,  4,  6,  7, 18, 17, 10,  4,  8,  5, 13, 10, 11,
        23, 23,  9,  7, 13,  6,  4,  6,  6, 15,  4,  7,  5,  9, 16, 14,  4,  8,
         4, 17,  4,  6, 18,  5,  6, 11,  5, 20, 10, 14,  4,  5,  5,  7,  4, 23,
         4,  6, 18,  8, 13, 12,  7,  4, 11,  6,  9, 14, 10, 20, 15,  9, 10, 14,
        12,  6, 21,  4,  7, 13,  5,  4, 10, 16,  6,  6,  5, 16, 12, 13, 19,  4,
         9, 16,  9, 18, 19, 14, 14,  4, 10,  4, 10,  6,  6, 17, 10,  6,  6,  9,
         4, 11,  7, 13,  6, 10,  7,  8,  8, 16, 17,  4, 11,  5,  4,  4, 20, 11,
         5, 14,  4,  5,  9, 16, 13, 11, 11, 12, 10, 12, 20,  6, 13,  4,  7,  8,
        15, 14, 19, 12, 13, 12, 11,  4, 21,  4, 20, 15,  5, 17,  6, 12, 13,  7,
         6, 21,  9, 18, 19, 16, 12, 17, 

In [8]:
logits = roberta.model(tokens.unsqueeze(0).long(), classification_head_name='go_prediction')[0]

In [9]:
from go_annotation.ontology import Ontology
ont = Ontology()

_ancestor_array = ont.ancestor_array()

import torch

bsz = logits.shape[0]
index_tensor = logits.new_tensor(_ancestor_array, dtype=torch.int64)
index_tensor = index_tensor.unsqueeze(0).expand((bsz, -1, -1))  # Array of ancestors, offset by one
padded_logits = torch.nn.functional.pad(logits, (1, 0), value=float('inf'))  # Make 0 index return inf
padded_logits = padded_logits.unsqueeze(-1).expand((-1, -1, index_tensor.shape[2]))
normed_logits = torch.gather(padded_logits, 1, index_tensor)
normed_logits, _ = torch.min(normed_logits, -1)

In [10]:
normed_logits

tensor([[ -9.0975,  -8.1519,  -7.9774,  ..., -13.1126, -13.2209, -13.3508]],
       grad_fn=<MinBackward0>)

In [11]:
targets

tensor([[0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64)

In [12]:
normed_logits.max()

tensor(2.6472, grad_fn=<MaxBackward1>)

In [14]:
normed_logits.shape

torch.Size([1, 32012])

In [15]:
targets.shape

torch.Size([1, 32012])

In [19]:
import torch.nn.functional as F

out = F.binary_cross_entropy_with_logits(normed_logits, targets, reduction='sum')

In [20]:
normed_logits.detach().numpy()[np.isnan(out.detach().numpy())]

array([], shape=(0, 1, 32012), dtype=float32)

In [21]:
out

tensor(113.3611, dtype=torch.float64,
       grad_fn=<BinaryCrossEntropyWithLogitsBackward>)