In [1]:
from transformers import pipeline, AutoTokenizer

# 3 states

 text -> input ids -> logits -> predictions

In [2]:
checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)



In [4]:
raw_inputs = [
    "I have been waiting to bike around hawk hills my entire life",
    "Biking is good for the mind"
]

In [6]:
# Note: the above two raw inputs are of different sizes, and ideally we would use padding to normalize the data
inputs = tokenizer(raw_inputs, padding=True, truncation=True, return_tensors="pt")

In [8]:
for input_lst in inputs['input_ids']: # we can verify that the lenght of the strings are exactly the same
    print(len(input_lst))

14
14


In [12]:
inputs

{'input_ids': tensor([[  101,  1045,  2031,  2042,  3403,  2000,  7997,  2105,  9881,  4564,
          2026,  2972,  2166,   102],
        [  101, 28899,  2003,  2204,  2005,  1996,  2568,   102,     0,     0,
             0,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]])}

### Now that input is tokenized, we need to pass the tokenized data to the model

In [9]:
from transformers import AutoModelForSequenceClassification

checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)

outputs = model(**inputs)

print(outputs.logits)



tensor([[-2.6748,  2.7583],
        [-3.9115,  4.1768]], grad_fn=<AddmmBackward0>)


### To make sense of the logits we need to convert to probabilities and apply softmax

In [10]:
import torch

predictions = torch.nn.functional.softmax(outputs.logits, dim=1)

print(predictions)

tensor([[4.3509e-03, 9.9565e-01],
        [3.0701e-04, 9.9969e-01]], grad_fn=<SoftmaxBackward0>)


### Getting the label for what each logits in each row corresponds to below

In [11]:
# getting the label corresponding to each position
model.config.id2label

{0: 'NEGATIVE', 1: 'POSITIVE'}

### The first value is % negative and 2nd is % positive sentiment