In [1]:
from outlines.models import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import transformers


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_anneal_attn_mask(seq_len, bsz, dtype, device, attn_mask_ratio):
    mask = torch.full((seq_len, seq_len), 0, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 1)
    causal_mask = mask.to(dtype)

    random_mask = torch.bernoulli(torch.full((seq_len, seq_len), 0.0, device=device) + attn_mask_ratio)

    anneal_mask = torch.logical_or(causal_mask, random_mask)
    expanded_mask = anneal_mask[None, None, :, :].expand(bsz, 1, seq_len, seq_len)
    inverted_mask = 1.0 - expanded_mask.to(dtype)

    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


In [3]:
# looks like Qwen using SDPA should be fine with an adapted attention mask, no


from functools import partial
from transformers import DynamicCache
from transformers.processing_utils import Unpack
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from cachetools import Cache
from typing import Optional, __all__
from transformers.utils import logging

logger = logging.get_logger(__name__)


def qwen_new_forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    use_cache = use_cache if use_cache is not None else self.config.use_cache

    if (input_ids is None) ^ (inputs_embeds is not None):
        raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

    if self.gradient_checkpointing and self.training and use_cache:
        logger.warning_once(
            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
        )
        use_cache = False

    # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
    if not isinstance(past_key_values, (type(None), Cache)):
        raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")

    if inputs_embeds is None:
        inputs_embeds = self.embed_tokens(input_ids)

    if use_cache and past_key_values is None:
        past_key_values = DynamicCache()

    if cache_position is None:
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        cache_position = torch.arange(
            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
        )

    if position_ids is None:
        position_ids = cache_position.unsqueeze(0)

    ## Add by DiffuLLaMA, adapting for 4d attention-mask.
    if attention_mask is not None and len(attention_mask.shape) == 4:
        causal_mask = attention_mask
        print("logging....attention-mask for 4d")
    else:
        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

    print(causal_mask)

    # causal_mask = self._update_causal_mask(
    #     attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
    # )

    hidden_states = inputs_embeds

    # create position embeddings to be shared across the decoder layers
    position_embeddings = self.rotary_emb(hidden_states, position_ids)

    # decoder layers
    all_hidden_states = () if output_hidden_states else None
    all_self_attns = () if output_attentions else None

    for decoder_layer in self.layers[: self.config.num_hidden_layers]:
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        if self.gradient_checkpointing and self.training:
            layer_outputs = self._gradient_checkpointing_func(
                partial(decoder_layer.__call__, **flash_attn_kwargs),
                hidden_states,
                causal_mask,
                position_ids,
                past_key_values,
                output_attentions,
                use_cache,
                cache_position,
                position_embeddings,
            )
        else:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **flash_attn_kwargs,
            )

        hidden_states = layer_outputs[0]

        if output_attentions:
            all_self_attns += (layer_outputs[1],)

    hidden_states = self.norm(hidden_states)

    # add hidden states from the last decoder layer
    if output_hidden_states:
        all_hidden_states += (hidden_states,)

    return BaseModelOutputWithPast(
        last_hidden_state=hidden_states,
        past_key_values=past_key_values if use_cache else None,
        hidden_states=all_hidden_states,
        attentions=all_self_attns,
    )


In [10]:
transformers.models.qwen3.modeling_qwen3.Qwen3Model.forward = qwen_new_forward

In [11]:
model_id = "Goedel-LM/Goedel-Prover-V2-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, trust_remote_code=True)


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 112.03it/s]


In [16]:
formal_statement = """
import Mathlib
import Aesop

set_option maxHeartbeats 0

open BigOperators Real Nat Topology Rat


theorem square_equation_solution {x y : ℝ} (h : x^2 + y^2 = 2*x - 4*y - 5) : x + y = -1 := by
  sorry
""".strip()

prompt = """
Complete the following Lean 4 code:

```lean4
{}```

Before producing the Lean 4 code to formally prove the given theorem, provide a detailed proof plan outlining the main proof steps and strategies.
The plan should highlight key ideas, intermediate lemmas, and proof structures that will guide the construction of the final formal proof.
""".strip()

chat = [
    [{"role": "user", "content": 'ABC'}],
    [{'role': 'user', 'content': 'ABCDEF'}]
]

inputs = tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, padding='longest',
                                       return_dict=True, return_tensors='pt')

In [7]:
inputs

{'input_ids': tensor([[151644,    872,    198,  25411, 151645,    198, 151644,  77091,    198,
         151643],
        [151644,    872,    198,  25411,  13649, 151645,    198, 151644,  77091,
            198]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [8]:
inputs['attention_mask']

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

In [13]:
model.forward(inputs['input_ids'], attention_mask=inputs['attention_mask'])


AttributeError: 'Qwen3Model' object has no attribute '_update_causal_mask'

In [17]:
attn_mask = get_anneal_attn_mask(inputs['input_ids'].shape[1], inputs['input_ids'].shape[0], dtype=torch.bfloat16, device=inputs['input_ids'].device, attn_mask_ratio=0.5)

In [18]:
logits = model.forward(inputs['input_ids'], attention_mask=attn_mask, output_attentions=True)


logging....attention-mask for 4d
tensor([[[[ 0.0000e+00, -3.3895e+38, -3.3895e+38,  0.0000e+00, -3.3895e+38,
           -3.3895e+38, -3.3895e+38, -3.3895e+38, -3.3895e+38,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00,  0.0000e+00, -3.3895e+38, -3.3895e+38,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00, -3.3895e+38, -3.3895e+38,  0.0000e+00, -3.3895e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00, -3.3895e+38,  0.0000e+00, -3.3895e+38,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           -3.3895e+38,  0.0000e+00,  0.0000e+00, -3.3895e+38,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00,  0.0000e+00, -3.3895e+38, -3.3895e+38, -3.3895e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,

RuntimeError: The size of tensor a (32) must match the size of tensor b (128) at non-singleton dimension 3

In [None]:
print (logits.attentions[0][0][1])
# assert attentions are 0 at attn_mask, and nonzero elsewhere

new_mask = attn_mask.expand(-1, logits.attentions[0].shape[1], -1,-1)

for layer in logits.attentions:
    for x_ in layer:
        for attns in x_:
            print (attns)
            break

# comparing to attn_mask, looks like the attentions are correctly working. I.e. attentions are 0 for masked, and non-zero elsewhere
# Interestingly, high attention values often found for subsequent tokens even though model only used to seeing past tokens.
# Suggests annealing worthwhile to limit influence earlier in adaptation process

In [None]:
# todo:
# - set up dataset, test edit flow data prep code (couplings, alignments etc. ).
# - Look at generating from empty coupling as well as error correcting coupling (i.e. add error to context, modify the code directly)
# - Set up model wrapper (rate predictions etc.), Quantization/LoRA. Unlike MDM adaptation, we can probably keep the logits unshifted, as we are predicting both substitute and insert next token probs, as well as delete.
