In [5]:
import warnings
warnings.filterwarnings('ignore')

from torch import nn
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification

### 1. Quickstart

In [6]:
model = 'distilbert/distilbert-base-uncased-finetuned-sst-2-english'

task = 'sentiment-analysis'
text = 'Lawrence Pritchard Waterhouse was happy to see his friend Turing after a long time.'

In [None]:
classifier = pipeline(task=task, model=model, device=0)

device: mps:0


In [8]:
result = classifier(text)
result

[{'label': 'POSITIVE', 'score': 0.9996441602706909}]

### 2. The `Tokenizer` object

In [9]:
tokenizer = AutoTokenizer.from_pretrained(model)

encoding = tokenizer(text, return_tensors='pt')
encoding['input_ids'], encoding['attention_mask']

(tensor([[  101,  5623, 26927, 10649,  4232,  2300,  4580,  2001,  3407,  2000,
           2156,  2010,  2767, 28639,  2044,  1037,  2146,  2051,  1012,   102]]),
 tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]))

In [10]:
pt_batch = tokenizer(
    ['Lawrence Pritchard Waterhouse was happy to see his friend Turing after a long time.', 'The Enigma machine has been cracked.'],
    padding=True,
    truncation=True,
    max_length=512,
    return_tensors="pt",
)

### 3. The `AutoModel` object

In [11]:
pt_model = AutoModelForSequenceClassification.from_pretrained(model)

In [12]:
# the model outputs the final activations in the logits attribute
output = pt_model(**pt_batch)
output

SequenceClassifierOutput(loss=None, logits=tensor([[-3.8251,  4.1154],
        [ 4.0819, -3.3393]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [13]:
# apply the softmax function to the logits to retrieve the probabilities
pt_predictions = nn.functional.softmax(output.logits, dim=-1)
print(pt_predictions)

tensor([[3.5588e-04, 9.9964e-01],
        [9.9940e-01, 5.9809e-04]], grad_fn=<SoftmaxBackward0>)


### 4. Saving the model

In [14]:
pt_save_directory = "./pt_save_pretrained"
tokenizer.save_pretrained(pt_save_directory)
pt_model.save_pretrained(pt_save_directory)