In [1]:
import torch
from typing import Tuple

import torch.nn as nn

class EntailmentMemory(nn.Module):
    def __init__(self, num_slots: int, hidden_dim: int):
        super().__init__()
        self.num_slots = num_slots
        self.hidden_dim = hidden_dim
        self.memory = nn.Parameter(torch.randn(num_slots, hidden_dim))  # [K, D]
        self.proj = nn.Linear(hidden_dim, num_slots)  # Maps BART hidden states to memory slots

    def forward(self, hidden_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            hidden_state: [batch_size, hidden_dim] (e.g., BART's [z] token embedding)
        Returns:
            z: Entailment representation [batch_size, hidden_dim]
            attn_weights: Memory attention scores [batch_size, num_slots]
        """
        attn_weights = torch.softmax(self.proj(hidden_state), dim=-1)  # [batch_size, K]
        z = torch.einsum('bk,kd->bd', attn_weights, self.memory)      # [batch_size, D]
        return z, attn_weights
    



In [2]:
import torch
from typing import Tuple

import torch.nn as nn

class DiscourseMemory(nn.Module):
    def __init__(self, num_slots: int, hidden_dim: int):
        super().__init__()
        self.num_slots = num_slots
        self.hidden_dim = hidden_dim
        self.memory = nn.Parameter(torch.randn(num_slots, hidden_dim))  # [L, D]
        self.proj = nn.Linear(hidden_dim, num_slots)

    def forward(self, hidden_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Same as ERM but for discourse features."""
        attn_weights = torch.softmax(self.proj(hidden_state), dim=-1)  # [batch_size, L]
        z_d = torch.einsum('bl,ld->bd', attn_weights, self.memory)     # [batch_size, D]
        return z_d, attn_weights

In [5]:
import torch
from typing import Optional, Tuple, Dict, Any
import torch.nn as nn
from transformers import BartForConditionalGeneration, BartConfig
from transformers.modeling_outputs import Seq2SeqLMOutput

class BARTWithMemory(nn.Module):
    def __init__(self, 
                 bart_model_name: str = "facebook/bart-base",
                 erm_slots: int = 10, 
                 ddm_slots: int = 5):
        super().__init__()
        # Load full model to preserve components
        full_bart = BartForConditionalGeneration.from_pretrained(bart_model_name)
        self.bart = full_bart.model
        self.lm_head = full_bart.lm_head
        self.final_logits_bias = full_bart.final_logits_bias
        self.config: BartConfig = full_bart.config
        # self.generate = full_bart.generate

        # Initialize memory modules
        self.erm = EntailmentMemory(erm_slots, self.config.d_model)
        self.ddm = DiscourseMemory(ddm_slots, self.config.d_model)
        self.ortho_loss_coeff = 0.1

        # Register token IDs from config
        self.register_buffer("sop_token_id", torch.tensor([self.config.bos_token_id]))
        self.register_buffer("eop_token_id", torch.tensor([self.config.eos_token_id]))

    def orthogonal_loss(self) -> torch.Tensor:
        """Orthogonality constraint between memory matrices"""
        return torch.norm(torch.mm(self.erm.memory, self.ddm.memory.T)) ** 2

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.LongTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        **kwargs
    ) -> Seq2SeqLMOutput:
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Encode inputs
        if encoder_outputs is None:
            encoder_outputs = self.bart.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        
        last_hidden = encoder_outputs.last_hidden_state

        # Memory operations
        z_token_embedding = last_hidden[:, 0]  # First token embedding
        z, _ = self.erm(z_token_embedding)
        z_d, _ = self.ddm(z_token_embedding)

        # Prepare decoder inputs
        if labels is not None:
            if decoder_input_ids is None:
                decoder_input_ids = self._shift_right(labels)
        
        # Embed decoder inputs
        decoder_inputs_embeds = self.bart.decoder.embed_tokens(decoder_input_ids)
        decoder_inputs_embeds[:, 0] += z + z_d  # Modify first token

        # Decode
        decoder_outputs = self.bart.decoder(
            input_ids=None,
            inputs_embeds=decoder_inputs_embeds,
            encoder_hidden_states=last_hidden,
            encoder_attention_mask=attention_mask,
            attention_mask=decoder_attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Calculate logits and losses
        lm_logits = self.lm_head(decoder_outputs.last_hidden_state) + self.final_logits_bias
        loss = None
        ortho_loss = self.orthogonal_loss()

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
            ce_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
            loss = ce_loss + self.ortho_loss_coeff * ortho_loss

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )

    def _shift_right(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Shift labels to create decoder inputs (like original BART)"""
        shifted = input_ids.new_zeros(input_ids.shape)
        shifted[:, 1:] = input_ids[:, :-1].clone()
        shifted[:, 0] = self.sop_token_id
        return shifted

In [6]:
from transformers import BartTokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BARTWithMemory()
input_ids = tokenizer("Hello, how are you?", return_tensors="pt").input_ids
labels = tokenizer("I am fine, thank you.", return_tensors="pt").input_ids
model.generate(input_ids=input_ids, labels=labels)

AttributeError: 'BARTWithMemory' object has no attribute 'generate'

In [14]:
from transformers import BartTokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BARTWithMemory()

# Structured example
input_text = "<latent><persona>p1,p2<query>q<response>"
output_text = "r<eos>"

# Tokenize (make sure to add special tokens to your tokenizer first!)
inputs = tokenizer(input_text, return_tensors="pt")
labels = tokenizer(output_text, return_tensors="pt").input_ids

# Forward pass
outputs = model(
    input_ids=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    labels=labels
)

loss = outputs.loss
loss.backward()


In [15]:
loss

tensor(3052.6052, grad_fn=<AddBackward0>)

In [22]:
# Generate response

inputs = tokenizer("<latent> <persona> Hi I am College Buddy <query> WHat is Your Name <responce> ", return_tensors="pt")
decoder_input_ids = tokenizer("<latent> <persona> Hi I am College Buddy <query> WHat is Your Name <responce> ", return_tensors="pt").input_ids
output_ids = model.generate(
    input_ids=inputs['input_ids'],
    attention_mask=inputs['attention_mask'],
    decoder_input_ids=decoder_input_ids,
    max_length=200,
    do_sample=True,
    temperature=0.7,
    top_p=0.7,
    top_k=50,
    num_return_sequences=1,
    pad_token_id=tokenizer.eos_token_id
)

# Decode and extract response
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
# response = full_output.split("<response>")[-1].split("<eos>")[0].strip()
print(full_output)

<latent> <persona> Hi I am College Buddy <query> WHat is Your Name <responce> 


### NExt

In [7]:
import torch
from typing import Optional, Tuple, Dict, Any
import torch.nn as nn
from transformers import BartForConditionalGeneration, BartConfig, BartModel
from transformers.modeling_outputs import Seq2SeqLMOutput

class EntailmentMemory(nn.Module):
    def __init__(self, num_slots: int, embedding_dim: int):
        super().__init__()
        self.memory = nn.Parameter(torch.randn(num_slots, embedding_dim))
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # x: [batch_size, embedding_dim]
        projected = self.proj(x)  # [batch_size, embedding_dim]
        scores = torch.matmul(projected, self.memory.T)  # [batch_size, num_slots]
        weights = self.softmax(scores)
        output = torch.matmul(weights, self.memory)  # [batch_size, embedding_dim]
        return output, weights

class DiscourseMemory(nn.Module):
    def __init__(self, num_slots: int, embedding_dim: int):
        super().__init__()
        self.memory = nn.Parameter(torch.randn(num_slots, embedding_dim))
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        projected = self.proj(x)
        scores = torch.matmul(projected, self.memory.T)
        weights = self.softmax(scores)
        output = torch.matmul(weights, self.memory)
        return output, weights

class BartModelWithMemory(BartModel):
    def __init__(self, config: BartConfig, erm_slots: int = 10, ddm_slots: int = 5):
        super().__init__(config)
        self.erm = EntailmentMemory(erm_slots, config.d_model)
        self.ddm = DiscourseMemory(ddm_slots, config.d_model)
        self.ortho_loss_coeff = 0.1

    def orthogonal_loss(self) -> torch.Tensor:
        return torch.norm(torch.mm(self.erm.memory, self.ddm.memory.T)) ** 2

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.LongTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        # Original BART forward
        outputs = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # Memory operations only on initial forward pass
        if past_key_values is None:
            encoder_hidden = outputs.encoder_last_hidden_state
            z_token = encoder_hidden[:, 0]  # Use first token embedding
            
            # Get memory outputs
            erm_out, _ = self.erm(z_token)
            ddm_out, _ = self.ddm(z_token)
            
            # Modify decoder inputs
            if decoder_inputs_embeds is None:
                decoder_inputs_embeds = self.decoder.embed_tokens(decoder_input_ids)
            
            # Apply memory to decoder's first token embedding
            decoder_inputs_embeds[:, 0] += erm_out + ddm_out
            
            # Re-run decoder with modified inputs
            outputs = super().forward(
                encoder_outputs=encoder_outputs,
                decoder_inputs_embeds=decoder_inputs_embeds,
                past_key_values=past_key_values,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )

        return outputs

class BARTWithMemory(BartForConditionalGeneration):
    def __init__(self, config: BartConfig, erm_slots: int = 10, ddm_slots: int = 5):
        super().__init__(config)
        self.model = BartModelWithMemory(config, erm_slots, ddm_slots)
        self.post_init()  # Important for loading pretrained weights

    def forward(self, **kwargs):
        outputs = super().forward(**kwargs)
        
        # Add orthogonal loss
        if kwargs.get('labels') is not None:
            ortho_loss = self.model.orthogonal_loss()
            outputs.loss += self.model.ortho_loss_coeff * ortho_loss
            
        return outputs

In [8]:
# Usage example
if __name__ == "__main__":
    from transformers import BartTokenizer

    # Initialize model
    config = BartConfig.from_pretrained("facebook/bart-base")
    model = BARTWithMemory.from_pretrained(
        "facebook/bart-base",
        config=config,
        erm_slots=10,
        ddm_slots=5
    )
    tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

    # Sample generation
    text = "The quick brown fox jumps over the lazy dog."
    inputs = tokenizer(text, return_tensors="pt")

    # Generate with memory
    outputs = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=100,
        num_beams=5,
        early_stopping=True
    )
    
    print("\nGenerated text:")
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))


Some weights of BARTWithMemory were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['model.ddm.memory', 'model.ddm.proj.bias', 'model.ddm.proj.weight', 'model.erm.memory', 'model.erm.proj.bias', 'model.erm.proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Generated text:
The quick brown fox jumps over the lazy dog.
