In [23]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config, T5EncoderModel
from transformers.models.t5.modeling_t5 import T5Stack
import torch

# Load pre-trained T5 configuration
config = T5Config.from_pretrained('google/flan-t5-base')

# Modify the configuration
config.num_layers = 1  # Number of encoder layers
# config.num_decoder_layers = 1  # Number of decoder layers
config.is_decoder = True

# Initialize the model with the custom configuration
encoder = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base').encoder
model = T5Stack(config)

In [25]:
model

T5Stack(
  (block): ModuleList(
    (0): T5Block(
      (layer): ModuleList(
        (0): T5LayerSelfAttention(
          (SelfAttention): T5Attention(
            (q): Linear(in_features=768, out_features=768, bias=False)
            (k): Linear(in_features=768, out_features=768, bias=False)
            (v): Linear(in_features=768, out_features=768, bias=False)
            (o): Linear(in_features=768, out_features=768, bias=False)
            (relative_attention_bias): Embedding(32, 12)
          )
          (layer_norm): T5LayerNorm()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (1): T5LayerCrossAttention(
          (EncDecAttention): T5Attention(
            (q): Linear(in_features=768, out_features=768, bias=False)
            (k): Linear(in_features=768, out_features=768, bias=False)
            (v): Linear(in_features=768, out_features=768, bias=False)
            (o): Linear(in_features=768, out_features=768, bias=False)
          )
          (layer_norm)

In [18]:
B, L, T, D = 4, 22, 128, config.d_model
q_token = torch.randint(0, config.vocab_size, (B, L))
q_feat = encoder.embed_tokens(q_token)
v_feat = torch.randn(B, T, D)
lm_input = torch.cat([q_feat, v_feat], dim=1)
lm_mask = torch.ones(B, L + T)
key_value_states = torch.randn(B, 48, D)
key_value_states_attention_mask = torch.ones(B, 48)


In [24]:
model.forward(
    input_ids=None,
    attention_mask=lm_mask,
    encoder_hidden_states=key_value_states,
    encoder_attention_mask=key_value_states_attention_mask,
    inputs_embeds=lm_input,
    head_mask=None,
    cross_attn_head_mask=None,
    past_key_values=None,
    use_cache=True,
    output_attentions=True,
    output_hidden_states=True,
    return_dict=True,
)

BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=tensor([[[-1.1007,  1.0935, -0.1241,  ...,  0.7386, -0.7188,  0.2768],
         [ 0.2894,  0.2786, -0.9207,  ...,  0.9717, -1.1071,  1.9101],
         [ 1.4272,  0.0720,  2.0160,  ..., -0.0329, -0.1990,  0.2328],
         ...,
         [-0.0842, -0.3765, -0.5411,  ..., -0.4022, -2.0643,  0.2778],
         [ 1.3436,  0.2336, -1.2332,  ...,  1.0112, -0.9556, -1.0569],
         [ 1.6795,  0.2401, -0.3354,  ..., -0.7944,  1.6792, -0.6674]],

        [[-1.7007, -0.8680,  0.0000,  ..., -1.0876, -0.2907,  0.4550],
         [ 0.0000,  0.5002, -0.0336,  ..., -1.1147, -0.2716,  0.4202],
         [ 1.9371,  0.5306,  0.5409,  ...,  1.1800,  1.1592, -0.1748],
         ...,
         [ 1.3863, -1.0693,  2.0788,  ..., -0.6507,  0.0824,  0.5754],
         [ 0.4726, -0.0000,  1.2138,  ...,  0.0000,  0.6428,  0.4845],
         [-0.1451,  0.0000,  1.7852,  ..., -0.8643,  0.2718, -0.4514]],

        [[ 0.1020, -0.0285,  0.2848,  ...,  0.2353, -0.9