In [1]:
from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel, BertTokenizer
import torch

In [2]:
from transformers import BertModel

In [3]:
from transformers.modeling_bert import BertEncoder, BertPooler

In [4]:
config_encoder = BertConfig()
config_decoder = BertConfig()

In [5]:
config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)

In [6]:
config.encoder

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

In [7]:
config.decoder

BertConfig {
  "add_cross_attention": true,
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": true,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

In [8]:
be = BertEncoder(config.encoder)

In [9]:
#be

In [10]:
bp = BertPooler(config.encoder)

In [11]:
bm = BertModel(config.encoder)

In [12]:
#bm

In [13]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [14]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")

In [15]:
inputs

{'input_ids': tensor([[  101,  7592,  1010,  2026,  3899,  2003, 10140,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}

In [16]:
bm.config.output_hidden_states

False

In [17]:
input_embedded = bm.embeddings(input_ids=inputs["input_ids"], position_ids=None, token_type_ids=inputs["token_type_ids"], inputs_embeds=None)

In [18]:
input_embedded.shape

torch.Size([1, 8, 768])

In [19]:
input_encodings = be(input_embedded)

In [20]:
input_encodings[0].shape

torch.Size([1, 8, 768])

In [21]:
pooled_output = bp(input_encodings[0])

In [22]:
pooled_output.shape

torch.Size([1, 768])

In [23]:
edm = EncoderDecoderModel(config=config)

In [24]:
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) 

In [39]:
input_ids2 = torch.tensor(tokenizer.encode("Hello, my dog is cute ha ?", add_special_tokens=True)).unsqueeze(0) 

In [40]:
edm(input_ids=input_ids, decoder_input_ids=input_ids2)

----------------------------------encode------------------------------
tensor([[  101,  7592,  1010,  2026,  3899,  2003, 10140,   102]]) 
 None 
 None 
 None 
 None 
 False
others: 
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1,

(tensor([[[-0.2635,  0.3335, -0.5793,  ..., -0.6256,  0.3102, -0.3588],
          [-0.0058,  0.7079,  0.1599,  ..., -0.5209,  0.2779, -0.7250],
          [-0.0678,  0.3428,  0.1493,  ..., -0.0747,  0.2034, -0.4901],
          ...,
          [ 0.0296, -0.4989,  0.1196,  ...,  0.0709, -0.2126,  0.2090],
          [-0.0108,  0.5508,  0.1749,  ..., -0.0189, -0.1595, -0.0454],
          [ 0.0542, -0.0384, -0.0409,  ...,  0.2816, -0.1454, -0.0560]]],
        grad_fn=<AddBackward0>),
 tensor([[[ 0.7444,  0.1622, -0.1934,  ...,  0.9731, -1.8561,  1.0474],
          [ 1.1234, -1.7583,  0.0234,  ...,  2.3432,  0.0813, -0.2300],
          [ 0.1358, -1.9884,  0.3587,  ...,  2.1168,  0.0220,  0.9787],
          ...,
          [ 0.4982, -1.2333,  0.3616,  ...,  1.4725, -0.9568,  0.9959],
          [ 0.4878, -1.5515, -0.7099,  ...,  1.6839,  0.5675,  0.9859],
          [ 2.2642, -0.7312, -0.5531,  ...,  1.9092, -0.4106, -0.0130]]],
        grad_fn=<NativeLayerNormBackward>),
 tensor([[-4.7333e-02, -7

In [27]:
edm

EncoderDecoderModel(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [28]:
config.decoder

BertConfig {
  "add_cross_attention": true,
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": true,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

In [29]:
decoder_bm = BertModel(config.decoder)

In [30]:
decoder_be = BertEncoder(config.decoder)

In [31]:
decoder_outputs = decoder_bm(input_ids = input_ids, encoder_hidden_states=torch.randn(1,8,768))

bertmodel encoder_extended_attention_mask torch.Size([1, 1, 1, 8]) tensor([[[[-0., -0., -0., -0., -0., -0., -0., -0.]]]])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 8, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 8, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 8, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 8, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 8, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 8, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.Size([1, 1, 1, 8])
Bertselfattention torch.Size([1, 12, 8, 8]) torch.

In [32]:
decoder_outputs[0].shape

torch.Size([1, 8, 768])

In [33]:
decoder_outputs[1].shape

torch.Size([1, 768])

In [34]:
edm.decoder.cls

BertOnlyMLMHead(
  (predictions): BertLMPredictionHead(
    (transform): BertPredictionHeadTransform(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    )
    (decoder): Linear(in_features=768, out_features=30522, bias=True)
  )
)

In [35]:
BertConfig()

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

In [36]:
#model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')

In [37]:
#model