In [1]:
import math
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from moondream.modeling_phi import apply_rotary_pos_emb

model_id = "vikhyatk/moondream2"
revision = "2024-04-02"
model = AutoModelForCausalLM.from_pretrained(
    model_id, trust_remote_code=True, revision=revision, torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)


  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [16]:
tokens = tokenizer("Describe the image.", return_tensors="pt")['input_ids']
inputs_embeds = model.text_model.get_input_embeddings()(tokens)

In [49]:
x = inputs_embeds
x = model.text_model.transformer.h[0].ln(x)

# Attention
mixer = model.text_model.transformer.h[0].mixer
bsz, q_len, _ = x.size()
query_states, key_states, value_states = mixer.Wqkv(x).chunk(3, dim=-1)
query_states = query_states.view(
    bsz, q_len, mixer.num_heads, mixer.head_dim
).transpose(1, 2)
key_states = key_states.view(
    bsz, q_len, mixer.num_key_value_heads, mixer.head_dim
).transpose(1, 2)
value_states = value_states.view(
    bsz, q_len, mixer.num_key_value_heads, mixer.head_dim
).transpose(1, 2)

# rope
kv_seq_len = key_states.shape[-2]
cos, sin = mixer.rotary_emb(value_states, seq_len=kv_seq_len)
query_rot, query_pass = (
    query_states[..., : mixer.rotary_emb.dim],
    query_states[..., mixer.rotary_emb.dim :],
)
key_rot, key_pass = (
    key_states[..., : mixer.rotary_emb.dim],
    key_states[..., mixer.rotary_emb.dim :],
)
query_rot, key_rot = apply_rotary_pos_emb(
    query_rot, key_rot, cos, sin, None
)
query_states = torch.cat((query_rot, query_pass), dim=-1)
key_states = torch.cat((key_rot, key_pass), dim=-1)

attn_output = scaled_dot_product_attention(
    query_states, key_states, value_states, is_causal=True
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, mixer.hidden_size)
attn_output = mixer.out_proj(attn_output)

print(attn_output.shape)
attn_output

torch.Size([1, 5, 2048])


tensor([[[ 0.1456,  0.5635,  0.6240,  ...,  0.4666,  0.2549, -0.2764],
         [ 0.6284,  0.3845,  0.8750,  ...,  0.9077,  0.0452, -0.3027],
         [ 0.3618,  0.6597,  0.6064,  ...,  0.4902,  0.1268, -0.3398],
         [ 0.1552,  0.3916,  0.4509,  ...,  0.1307, -0.0458, -0.0320],
         [ 0.0200,  0.0423, -0.5195,  ...,  0.1262,  0.0666, -0.0076]]],
       dtype=torch.float16, grad_fn=<ViewBackward0>)

In [48]:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

In [55]:
tokens = tokenizer("Question: Is this a dog? Yes or no.\n\nAnswer:", return_tensors="pt")['input_ids']
output = model.text_model.generate(tokens, do_sample=False, max_length=50, pad_token_id=tokenizer.eos_token_id)
tokenizer.decode(output[0], skip_special_tokens=True)

'Question: Is this a dog? Yes or no.\n\nAnswer: No\n\n4. Is this a cat or a dog? No\n\n5. Is this a cat or a dog? No\n\n6. Is'