In [1]:
import numpy as np
import torch
from transformers import LlamaTokenizer, AutoModelForCausalLM, AutoTokenizer

#model_id = "stabilityai/japanese-stablelm-instruct-alpha-7b"    
#model_id = "stabilityai/japanese-stablelm-base-alpha-7b"
model_id = "stabilityai/japanese-stablelm-base-gamma-7b"

#tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1", additional_special_tokens=['▁▁'])
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
)
#model.half()
model.eval()

def build_prompt(user_query, inputs="", sep="\n\n### "):
    sys_msg = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
    p = sys_msg
    roles = ["指示", "応答"]
    msgs = [": \n" + user_query, ": "]
    if inputs:
        roles.insert(1, "入力")
        msgs.insert(1, ": \n" + inputs)
    for role, msg in zip(roles, msgs):
        p += sep + role + msg
    return p


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:31<00:00, 15.68s/it]


In [2]:

# this is for reproducibility.
# feel free to change to get different result
seed = 42
torch.manual_seed(seed)

# Infer with prompt without any additional input
user_inputs = {
    "user_query": "VR とはどのようなものですか？",
    "inputs": ""
}
prompt = build_prompt(**user_inputs)

tokens = tokenizer(
    prompt, 
    add_special_tokens=False, 
    return_tensors="pt"
)

pos = tokens.input_ids.shape[-1]

In [3]:
inputs = {
    'input_ids'           : tokens.input_ids,
    'attention_mask'      : tokens.attention_mask,
    'position_ids'        : torch.tensor([n for n in range(pos)], dtype=torch.int64),
    #'past_key_values'     : None,
    #'output_attentions'   : True,
    #'output_hidden_states': True,
    'use_cache'           : True,
    'return_dict'         : True
}


In [6]:

num_generate_tokens = 20

for nn in range(num_generate_tokens):

    res = model(**inputs)
    # outpus
    #  logits [1, 3, 65536]
    #  past_key_values [n][0|1][ 1, 32, seq_len, 128]
    
    #print(res.keys())
    #print(res['attentions'][0].shape)
    #odict_keys(['logits', 'past_key_values', 'hidden_states', 'attentions'])

    logits = res['logits'][0,-1,:].to('cpu').detach().numpy().copy()
    past_key_values = res['past_key_values']
    print(len(past_key_values), len(past_key_values[0]), past_key_values[0][0].shape)
    predicted_id = np.argmax(logits)
    if predicted_id == tokenizer.eos_token_id:
        print('** EOS token is generated.')
        break
    token = tokenizer.decode([predicted_id], skip_special_tokens=False)

    past_key_values = res['past_key_values']
    print(predicted_id, token, len(past_key_values), past_key_values[0][0].shape, past_key_values[0][1].shape)

    inputs = {
        'past_key_values'     : past_key_values,
        'input_ids'           : torch.tensor([[predicted_id]], dtype=torch.int64),
        'attention_mask'      : torch.tensor([[1]], dtype=torch.int64),
        'position_ids'        : torch.tensor([[pos]], dtype=torch.int64),
        #'past_key_values'     : past_key_values,
        #'output_attentions'   : True,
        #'output_hidden_states': True,
        'use_cache'           : True,
        'return_dict'         : True
    }
    pos += 1

# past_key_values = ( ( [1,32,seq,128], [1,32,seq,128] ) *32 ) 

32 2 torch.Size([1, 8, 131, 128])
30692 じ 32 torch.Size([1, 8, 131, 128]) torch.Size([1, 8, 131, 128])
32 2 torch.Size([1, 8, 132, 128])
29943 よ 32 torch.Size([1, 8, 132, 128]) torch.Size([1, 8, 132, 128])
32 2 torch.Size([1, 8, 133, 128])
29620 う 32 torch.Size([1, 8, 133, 128]) torch.Size([1, 8, 133, 128])
32 2 torch.Size([1, 8, 134, 128])
29174 に 32 torch.Size([1, 8, 134, 128]) torch.Size([1, 8, 134, 128])
32 2 torch.Size([1, 8, 135, 128])
30599 見 32 torch.Size([1, 8, 135, 128]) torch.Size([1, 8, 135, 128])
32 2 torch.Size([1, 8, 136, 128])
30145 え 32 torch.Size([1, 8, 136, 128]) torch.Size([1, 8, 136, 128])
32 2 torch.Size([1, 8, 137, 128])
29116 る 32 torch.Size([1, 8, 137, 128]) torch.Size([1, 8, 137, 128])
32 2 torch.Size([1, 8, 138, 128])
29696 画 32 torch.Size([1, 8, 138, 128]) torch.Size([1, 8, 138, 128])
32 2 torch.Size([1, 8, 139, 128])
29663 像 32 torch.Size([1, 8, 139, 128]) torch.Size([1, 8, 139, 128])
32 2 torch.Size([1, 8, 140, 128])
29078 を 32 torch.Size([1, 8, 140, 128])