In [1]:
from ndart.modeling_ndart import NDartForConditionalGeneration
from ndart.processing_ndart import NDartProcessor
from ndart.configuration_ndart import NDartConfig

import torch
from transformers import (
    AutoModelForTextEncoding,
    AutoModelForCausalLM,
    AutoTokenizer,
)

In [2]:
encoder_model = "intfloat/multilingual-e5-small"
decoder_model = "p1atdev/dart-v3-llama-8L-241018_241020-sft-use-group"
bert = AutoModelForTextEncoding.from_pretrained(encoder_model)
dart = AutoModelForCausalLM.from_pretrained(decoder_model)
processor = NDartProcessor(
    encoder_tokenizer=AutoTokenizer.from_pretrained(encoder_model),
    decoder_tokenizer=AutoTokenizer.from_pretrained(decoder_model),
    natural_token="<|natural|>",
)

model = NDartForConditionalGeneration._from_config(
    NDartConfig(
        natural_config=bert.config,
        tag_config=dart.config,
        natural_token_index=processor.natural_token_id,
    )
)
model.encoder_model = bert
model.decoder_model = dart
print(model)

NDartForConditionalGeneration(
  (encoder_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(250037, 384, padding_idx=0)
      (position_embeddings): Embedding(512, 384)
      (token_type_embeddings): Embedding(2, 384)
      (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (LayerNorm): LayerNorm((384

In [3]:
tag_prompt = [
    "<|bos|><projection><|natural|></projection>",
    "<projection><|natural|></projection><|translate:exact|><|input_end|>",
]
natural_text = [
    "an image",
    "黒髪ロング猫耳美少女JK",
]

In [4]:
with torch.no_grad():
    with torch.autocast(device_type="cpu"):
        inputs = processor(
            natural_text=natural_text,
            tag_text=tag_prompt,
            return_tensors="pt",
        )

        print(
            f"Encoder input_ids: {inputs.input_ids}",
        )
        encoder_embeds = model.encoder_model(
            input_ids=inputs.encoder_input_ids,
            attention_mask=inputs.encoder_attention_mask,
        ).last_hidden_state
        print(
            f"Encoder embeddings: {encoder_embeds.shape} {encoder_embeds[:, :, 0]}",
        )
        print(
            f"Natural attention mask: {inputs.encoder_attention_mask}",
        )

        projected_embeds = model.projection(encoder_embeds)
        print(
            f"Projected embeddings: {projected_embeds[:, :, 0]}",
        )

        print(
            f"Decoder input_ids: {inputs.input_ids}",
        )
        decoder_embeds = model.decoder_model.get_input_embeddings()(
            inputs.input_ids,
        )
        print(
            f"Decoder embeddings: {decoder_embeds.shape} {decoder_embeds[:, :, 0]}",
        )

        replaced_decoder_embeds = model._replace_natural_token_embeddings(
            encoder_embeds=projected_embeds,
            decoder_input_ids=inputs.input_ids,
            decoder_embeds=decoder_embeds,
            encoder_attention_mask=inputs.encoder_attention_mask,
        )
        print(
            f"Replaced decoder embeddings: {replaced_decoder_embeds.shape} {replaced_decoder_embeds[:, :, 0]}",
        )

Encoder input_ids: tensor([[ 0, 65, 60, 60, 60, 60, 67,  2,  2,  2,  2,  2,  2,  2,  2,  2],
        [65, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 67, 71, 40]])
Encoder embeddings: torch.Size([2, 12, 384]) tensor([[0.1989, 0.0757, 0.1789, 0.1089, 0.2332, 0.2270, 0.2442, 0.2437, 0.2468,
         0.2508, 0.2663, 0.2650],
        [0.2803, 0.2220, 0.3636, 0.2078, 0.0874, 0.4443, 0.3044, 0.2226, 0.3171,
         0.4064, 0.4431, 0.2801]])
Natural attention mask: tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
Projected embeddings: tensor([[ 0.0603,  0.0737,  0.0830,  0.0786,  0.0654,  0.0620,  0.0645,  0.0669,
          0.0659,  0.0659,  0.0649,  0.0635],
        [ 0.0332, -0.0253,  0.0244,  0.0299,  0.0270,  0.0084,  0.0425,  0.0276,
          0.0240,  0.0254,  0.0270,  0.0332]], dtype=torch.bfloat16)
Decoder input_ids: tensor([[ 0, 65, 60, 60, 60, 60, 67,  2,  2,  2,  2,  2,  2,  2,  2,  2],
        [65, 60, 60, 60, 60, 60, 60, 60, 60, 60

In [5]:
inputs = processor(
    natural_text=natural_text,
    tag_text=tag_prompt,
    return_tensors="pt",
)
inputs

{'input_ids': tensor([[ 0, 65, 60, 60, 60, 60, 67,  2,  2,  2,  2,  2,  2,  2,  2,  2],
        [65, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 67, 71, 40]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'encoder_input_ids': tensor([[    0,   142, 29569,     2,     1,     1,     1,     1,     1,     1,
             1,     1],
        [    0,     6, 74496, 92356, 11119, 76337, 35076, 21645,  2655, 48701,
         63859,     2]]), 'encoder_attention_mask': tensor([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [6]:
inputs.attention_mask.shape

torch.Size([2, 16])

In [9]:
with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=16, use_cache=True)

", ".join(
    [
        token
        for token in processor.batch_decode(outputs[0], skip_special_tokens=True)
        if token != ""
    ]
)

'amagi yukiko, :d, closed eyes, ^_^, round teeth, glasses, xd, profile, :o, beret, profile, profile, :o, profile, drill hair, facing viewer'