In [3]:
import torch
import torch.nn as NN
from sentence_transformers import SentenceTransformer

class SentenceClassifier(NN.Module):
    def __init__(self, labels):
        
        super().__init__()
        
        self.transformer = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
        for params in self.transformer.parameters():
            params.requires_grad = False
        
        self.fc = NN.Linear(768, len(labels))
        self.logits = NN.Softmax()
        self.labels = labels
    
    def forward(self, x):
        return self.logits(self.fc(torch.tensor(self.transformer.encode(x))))
    

In [9]:
from transformers import RobertaModel, RobertaTokenizerFast
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")

def find_attention_span(sentences, idx,tokenizer):

    inp_len = 513

    sents = [sent for sent in sentences]

    removed_from_back = 0
    removed_from_front = 0
    
    while True:
        tokenizer_inp = " ".join(sents)
        tokenized = tokenizer(tokenizer_inp, padding=True, return_offsets_mapping=True, return_tensors='pt')
        #print("Token shape:", tokenized["input_ids"].shape)
        if tokenized["input_ids"].shape[1] > 512:
            if idx < len(sents)-1:
                sents = sents[:-1]
                removed_from_back += 1
            else:
                sents = sents[1:]
                idx -= 1
                removed_from_front += 1
        else:
            break
    
    print("Removed", removed_from_back, "sentences from back")
    print("Removed", removed_from_front, "sentences from front")
    print("Attention Span:", len(sents), "sentences")
    return sents, idx

In [10]:
import time

sentences = ["i like apples."]*500


find_attention_span(sentences, 200, tokenizer)

Token indices sequence length is longer than the specified maximum sequence length for this model (2002 > 512). Running this sequence through the model will result in indexing errors


Token shape: torch.Size([1, 2002])
Token shape: torch.Size([1, 1998])
Token shape: torch.Size([1, 1994])
Token shape: torch.Size([1, 1990])
Token shape: torch.Size([1, 1986])
Token shape: torch.Size([1, 1982])
Token shape: torch.Size([1, 1978])
Token shape: torch.Size([1, 1974])
Token shape: torch.Size([1, 1970])
Token shape: torch.Size([1, 1966])
Token shape: torch.Size([1, 1962])
Token shape: torch.Size([1, 1958])
Token shape: torch.Size([1, 1954])
Token shape: torch.Size([1, 1950])
Token shape: torch.Size([1, 1946])
Token shape: torch.Size([1, 1942])
Token shape: torch.Size([1, 1938])
Token shape: torch.Size([1, 1934])
Token shape: torch.Size([1, 1930])
Token shape: torch.Size([1, 1926])
Token shape: torch.Size([1, 1922])
Token shape: torch.Size([1, 1918])
Token shape: torch.Size([1, 1914])
Token shape: torch.Size([1, 1910])
Token shape: torch.Size([1, 1906])
Token shape: torch.Size([1, 1902])
Token shape: torch.Size([1, 1898])
Token shape: torch.Size([1, 1894])
Token shape: torch.S

Token shape: torch.Size([1, 1054])
Token shape: torch.Size([1, 1050])
Token shape: torch.Size([1, 1046])
Token shape: torch.Size([1, 1042])
Token shape: torch.Size([1, 1038])
Token shape: torch.Size([1, 1034])
Token shape: torch.Size([1, 1030])
Token shape: torch.Size([1, 1026])
Token shape: torch.Size([1, 1022])
Token shape: torch.Size([1, 1018])
Token shape: torch.Size([1, 1014])
Token shape: torch.Size([1, 1010])
Token shape: torch.Size([1, 1006])
Token shape: torch.Size([1, 1002])
Token shape: torch.Size([1, 998])
Token shape: torch.Size([1, 994])
Token shape: torch.Size([1, 990])
Token shape: torch.Size([1, 986])
Token shape: torch.Size([1, 982])
Token shape: torch.Size([1, 978])
Token shape: torch.Size([1, 974])
Token shape: torch.Size([1, 970])
Token shape: torch.Size([1, 966])
Token shape: torch.Size([1, 962])
Token shape: torch.Size([1, 958])
Token shape: torch.Size([1, 954])
Token shape: torch.Size([1, 950])
Token shape: torch.Size([1, 946])
Token shape: torch.Size([1, 942])


(['i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',
  'i like apples.',


In [4]:
import json


all_spans = {}

with open("../../mfc_v4.0/spans_no_context.json", "r") as f:
    all_spans = json.load(f)

keys = list(all_spans.keys())

for key in keys:
    if key[-2:] != '.0':
        del all_spans[key]
        
keys = list(all_spans.keys())

In [5]:
model_reloaded = SentenceClassifier(keys)
model_reloaded.load_state_dict(torch.load("./distilberta-mfc-no-context.pt"))

<All keys matched successfully>

In [9]:
model_reloaded("the president's ratings plummeted")



tensor([0.5195, 0.0260, 0.0730, 0.0143, 0.0745, 0.0224, 0.0543, 0.0840, 0.0103,
        0.0207, 0.0104, 0.0228, 0.0060, 0.0600, 0.0018],
       grad_fn=<SoftmaxBackward0>)

In [10]:
keys[4]

'12.0'