In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import LlamaTokenizer, LlamaForCausalLM
import torch

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [23]:
model_id = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = LlamaTokenizer.from_pretrained(model_id, truncation_side='left', padding_side='right')
tokenizer.pad_token = tokenizer.eos_token
# model = LlamaForCausalLM.from_pretrained(model_id, device_map={"":0},use_flash_attention_2=True)
model = LlamaForCausalLM.from_pretrained(model_id, device_map={"":0})
flash_attn_model = LlamaForCausalLM.from_pretrained(model_id, device_map={"":0},use_flash_attention_2=True)


Loading checkpoint shards: 100%|██████████| 2/2 [00:49<00:00, 24.74s/it]
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
Loading checkpoint shards: 100%|██████████| 2/2 [00:49<00:00, 24.68s/it]


In [82]:
tokenizer = LlamaTokenizer.from_pretrained(model_id, truncation_side='left', padding_side='right')


In [76]:
model.eval()
flash_attn_model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )


In [94]:
tokenizer.pad_token = tokenizer.eos_token

input_prompt = "I've got a lovely bunch of coconuts do do do dooo"
random = " ".join(["do" for _ in range(50)])
batched_input_prompt = ["I've got a lovely bunch of coconuts do do do dooo", " ".join(["do" for _ in range(50)])]
input_prompt_tokenized = tokenizer(input_prompt, return_tensors="pt").to('cuda')
input_prompt_padded_tokenized = tokenizer(input_prompt, return_tensors="pt", padding="max_length", max_length=50).to('cuda')
batched_input_prompt_tokenized = tokenizer(batched_input_prompt, return_tensors="pt", padding=True, max_length=50).to('cuda')
random_tokenized = tokenizer(random, return_tensors="pt").to('cuda')

print(input_prompt_tokenized)
print(input_prompt_padded_tokenized)
print(batched_input_prompt_tokenized)

{'input_ids': tensor([[    1,   306, 29915,   345,  2355,   263, 12355,   873, 14928,   310,
          1302,   535,  8842,   437,   437,   437,   437,  3634]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0')}
{'input_ids': tensor([[    1,   306, 29915,   345,  2355,   263, 12355,   873, 14928,   310,
          1302,   535,  8842,   437,   437,   437,   437,  3634,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             2,     2,     2,     2,     2,     2,     2,     2,     2,     2]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0]], device='cuda:0')}
{'input_ids': tensor([[    1,   306, 29915,   345,  2355



In [84]:
input_ids_masked = torch.zeros(input_prompt_padded_tokenized.input_ids.shape, dtype=torch.int64).to('cuda')
torch.where(input_prompt_padded_tokenized.input_ids == tokenizer.pad_token_id,
            torch.tensor(-100, dtype=torch.int64),
            input_prompt_padded_tokenized.input_ids,
            out=input_ids_masked)

loss1 = model(
    input_prompt_tokenized.input_ids,
    attention_mask=input_prompt_tokenized.attention_mask,
    labels=input_prompt_tokenized.input_ids
).loss

logits1 = model(
    input_prompt_tokenized.input_ids,
    attention_mask=input_prompt_tokenized.attention_mask,
    labels=input_prompt_tokenized.input_ids
).logits

print(f"Normal Model + Non-Padded Input")
print(f"Loss (no padding): {loss1}")
print(f"Logits (no padding): {logits1}")

Normal Model + Non-Padded Input
Loss (no padding): 2.8300979137420654
Logits (no padding): tensor([[[ 0.1357, -0.1206,  0.3125,  ...,  1.3594,  1.8984,  0.6641],
         [-8.3750, -9.8125, -0.3691,  ..., -3.4844, -8.0000, -2.8594],
         [-3.7812, -2.7656,  4.0000,  ..., -1.4297, -2.8125, -0.3926],
         ...,
         [-3.4688, -3.0625,  8.3125,  ...,  0.2559, -2.6875, -2.8125],
         [-3.0000, -3.0938,  8.7500,  ...,  0.0806, -2.5156, -2.7656],
         [-2.6562, -1.6953,  7.7812,  ..., -0.6484, -3.1562, -2.5000]]],
       device='cuda:0', dtype=torch.float32, grad_fn=<ToCopyBackward0>)


In [96]:
loss1_falsh_attn = flash_attn_model(
    input_prompt_tokenized.input_ids,
    attention_mask=input_prompt_tokenized.attention_mask,
    labels=input_prompt_tokenized.input_ids
).loss

logits1_falsh_attn = flash_attn_model(
    input_prompt_tokenized.input_ids,
    attention_mask=input_prompt_tokenized.attention_mask,
    labels=input_prompt_tokenized.input_ids
).logits

loss_random = model(
    random_tokenized.input_ids,
    attention_mask=random_tokenized.attention_mask,
    labels=random_tokenized.input_ids
).loss

print(f"Flash Attn Model + Non-Padded Input")
print(f"Loss (no padding) with flash attn: {loss1_falsh_attn}")
print(f"Logits (no padding) with flash attn: {logits1_falsh_attn}")
print(f"Loss (no padding) with flash attn: {loss_random + loss1_falsh_attn}")

Flash Attn Model + Non-Padded Input
Loss (no padding) with flash attn: 2.8371522426605225
Logits (no padding) with flash attn: tensor([[[ 0.1357, -0.1206,  0.3125,  ...,  1.3594,  1.8984,  0.6641],
         [-8.4375, -9.8750, -0.3555,  ..., -3.5000, -8.0625, -2.9062],
         [-3.7812, -2.7812,  4.0625,  ..., -1.3828, -2.7969, -0.3438],
         ...,
         [-3.4844, -3.0781,  8.3750,  ...,  0.2871, -2.6719, -2.8125],
         [-3.0625, -3.1250,  8.6875,  ...,  0.0928, -2.5781, -2.7812],
         [-2.6562, -1.6719,  7.7812,  ..., -0.6914, -3.1406, -2.5156]]],
       device='cuda:0', dtype=torch.float32, grad_fn=<ToCopyBackward0>)
Loss (no padding) with flash attn: 8.834760665893555


In [105]:
loss2 = model(
    input_prompt_padded_tokenized.input_ids,
    attention_mask=input_prompt_padded_tokenized.attention_mask,
    labels=input_prompt_padded_tokenized.input_ids
).loss

logits2 = model(
    input_prompt_padded_tokenized.input_ids,
    attention_mask=input_prompt_padded_tokenized.attention_mask,
    labels=input_prompt_padded_tokenized.input_ids
).logits

print(logits2.shape)
# decode the logits
print(tokenizer.decode(logits2.argmax(dim=-1)[0]))

print(f"Normal Model + Padded Input")
print(f"Loss (padding): {loss2}")
print(f"Logits (padding): {logits2}")
print(f"Decoded Logits (padding): {decoded_logits}")

torch.Size([1, 50, 32000])
Unterscheidung'm been a bitely little of coconuts
o do do doo<s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>'''
Normal Model + Padded Input
Loss (padding): 8.399846076965332
Logits (padding): tensor([[[ 0.1357, -0.1206,  0.3125,  ...,  1.3594,  1.8984,  0.6641],
         [-8.3750, -9.8125, -0.3691,  ..., -3.4844, -8.0000, -2.8594],
         [-3.7812, -2.7656,  4.0000,  ..., -1.4297, -2.8125, -0.3926],
         ...,
         [-4.9062,  8.0000,  5.2500,  ..., -2.2812, -2.2812, -2.7344],
         [-5.4062,  3.4219,  5.3125,  ..., -2.6094, -1.1250, -3.4062],
         [-5.3125,  2.5312,  5.9375,  ..., -2.2656, -0.9961, -3.6719]]],
       device='cuda:0', dtype=torch.float32, grad_fn=<ToCopyBackward0>)
Decoded Logits (padding): []


In [106]:
loss2_falsh_attn = flash_attn_model(
    input_prompt_padded_tokenized.input_ids,
    attention_mask=input_prompt_padded_tokenized.attention_mask,
    labels=input_prompt_padded_tokenized.input_ids
).loss

logits2_falsh_attn = flash_attn_model(
    input_prompt_padded_tokenized.input_ids,
    attention_mask=input_prompt_padded_tokenized.attention_mask,
    labels=input_prompt_padded_tokenized.input_ids
).logits

print(tokenizer.decode(logits2_falsh_attn.argmax(dim=-1)[0]))

print(f"Flash Attn Model + Padded Input")
print(f"Loss (padding) with flash attn: {loss2_falsh_attn}")
print(f"Logits (padding) with flash attn: {logits2_falsh_attn}")

Unterscheidung'm been a bitely little of coconuts
o do do doo<s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>
Flash Attn Model + Padded Input
Loss (padding) with flash attn: 16.86910057067871
Logits (padding) with flash attn: tensor([[[ 0.1357, -0.1206,  0.3125,  ...,  1.3594,  1.8984,  0.6641],
         [-8.4375, -9.8750, -0.3555,  ..., -3.5000, -8.0625, -2.9062],
         [-3.7812, -2.7812,  4.0625,  ..., -1.3828, -2.7969, -0.3438],
         ...,
         [ 8.5000, 27.0000,  2.2188,  ...,  4.2188,  1.9766,  3.6250],
         [ 8.5000, 27.0000,  2.2188,  ...,  4.2188,  1.9766,  3.6250],
         [ 8.5000, 27.0000,  2.2188,  ...,  4.2188,  1.9766,  3.6250]]],
       device='cuda:0', dtype=torch.float32, grad_fn=<ToCopyBackward0>)


In [113]:
loss2 = model(
    batched_input_prompt_tokenized.input_ids,
    attention_mask=batched_input_prompt_tokenized.attention_mask,
    labels=batched_input_prompt_tokenized.input_ids
).loss

logits2 = model(
    batched_input_prompt_tokenized.input_ids,
    attention_mask=batched_input_prompt_tokenized.attention_mask,
    labels=batched_input_prompt_tokenized.input_ids
).logits

print('Decoded Logits')
print('------------------')
print(tokenizer.decode(logits2.argmax(dim=-1)[0]))
print('------------------\n')

print(f"Normal Model + Padded Input")
print(f"Loss (with padding): {loss2}")
print(f"Logits (with padding): {logits2[0]}")

Decoded Logits
------------------
Unterscheidung'm been a bitely little of coconuts
o do do doo<s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>''''
------------------

Normal Model + Padded Input
Loss (with padding): 7.247009754180908
Logits (with padding): tensor([[ 0.0608, -0.1953,  0.3203,  ...,  1.3281,  1.8359,  0.6094],
        [-8.3750, -9.8125, -0.3574,  ..., -3.5000, -8.0625, -2.8750],
        [-3.8281, -2.8281,  4.0000,  ..., -1.4375, -2.8438, -0.4102],
        ...,
        [-5.3438,  3.5625,  5.4062,  ..., -2.5781, -1.0469, -3.3594],
        [-5.3750,  2.4375,  5.8438,  ..., -2.2969, -1.0547, -3.7031],
        [-5.1250,  3.9844,  5.8750,  ..., -1.9766, -0.9297, -3.1250]],
       device='cuda:0', dtype=torch.float32, grad_fn=<SelectBackward0>)


In [114]:
flash_attn_loss = flash_attn_model(
    batched_input_prompt_tokenized.input_ids,
    attention_mask=batched_input_prompt_tokenized.attention_mask,
    labels=batched_input_prompt_tokenized.input_ids
).loss

flash_attn_logits = flash_attn_model(
    batched_input_prompt_tokenized.input_ids,
    attention_mask=batched_input_prompt_tokenized.attention_mask,
    labels=batched_input_prompt_tokenized.input_ids
).logits

print('Decoded Logits')
print('------------------')
print(tokenizer.decode(flash_attn_logits.argmax(dim=-1)[0]))
print('------------------\n')

print(f"Flash Attn Model + Padded Input")
print(f"Loss (padding input & masking labels): {flash_attn_loss}")
print(f"Logits (padding input & masking labels): {flash_attn_logits[0]}")

Decoded Logits
------------------
Unterscheidung'm been a bitely little of coconuts
o do do doo<s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s><s>
------------------

Flash Attn Model + Padded Input
Loss (padding input & masking labels): 11.508078575134277
Logits (padding input & masking labels): tensor([[ 0.0608, -0.1953,  0.3203,  ...,  1.3281,  1.8359,  0.6094],
        [-8.4375, -9.8750, -0.3477,  ..., -3.5312, -8.0625, -2.8906],
        [-3.7656, -2.8125,  4.0625,  ..., -1.3672, -2.7812, -0.3457],
        ...,
        [ 8.5000, 27.0000,  2.2344,  ...,  4.2500,  2.0000,  3.6562],
        [ 8.5000, 27.0000,  2.2344,  ...,  4.2500,  2.0000,  3.6562],
        [ 8.5000, 27.0000,  2.2344,  ...,  4.2500,  2.0000,  3.6562]],
       device='cuda:0', dtype=torch.float32, grad_fn=<SelectBackward0>)


In [9]:
a = torch.randint(0, 5, (3, 4))
b = torch.randint(0, 5, (2, 4))

print(a, b)

tensor([[3, 1, 1, 3],
        [0, 0, 0, 1],
        [4, 1, 3, 2]]) tensor([[0, 4, 3, 1],
        [2, 3, 2, 2]])


In [17]:
# convert a and b to dtype=torch.float32
a = a.float()
b = b.float()


In [18]:
b1 = torch.tensor([[0, 4, 3, 1]])
b1 = b1.float()
b1.shape

torch.Size([1, 4])

In [23]:
torch.mean((a[:, None, :] - b1) * (a[:, None, :]), dim=2).transpose(0, 1)

tensor([[2.5000, 0.0000, 3.7500]])

In [22]:
print(a[:, None, :])
torch.mean((a[:, None, :] - b) * (a[:, None, :]), dim=2).transpose(0, 1)

tensor([[[3., 1., 1., 3.]],

        [[0., 0., 0., 1.]],

        [[4., 1., 3., 2.]]])


tensor([[ 2.5000,  0.0000,  3.7500],
        [ 0.7500, -0.2500,  2.2500]])

In [32]:
tensor_example = [[4.46, 7.57, 7.58]]
# tensor_example = tensor_example
tensor_example = torch.tensor(tensor_example)

# plot the softmax output
import matplotlib.pyplot as plt
import numpy as np


# def plot_softmax(tensor_example, knn_T=10):
#     tensor_example = tensor_example / knn_T
#     softmax = torch.nn.Softmax(dim=1)
#     softmax_output = softmax(tensor_example)
#     print(softmax_output)
#     plt.plot(softmax_output.detach().numpy().flatten(), 'o')
#     # shot the x axis as discrete values
#     plt.xticks(np.arange(16))
#     # plot the y from 0 to 1
#     plt.ylim(0, 1)
#     plt.show()

# plot_softmax(tensor_example, 100)
torch.softmax(tensor_example, dim=-1)

tensor([[0.0217, 0.4867, 0.4916]])

In [5]:
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
tokenizer(" positive")

{'input_ids': [3967], 'attention_mask': [1]}

In [2]:
'''
Choose sentiment from terrible or great.

Review: i would recommend big bad love only to winger fans who have missed her since 1995 's forget paris.
Sentiment: terrible
Review:  suspenseful enough for older kids but not . 
Sentiment: great

Review: the subtle strength of elling is that it never squandering touch with the reality of the grim situation . 
Sentiment:
'''
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, use_flash_attention_2=True)

# gpt_tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
# gpt_model = AutoModelForCausalLM.from_pretrained("gpt2-xl").to('cuda')

# mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
# mistral_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1").to('cuda')

You are attempting to use Flash Attention 2.0 with a model initialized on CPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


KeyboardInterrupt: 

In [4]:
model = model.to('cuda')

In [11]:
gpt_tokenizer = AutoTokenizer.from_pretrained("gpt2-xl")
gpt_model = AutoModelForCausalLM.from_pretrained("gpt2-xl").to('cuda')

In [20]:
from itertools import permutations
from tqdm import tqdm
import numpy as np
from collections import Counter
instructions = "Classify the sentiment of negative and positive."
icl_samples = ["Review: is a step down for director gary fleder . \nSentiment: negative",
               "Review: the director , tom dey , had spliced together bits and pieces of midnight run and 48 hours ( and , for that matter , shrek ) . \nSentiment: positive",
                "Review: from two fatal ailments -- a dearth of vitality and a story that 's shapeless and uninflected . \nSentiment: negative",
                "Review: results that are sometimes bracing . \nSentiment: positive",
                "Review: plodding soap opera . \nSentiment: negative",
                "Review: all-star salute . \nSentiment: positive",
                "Review: fit all of pootie tang in between its punchlines . \nSentiment: negative",
                "Review: award-winning . \nSentiment: positive",
                "Review: deserve better . \nSentiment: negative",
                "Review: you actually buy into \nSentiment: positive",
                "Review: of cliches that shoplifts shamelessly from farewell-to-innocence movies like the wanderers and a bronx tale without cribbing any of their intelligence . \nSentiment: negative",
                "Review: real-life basis is , in fact , so interesting that no embellishment is \nSentiment: positive",
                "Review: to insulting the intelligence of anyone who has n't been living under a rock \nSentiment: negative",
                "Review: immensely ambitious \nSentiment: positive",
                "Review: into the modern rut of narrative banality \nSentiment: negative",
                "Review: user-friendly \nSentiment: positive"]
sample = """or doing last year 's taxes with your ex-wife . \nSentiment:"""

tokenizer.padding_side = "left"
tokenizer.truncation_side = "left"
tokenizer.pad_token = tokenizer.eos_token

prompt_ids_list = []
i = 0
# permutation = np.random.permutation(len(icl_samples))
# icl_samples = [icl_samples[i] for i in permutation]

np.random.seed(1)
for order in range(100):
    permutation = np.random.permutation(len(icl_samples))
    order = [icl_samples[i] for i in permutation]
    prompt = '\n\n'.join(order) + '\n\n' + sample
    # prompt_ids = tokenizer.encode(prompt, return_tensors="pt")
    # prompt_ids = prompt_ids.to('cuda')

    prompt_ids_list.append((prompt))
    i += 1

print('Length of prompt_ids_list: ', len(prompt_ids_list))

batch_size = 8
prompt_ids_list = [prompt_ids_list[i:i + batch_size] for i in range(0, len(prompt_ids_list), batch_size)]

# convert prompt_ids_list to dataloader
# prompt_ids_list = torch.utils.data.DataLoader(prompt_ids_list, batch_size=batch_size, shuffle=False)

outputs = []
for batch in tqdm(prompt_ids_list):
    batch = tokenizer.batch_encode_plus(batch, return_tensors="pt", padding=True, truncation=True).to('cuda')
    output = model(batch['input_ids'], return_dict=True).logits
    output = output[:, -1, :].detach().cpu()
    # outputs += tokenizer.batch_decode(torch.argmax(output, dim=-1).squeeze().tolist())
    outputs += (tokenizer.batch_decode(torch.argmax(output, dim=-1).squeeze().tolist()))
    del batch
    
print(Counter(outputs))
# prompt = '\n\n'.join(icl_samples) + '\n\n' + sample
# prompt_ids = tokenizer.encode(prompt, return_tensors="pt")
# prompt_ids = prompt_ids.to('cuda')

# output = model(prompt_ids, return_dict=True).logits
# print(output.shape)
# output = output[:, -1, :]
# tokenizer.decode(torch.argmax(output, dim=-1).squeeze().tolist())

Length of prompt_ids_list:  100


100%|██████████| 13/13 [00:03<00:00,  3.67it/s]

Counter({'positive': 97, 'negative': 3})





In [50]:
tokenizer("terrible")

{'input_ids': [1, 16403], 'attention_mask': [1, 1]}

In [24]:
prompt = """Classify the sentiment of positive and negative.
Review: i would recommend big bad love only to winger fans who have missed her since 1995 's forget paris . .
Sentiment: negative
Review: suspenseful enough for older kids but not .
Sentiment: positive
Review: another run-of-the-mill disney sequel intended for the home video market .
Sentiment: negative
Review: has never been smoother or more confident .
Sentiment: positive
Review: bad-movie .
Sentiment: negative
Review: sweetly .
Sentiment: positive
Review: like an extended dialogue exercise in retard 101 .
Sentiment: negative
Review: bouquet gives a performance that is masterly . .
Sentiment: positive"""

input1 = "Review: one of creepiest, scariest movies to come along in a long, long time, easily rivaling blair witch or the others. \nSentiment:"
input2 = "Review: good movie . \nSentiment:"

prompt_ids = gpt_tokenizer.encode(prompt, return_tensors="pt")
inputs1 = prompt + input1
inputs2 = prompt + input2

inputs = [inputs1, inputs2]
# inputs = [inputs1, inputs1]

gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
input_ids = gpt_tokenizer.batch_encode_plus(inputs, return_tensors="pt", padding=True, truncation=True)['input_ids'].to('cuda')

attention_mask = input_ids.ne(gpt_tokenizer.pad_token_id).float().to('cuda')
output = gpt_model(input_ids, attention_mask=attention_mask, return_dict=True).logits

last_non_pad_indices = torch.ne(input_ids, gpt_tokenizer.pad_token_id).sum(-1) - 1
print(last_non_pad_indices)
output = output[range(output.shape[0]), last_non_pad_indices, :]

gpt_tokenizer.batch_decode(torch.argmax(output, dim=-1).squeeze().tolist())

tensor([193, 167], device='cuda:0')


[' negative', ' positive']

In [33]:
prompt = "Classify the sentiment of positive and negative.\n Review: i would recommend big bad love only to winger fans who have missed her since 1995 's forget paris.Sentiment: negative\nReview:  suspenseful enough for older kids but not . \nSentiment: positive\n"

input = "Review: the cd is not suitable for children . \nSentiment:"
# try with the mistral tokenizer and model

prompt_ids = mistral_tokenizer.encode(prompt, return_tensors="pt")
input_ids = mistral_tokenizer.encode(input, return_tensors="pt")

input_ids = torch.cat([prompt_ids, input_ids[:, 1:]], dim=-1).to('cuda')

# add 10 padding tokens to the end of the input
input_ids = torch.cat([input_ids, torch.ones((1, 10)).long().to('cuda')], dim=-1)

attention_mask = input_ids != 1

output = mistral_model(input_ids, attention_mask=attention_mask, return_dict=True).logits
print(output.shape)
output = output[:, len(input_ids[0]) - 11, :]
mistral_tokenizer.decode(torch.argmax(output, dim=-1).squeeze().tolist())

torch.Size([1, 90, 32000])


'negative'

In [22]:
mistral_tokenizer("great")

{'input_ids': [1, 1598], 'attention_mask': [1, 1]}

In [51]:
gpt_tokenizer(" terrible")

{'input_ids': [353, 5547], 'attention_mask': [1, 1]}