In [3]:
import torch.nn as nn
import torch
from transformers.models.bert.modeling_bert import BertIntermediate, BertOutput, BertAttention

BertLayerNorm = nn.LayerNorm

In [52]:
class DecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = BertAttention(config)
        self.crossattention = BertAttention(config)
        # BertIntermediate: linear + gelu
        self.intermediate = BertIntermediate(config)
        # BertOutput: layer+dropout+residual+layer_norm
        self.output = BertOutput(config)

    def forward(
        self,
        hidden_states,
        encoder_hidden_states,
        encoder_attention_mask
    ):
        self_attention_outputs = self.attention(hidden_states)
        # print("the shape of hidden_states: ", hidden_states.shape)
        # print("the shape of self_attention_outputs: ", self_attention_outputs[0].shape)

        attention_output = self_attention_outputs[0]
        outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
        encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
        if encoder_attention_mask.dim() == 3:
            encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
        elif encoder_attention_mask.dim() == 2:
            encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
        else:
            raise ValueError(
                "Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(
                    encoder_hidden_shape, encoder_attention_mask.shape
                )
            )
        # The mask is also transformed so that positions with a value of 1 (indicating they should be masked) are set to
        # a large negative value (-10000.0), making them effectively zero when passed through a softmax.
        
        encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0

        
        cross_attention_outputs = self.crossattention(
            hidden_states=attention_output, encoder_hidden_states=encoder_hidden_states,  encoder_attention_mask=encoder_extended_attention_mask
        )
        attention_output = cross_attention_outputs[0]
        # print("the shape of cross_attention_outputs: ", cross_attention_outputs[0].shape)
        outputs = outputs + cross_attention_outputs[1:]  # add cross attentions if we output attention weights

        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        outputs = (layer_output,) + outputs
        return outputs

In [165]:
class SetDecoder_regressive_output(nn.Module):
    def __init__(self, config, num_generated_triples, num_layers, num_classes, return_intermediate=False, use_ILP=False):
        super().__init__()
        self.return_intermediate = return_intermediate
        self.num_generated_triples = num_generated_triples
        self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(num_layers)])
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.query_embed = nn.Embedding(num_generated_triples, config.hidden_size)
        if use_ILP:
            self.decoder2class = nn.Linear(config.hidden_size, num_classes)
            self.class2hidden = nn.Linear(num_classes, config.hidden_size)
        else:
            self.decoder2class = nn.Linear(config.hidden_size, num_classes + 1)
            self.class2hidden = nn.Linear(num_classes + 1, config.hidden_size)
        # self.decoder2span = nn.Linear(config.hidden_size, 4)

        self.head_start_metric_1 = nn.Linear(config.hidden_size, config.hidden_size)
        self.head_end_metric_1 = nn.Linear(config.hidden_size, config.hidden_size)

        self.tail_start_metric_1 = nn.Linear(config.hidden_size, config.hidden_size)
        self.tail_end_metric_1 = nn.Linear(config.hidden_size, config.hidden_size)

        self.head_start_metric_2 = nn.Linear(config.hidden_size, config.hidden_size)
        self.head_end_metric_2 = nn.Linear(config.hidden_size, config.hidden_size)

        self.tail_start_metric_2 = nn.Linear(config.hidden_size, config.hidden_size)
        self.tail_end_metric_2 = nn.Linear(config.hidden_size, config.hidden_size)

        torch.nn.init.orthogonal_(self.head_start_metric_1.weight, gain=1)
        torch.nn.init.orthogonal_(self.head_end_metric_1.weight, gain=1)
        torch.nn.init.orthogonal_(self.tail_start_metric_1.weight, gain=1)
        torch.nn.init.orthogonal_(self.tail_end_metric_1.weight, gain=1)
        torch.nn.init.orthogonal_(self.head_start_metric_2.weight, gain=1)
        torch.nn.init.orthogonal_(self.head_end_metric_2.weight, gain=1)
        torch.nn.init.orthogonal_(self.tail_start_metric_2.weight, gain=1)
        torch.nn.init.orthogonal_(self.tail_end_metric_2.weight, gain=1)
        torch.nn.init.orthogonal_(self.query_embed.weight, gain=1)
        
        self.regressive_decoder = nn.MultiheadAttention(embed_dim=config.hidden_size, num_heads=1, batch_first=True)
        self.output_linear = nn.Linear(config.hidden_size * 4, 4, bias=False)
        



    def forward(self, encoder_hidden_states, encoder_attention_mask):
        bsz = encoder_hidden_states.size()[0]
        # print("the shape of query_embed.weight: ", self.query_embed.weight.shape)
        hidden_states = self.query_embed.weight.unsqueeze(0).repeat(bsz, 1, 1)
        # print("the shape of hidden_states: ", hidden_states.shape)
        # hidden_state: [bsz, num_generated_triples, hidden_size]
        hidden_states = self.dropout(self.LayerNorm(hidden_states))
        all_hidden_states = ()
        
        print("=====================================")
        # print(f"encoder_attention_mask:\n{encoder_attention_mask}")
                
        for i, layer_module in enumerate(self.layers):
            # print(f"hidden_states shape:\n{hidden_states.shape}")
            # print(f"encoder_hidden_states shape:\n{encoder_hidden_states.shape}")
            # print("=====================================")
            
            if self.return_intermediate:
                all_hidden_states = all_hidden_states + (hidden_states,)
            layer_outputs = layer_module(
                hidden_states, encoder_hidden_states, encoder_attention_mask
            )
            hidden_states = layer_outputs[0]

        class_logits = self.decoder2class(hidden_states)
        
        class_hidden_states = self.class2hidden(class_logits)
        # class_hidden_states = class_hidden_states.unsqueeze(-2).repeat(1, 1, encoder_hidden_states.shape[-2], 1)
        # print(f"class_hidden_states after repeat shape:\n{class_hidden_states.shape}")
        # have a binary tensor encorder_extended_attention_mask_binary, shape is the same as encoder_extended_attention_mask, and the value is True for 0, and False for 1.
        if encoder_attention_mask.dim() == 3:
            encoder_extended_attention_mask = encoder_attention_mask.reshape(-1, encoder_hidden_states.shape[-2])
        elif encoder_attention_mask.dim() == 2:
            encoder_extended_attention_mask = encoder_attention_mask
        encoder_extended_attention_mask_binary = (encoder_extended_attention_mask == 0)
        # print(f"encoder_extended_attention_mask_binary: \n{encoder_extended_attention_mask_binary}")
        # print(f"encoder_extended_attention_mask_binary shape: \n{encoder_extended_attention_mask_binary.shape}")
        
        encoder_extended_attention_mask_binary =  encoder_extended_attention_mask_binary.unsqueeze(1).repeat(1, self.num_generated_triples, 1)
        # print(f"encoder_extended_attention_mask_binary shape: \n{encoder_extended_attention_mask_binary.shape}")
        
        encoder_extended_attention_mask_binary = encoder_extended_attention_mask_binary.reshape(bsz * self.num_generated_triples, encoder_hidden_states.shape[-2])
        # print(f"encoder_extended_attention_mask_binary shape: \n{encoder_extended_attention_mask_binary.shape}")
        
        # print(f"encoder_hidden_states shape:\n{encoder_hidden_states.shape}")
        encoder_hidden_states = encoder_hidden_states.repeat(self.num_generated_triples, 1, 1)
        # print(f"encoder_hidden_states shape after:\n{encoder_hidden_states.shape}")
        
        # print(f"encoder_extended_attention_mask_binary shape:\n{encoder_extended_attention_mask_binary.shape}")
        
        
        class_hidden_states = class_hidden_states.reshape(-1, class_hidden_states.shape[-1]).unsqueeze(1)
        # print(f"class_hidden_states shape after:\n{class_hidden_states.shape}")
        
        hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1]).unsqueeze(1)
        # print(f"hidden_states shape after:\n{hidden_states.shape}")
        
        # print(f"encoder_extended_attention_mask_binary.unsqueeze(-1) shape:\n{encoder_extended_attention_mask_binary.unsqueeze(-1).shape}")
        head_start_logits_mh = self.regressive_decoder(query=encoder_hidden_states, key=torch.tanh(self.head_start_metric_1(class_hidden_states) + self.head_start_metric_2(hidden_states)), value=torch.tanh(self.head_start_metric_1(class_hidden_states) + self.head_start_metric_2(hidden_states)), attn_mask=encoder_extended_attention_mask_binary.unsqueeze(-1))
        head_start_logits_mh = head_start_logits_mh[0]
        # print(f"head_start_logits_mh shape:\n{head_start_logits_mh.shape}")
        
        
        
        head_end_logits_mh = self.regressive_decoder(head_start_logits_mh, key=torch.tanh(self.head_end_metric_1(class_hidden_states) + self.head_end_metric_2(hidden_states)), value=torch.tanh(self.head_end_metric_1(class_hidden_states) + self.head_end_metric_2(hidden_states)), attn_mask=encoder_extended_attention_mask_binary.unsqueeze(-1))
        head_end_logits_mh = head_end_logits_mh[0]
                
        # print(f"head_end_logits_mh shape:\n{head_end_logits_mh.shape}")


        tail_start_logits_mh = self.regressive_decoder(head_end_logits_mh, key=torch.tanh(self.tail_start_metric_1(class_hidden_states) + self.tail_start_metric_2(hidden_states)), value=torch.tanh(self.tail_start_metric_1(class_hidden_states) + self.tail_start_metric_2(hidden_states)), attn_mask=encoder_extended_attention_mask_binary.unsqueeze(-1))
        tail_start_logits_mh = tail_start_logits_mh[0]
        
        # print(f"tail_start_logits_mh shape:\n{tail_start_logits_mh.shape}")
        
        
        
        tail_end_logits_mh = self.regressive_decoder(tail_start_logits_mh, key=torch.tanh(self.tail_end_metric_1(class_hidden_states) + self.tail_end_metric_2(hidden_states)), value=torch.tanh(self.tail_end_metric_1(class_hidden_states) + self.tail_end_metric_2(hidden_states)), attn_mask=encoder_extended_attention_mask_binary.unsqueeze(-1))
        tail_end_logits_mh = tail_end_logits_mh[0]
        
        # print(f"tail_end_logits_mh shape:\n{tail_end_logits_mh.shape}")
        # print("=============linear==================")
        # Stack the four tensors along a new dimension in-place
        input_tensor = torch.stack([head_start_logits_mh, head_end_logits_mh, tail_start_logits_mh, tail_end_logits_mh], dim=3)
        
        # Flatten the input tensor along the last dimension
        input_tensor_flat = input_tensor.view(input_tensor.size(0), input_tensor.size(1), -1)
        # print(f"input_tensor_flat shape:\n{input_tensor_flat.shape}")
        
        # Apply the linear layer
        output_flat = self.output_linear(input_tensor_flat)
        # print(f"output_flat shape:\n{output_flat.shape}")
        
        # Reshape the output tensor to separate it into four tensors
        output = output_flat.view(bsz, self.num_generated_triples, output_flat.size(1), 4)

        # Split the output tensor in-place into the four individual tensors
        head_start_logits  = output[:, :, :, 0].clone()
        head_end_logits = output[:, :, :, 1].clone()
        tail_start_logits = output[:, :, :, 2].clone()
        tail_end_logits = output[:, :, :, 3].clone()

        return class_logits, head_start_logits, head_end_logits, tail_start_logits, tail_end_logits

# test

In [168]:
from models.seq_encoder import SeqEncoder

from transformers import BertTokenizer

class Args:
    def __init__(self, **kwargs):
        # self.bert_directory = "bert-base-cased"
        self.bert_directory = "SpanBERT/spanbert-base-cased"
        self.fix_bert_embeddings = True


In [169]:
bert_args = Args()

encoder = SeqEncoder(bert_args)

tokenizer = BertTokenizer.from_pretrained(bert_args.bert_directory)

decoder = SetDecoder(encoder.config, 5, 2, 10, return_intermediate=False)

text = ["Hello, my dog is cute", "Do you want to play with me? I think we have some commons."]
tokenized_text = tokenizer(text, return_tensors="pt", padding=True)
last_hidden_state, pooler_output = encoder(input_ids=tokenized_text["input_ids"], attention_mask=tokenized_text["attention_mask"])

print("last_hidden_state: ", last_hidden_state.shape)
print("pooler_output: ", pooler_output.shape, "\n====================\n")


decoder_output = decoder(encoder_hidden_states=last_hidden_state, encoder_attention_mask=tokenized_text["attention_mask"])
class_logits, head_start_logits, head_end_logits, tail_start_logits, tail_end_logits  = decoder_output

print(f"head_start_logits shape: {head_start_logits.shape}")
print(f'tokenized_text["attention_mask"] shape: {tokenized_text["attention_mask"].shape}')
head_start_logits = head_start_logits.squeeze(-1).masked_fill((1 - tokenized_text["attention_mask"].unsqueeze(1)).bool(), -10000.0)
head_end_logits = head_end_logits.squeeze(-1).masked_fill((1 - tokenized_text["attention_mask"].unsqueeze(1)).bool(), -10000.0)
tail_start_logits = tail_start_logits.squeeze(-1).masked_fill((1 - tokenized_text["attention_mask"].unsqueeze(1)).bool(), -10000.0)
tail_end_logits = tail_end_logits.squeeze(-1).masked_fill((1 - tokenized_text["attention_mask"].unsqueeze(1)).bool(), -10000.0)# [bsz, num_generated_triples, seq_len]

print("class_logits: ", class_logits.shape)
print("head_start_logits: ", head_start_logits.shape)
print("head_end_logits: ", head_end_logits.shape)
print("tail_start_logits: ", tail_start_logits.shape)
print("tail_end_logits: ", tail_end_logits.shape)

# have the argmax of the class_logits
class_logits_argmax = class_logits.argmax(-1)
print("class_logits_argmax: ", class_logits_argmax)

Some weights of BertModel were not initialized from the model checkpoint at SpanBERT/spanbert-base-cased and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


last_hidden_state:  torch.Size([2, 18, 768])
pooler_output:  torch.Size([2, 768]) 

class_hidden_states after repeat shape:
torch.Size([2, 5, 768])
encoder_extended_attention_mask_binary: 
tensor([[False, False, False, False, False, False, False, False,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False]])
encoder_extended_attention_mask_binary shape: 
torch.Size([2, 18])
encoder_extended_attention_mask_binary shape: 
torch.Size([2, 5, 18])
encoder_extended_attention_mask_binary shape: 
torch.Size([10, 18])
encoder_hidden_states shape:
torch.Size([2, 18, 768])
encoder_hidden_states shape after:
torch.Size([10, 18, 768])
encoder_extended_attention_mask_binary shape:
torch.Size([10, 18])
class_hidden_states shape after:
torch.Size([10, 1, 768])
hidden_states shape after:
torch.Size([10, 1, 768])
encoder_extended_attent

In [49]:
decoder.head_start_metric_3

Linear(in_features=768, out_features=1, bias=False)