In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import torch

In [2]:
cache_dir = './cache'

In [3]:
model = AutoModelForSequenceClassification.from_pretrained("facebook/bart-large-mnli", cache_dir=cache_dir)
device = torch.device(f'cuda:6' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli", cache_dir=cache_dir)

In [4]:
classifier = pipeline("zero-shot-classification", model=model, tokenizer=tokenizer, device=device)

In [5]:
sequence_to_classify = "A photo of young"
candidate_labels = ["hair", "woman", "costume", "girl", "person", "face", "dress", "pose", "blue", "child", "kid", "computer", "laptop", "mouse", "code", "love"]
res = classifier(sequence_to_classify, candidate_labels, multi_label=True)
print(res['sequence'])
print(res['labels'])
print(res['scores'])

A photo of young
['child', 'kid', 'girl', 'woman', 'person', 'face', 'pose', 'blue', 'love', 'hair', 'dress', 'mouse', 'laptop', 'costume', 'code', 'computer']
[0.9987421035766602, 0.9986479878425598, 0.9812822341918945, 0.8993655443191528, 0.8792117238044739, 0.6589139699935913, 0.4070006310939789, 0.17151901125907898, 0.11285580694675446, 0.03831982612609863, 0.02029099501669407, 0.006456996314227581, 0.0019525282550603151, 0.0011518665123730898, 0.0011386561673134565, 0.0004380836326163262]


In [6]:
filtered_scores_labels = [(score, label) for score, label in zip(res['scores'], res['labels']) if score <= 0.99]
res['scores'], res['labels'] = zip(*filtered_scores_labels) if filtered_scores_labels else ([], [])

In [7]:
print(res['sequence'])
print(res['labels'])
print(res['scores'])

A photo of young
('girl', 'woman', 'person', 'face', 'pose', 'blue', 'love', 'hair', 'dress', 'mouse', 'laptop', 'costume', 'code', 'computer')
(0.9812822341918945, 0.8993655443191528, 0.8792117238044739, 0.6589139699935913, 0.4070006310939789, 0.17151901125907898, 0.11285580694675446, 0.03831982612609863, 0.02029099501669407, 0.006456996314227581, 0.0019525282550603151, 0.0011518665123730898, 0.0011386561673134565, 0.0004380836326163262)
