In [3]:
import torch
import torch.nn as NN
from sentence_transformers import SentenceTransformer

class SentenceClassifier(NN.Module):
    def __init__(self, labels):
        
        super().__init__()
        
        self.transformer = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
        for params in self.transformer.parameters():
            params.requires_grad = False
        
        self.fc = NN.Linear(768, len(labels))
        self.logits = NN.Softmax()
        self.labels = labels
    
    def forward(self, x):
        return self.logits(self.fc(torch.tensor(self.transformer.encode(x))))
    

In [11]:
from transformers import RobertaModel, RobertaTokenizerFast
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

def find_attention_span(sentences, idx,tokenizer):

    inp_len = 513

    sents = [sent for sent in sentences]

    removed_from_back = 0
    removed_from_front = 0
    
    while True:
        tokenizer_inp = " ".join(sents)
        tokenized = tokenizer(tokenizer_inp, padding=True, return_tensors='pt')
        #print("Token shape:", tokenized["input_ids"].shape)
        if tokenized["input_ids"].shape[1] > 512:
            if idx < len(sents)-1:
                sents = sents[:-1]
                removed_from_back += 1
            else:
                sents = sents[1:]
                idx -= 1
                removed_from_front += 1
        else:
            break
    
    print("Removed", removed_from_back, "sentences from back")
    print("Removed", removed_from_front, "sentences from front")
    print("Attention Span:", len(sents), "sentences")
    return sents, idx

In [14]:
import time

sentences = ["i like apples."]*500

t0 = time.time()
_,idx = find_attention_span(sentences, 200, tokenizer)
print("operation took", time.time()-t0)
print("Final index:", idx)

Removed 299 sentences from back
Removed 74 sentences from front
Attention Span: 127 sentences
operation took 1.056912899017334
Final index: 126


In [4]:
import json


all_spans = {}

with open("../../mfc_v4.0/spans_no_context.json", "r") as f:
    all_spans = json.load(f)

keys = list(all_spans.keys())

for key in keys:
    if key[-2:] != '.0':
        del all_spans[key]
        
keys = list(all_spans.keys())

In [5]:
model_reloaded = SentenceClassifier(keys)
model_reloaded.load_state_dict(torch.load("./distilberta-mfc-no-context.pt"))

<All keys matched successfully>

In [9]:
model_reloaded("the president's ratings plummeted")



tensor([0.5195, 0.0260, 0.0730, 0.0143, 0.0745, 0.0224, 0.0543, 0.0840, 0.0103,
        0.0207, 0.0104, 0.0228, 0.0060, 0.0600, 0.0018],
       grad_fn=<SoftmaxBackward0>)

In [10]:
keys[4]

'12.0'