In [None]:
class BartAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        is_decoder: bool = False,
        bias: bool = True,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert (
            self.head_dim * num_heads == self.embed_dim
        ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})."
        self.scaling = self.head_dim ** -0.5
        self.is_decoder = is_decoder

        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.scale_factor = 1
        
    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        key_value_states: Optional[torch.Tensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        layer_head_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        idx = None,
        keyword_position=None,
    ):
        """Input shape: Batch x Time x Channel"""

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None
        bsz, tgt_len, embed_dim = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scaling
        # get key, value proj
        if is_cross_attention and past_key_value is not None:
            # reuse k,v, cross_attentions
            key_states = past_key_value[0]
            value_states = past_key_value[1]
        elif is_cross_attention:
            # cross_attentions
            key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
            value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
        elif past_key_value is not None:
            # reuse k, v, self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)
        else:
            # self_attention
            key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
            value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_states, value_states)
        # 12 batch 16 heads 4 beam 
        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.view(*proj_shape)
        value_states = value_states.view(*proj_shape)

        src_len = key_states.size(1)
        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
        # use_layer=[6,8,4]
        # decoder_use_layer=[0]
        decoder_use_layer=[0,4,9]
        if not self.training:
            # attn_weights = F.softmax(attn_weights, dim=-1)
            
            if self.is_decoder and is_cross_attention and idx in decoder_use_layer:
                sigmoid = nn.Sigmoid()
                prob = sigmoid(attn_weights)
                m = torch.mean(prob, dim=-1, keepdim=True)
                s = torch.std(prob, dim=-1, keepdim=True)
                thresholds = m.detach()
                # thresholds = m.detach()-s.detach()
                # thresholds = m.detach()+s.detach()
                m_idx=0
                step = int(prob.size(0)/keyword_position.size(0))
                for kp in keyword_position:
                    flag_idx=torch.where(kp==-100,False,True)
                    kp_idx=kp[flag_idx].int()
                    prob[m_idx:m_idx+step,:, kp_idx] = 1
                    m_idx+=step
                # sample = torch.greater(prob, 0.2).type_as(prob)
                sample = torch.where(prob > thresholds, True, False).type_as(prob)
                attn_weights_exp = torch.exp(attn_weights)
                attn_weights_exp = attn_weights_exp * sample
                attn_weights_sum = torch.sum(attn_weights_exp, dim=-1, keepdim=True)
                attn_probs = attn_weights_exp / torch.clamp(attn_weights_sum, min=1e-4)
            else:
                attn_weights = F.softmax(attn_weights, dim=-1)
        else:
            # if not self.is_decoder:
            #     encoder_use_layer=[11,10,9]
            #     if idx in encoder_use_layer:
            #         sample_prob = RelaxedBernoulli(0.5, logits=attn_weights)
            #         y = sample_prob.rsample()
            #         y_hard = torch.greater(y, 0.1).type_as(y)
            #         sample = (y_hard - y).detach() + y
            #         attn_weights_exp = torch.exp(attn_weights)
            #         attn_weights_exp = attn_weights_exp * sample
            #         attn_weights_sum = torch.sum(attn_weights_exp, dim=-1, keepdim=True)
            #         attn_probs = attn_weights_exp / torch.clamp(attn_weights_sum, min=1e-4)
                    
            # if self.is_decoder and is_cross_attention and False:
            if self.is_decoder and is_cross_attention:
                # decoder_use_layer=[0,2,4,6]
                if idx in decoder_use_layer:
                    # if self.training:
                    sample_prob = RelaxedBernoulli(0.5, logits=attn_weights)
                    y = sample_prob.rsample()
                    at_mask = torch.repeat_interleave(attention_mask,self.num_heads,dim=1).view(bsz * self.num_heads, tgt_len, src_len).detach()
                    y_masked = y.detach() + at_mask
                    valid_mask = y_masked != torch.finfo(attn_weights.dtype).min

                    y_final = torch.where(valid_mask, y_masked, torch.tensor(float('nan')))
                    m = torch.nanmean(y_final, dim=-1, keepdim=True)
                    v = torch.nanmean((y_final - m) ** 2, dim=-1, keepdim=True)
                    s = torch.sqrt(v)
                    # m = torch.mean(y, dim=-1, keepdim=True)
                    # s = torch.std(y, dim=-1, keepdim=True)
                    thresholds = m.detach()+s.detach()
                    m_idx=0
                    max_val = torch.max(y)

                    # 對於 keyword_position，將 y 設置為比 max_val 稍大的值
                    keyword_adjustment = max_val + 0.01  # 確保大於 thresholds
                    
                    for kp in keyword_position:
                        flag_idx=torch.where(kp==-100,False,True)
                        kp_idx=kp[flag_idx].int()
                        if kp_idx.size(0)!=0:
                            mask = torch.zeros_like(y[m_idx:m_idx+self.num_heads])
                            mask[:,:,kp_idx] = 1
                            
                            # 使用 StraightThrough 估計器來保留梯度流動
                            y_adjusted = y[m_idx:m_idx+self.num_heads] * (1 - mask) + keyword_adjustment * mask
                            y[m_idx:m_idx+self.num_heads] = (y_adjusted - y[m_idx:m_idx+self.num_heads]).detach() + y[m_idx:m_idx+self.num_heads]

                            # mask = torch.ones_like(y[m_idx:m_idx+self.num_heads])
                            # mask[:,:,kp_idx] = 0
                            # hard_mask = torch.zeros_like(y[m_idx:m_idx+self.num_heads])
                            # hard_mask[:,:, kp_idx] = 1

                            # y[m_idx:m_idx+self.num_heads] = (
                            #     (hard_mask - y[m_idx:m_idx+self.num_heads]).detach() + y[m_idx:m_idx+self.num_heads])+\
                            #         ((mask - y[m_idx:m_idx+self.num_heads]).detach() + y[m_idx:m_idx+self.num_heads]) * hard_mask + \
                            #             y[m_idx:m_idx+self.num_heads] * (1 - hard_mask)
                                        
                        m_idx += self.num_heads

                    y_hard = torch.where(y > thresholds, True, False).type_as(y)

                    sample = (y_hard - y).detach() + y
                    sample = sample * self.scale_factor*1/(idx+1) + sample.detach() * (1 - self.scale_factor*1/(idx+1))

                    attn_weights_exp = torch.exp(attn_weights)
                    attn_weights_exp = attn_weights_exp * sample
                    attn_weights_sum = torch.sum(attn_weights_exp, dim=-1, keepdim=True)
                    attn_probs = attn_weights_exp / torch.clamp(attn_weights_sum, min=1e-4)
                else:
                    attn_weights = F.softmax(attn_weights, dim=-1)
            else:
                attn_weights = F.softmax(attn_weights, dim=-1)

        if layer_head_mask is not None:
            if layer_head_mask.size() != (self.num_heads,):
                raise ValueError(
                    f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
                )
            attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if output_attentions:
            # this operation is a bit awkward, but it's required to
            # make sure that attn_weights keeps its gradient.
            # In order to do so, attn_weights have to be reshaped
            # twice and have to be reused in the following
            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
        else:
            attn_weights_reshaped = None
            
        if not self.training:
            # attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
            if self.is_decoder and is_cross_attention:
                if idx not in decoder_use_layer:
                    attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
            
            else:
                attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
        else:
            # if not self.is_decoder:
            #     if idx not in encoder_use_layer:
            #         attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
            # if self.is_decoder and is_cross_attention and False:
            if self.is_decoder and is_cross_attention:
                if idx not in decoder_use_layer:
                    attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
            
            else:
                attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training)
        attn_output = torch.bmm(attn_probs, value_states)

        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
            )

        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights_reshaped, past_key_value