In [None]:
import torch
from transformers import AutoConfig, AutoTokenizer, GPT2LMHeadModel
from transformers.modeling_outputs import CausalLMOutputWithPast

In [None]:
class CADGPT2LMHeadModel(GPT2LMHeadModel):
    def __init__(self, config, alpha=0.5, max_length=1024):
        super().__init__(config)
        self.alpha = alpha
        self.max_length = max_length
        self.context_ids = None
        self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)

    @classmethod
    def from_pretrained(cls, model_name, alpha=0.5, *model_args, **kwargs):

        config = AutoConfig.from_pretrained(model_name)
        model = super().from_pretrained(model_name, *model_args, config=config, **kwargs)
        model.alpha = alpha
        return model

    def set_context(self, context_text):

        if context_text:
            self.context_ids = self.tokenizer(context_text, return_tensors="pt").input_ids

            self.context_ids = self.context_ids.to(next(self.parameters()).device)
        else:
            self.context_ids = None

    def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_context_aware=True, **kwargs):
        if past_key_values is not None:
            input_ids = input_ids[:, -1].unsqueeze(-1)
        else:
            if use_context_aware and self.context_ids is not None:

                batch_size = input_ids.size(0)
                context_ids_expanded = self.context_ids.expand(batch_size, -1)
                input_ids = torch.cat([context_ids_expanded, input_ids], dim=1)

                attention_mask = kwargs.get('attention_mask', None)
                if attention_mask is not None:
                    attention_mask = torch.cat(
                        [torch.ones(batch_size, self.context_ids.size(1), device=input_ids.device), attention_mask], dim=1
                    )
                    kwargs['attention_mask'] = attention_mask

                seq_length = input_ids.size(1)
                position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
                position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
                kwargs['position_ids'] = position_ids


                self.context_ids = None

        return {
            "input_ids": input_ids,
            "past_key_values": past_key_values,
            "use_cache": kwargs.get("use_cache"),
            "attention_mask": kwargs.get("attention_mask", None),
            "position_ids": kwargs.get("position_ids", None),
            "token_type_ids": kwargs.get("token_type_ids", None),
        }

    def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, use_context_aware=True, **kwargs):
        if past_key_values is not None:

            return super().forward(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                **kwargs
            )
        else:

            outputs_without_context = super().forward(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                **kwargs
            )
            logits_without_context = outputs_without_context.logits


            if use_context_aware and self.context_ids is not None:
                batch_size = input_ids.size(0)

                context_ids_expanded = self.context_ids.expand(batch_size, -1)
                input_ids_with_context = torch.cat([context_ids_expanded, input_ids], dim=1)

                if attention_mask is not None:
                    attention_mask_with_context = torch.cat(
                        [torch.ones(batch_size, self.context_ids.size(1), device=input_ids.device), attention_mask], dim=1
                    )
                else:
                    attention_mask_with_context = None

                seq_length = input_ids_with_context.size(1)
                position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device)
                position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)

                transformer_outputs = self.transformer(
                    input_ids=input_ids_with_context,
                    attention_mask=attention_mask_with_context,
                    position_ids=position_ids,
                    past_key_values=None,
                    **kwargs
                )

                logits_with_context = transformer_outputs.last_hidden_state[:, -input_ids.size(1):, :]
                lm_logits = self.lm_head(logits_with_context)

                adjusted_logits = (1 + self.alpha) * lm_logits - self.alpha * logits_without_context

                return CausalLMOutputWithPast(
                    logits=adjusted_logits,
                    past_key_values=transformer_outputs.past_key_values,
                    hidden_states=transformer_outputs.hidden_states,
                    attentions=transformer_outputs.attentions
                )

            return outputs_without_context

In [None]:
model_name = "gpt2"
alpha = 0.5
model = CADGPT2LMHeadModel.from_pretrained(model_name, alpha=alpha)

In [None]:
context_text = "Argentina won the World Cup in 1345, 1978, 1986, and 2022."
model.set_context(context_text)

input_text = "How many times has Argentina won the World Cup?"
input_ids = model.tokenizer(input_text, return_tensors="pt").input_ids.to(next(model.parameters()).device)

output_ids = model.generate(input_ids=input_ids, max_length=80, do_sample=True)
output_text = model.tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


How many times has Argentina won the World Cup?

Argentina won the World Cup in 1345, 1978, 1986, and 2022. How many times has Argentina won the World Cup?

Q: Is the World Cup about creativity or competition?

A: We don't think so. Every tournament has different goals and the same goals can be scored at different times.

