In [2]:
import pandas as pd

df = pd.DataFrame(
    {
        "Name": ["John", "Jane Lastname"],
        "Age": [30, 25],
    })

In [3]:
separator = '\u0488'
prompt_pre = "This is a table\n\n"
prompt_table = df.to_csv(index=False, sep=separator)
prompt_post = "\n\nHow old is Jane?\n\nAnswer:"

print(prompt_table)

Name҈Age
John҈30
Jane Lastname҈25



In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8b", revision="main")
tokenizer.pad_token = tokenizer.eos_token


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
prompt = prompt_pre + prompt_table + prompt_post
tokenizer(prompt)

{'input_ids': [128000, 2028, 374, 264, 2007, 271, 678, 142, 230, 17166, 198, 13379, 142, 230, 966, 198, 63602, 8155, 609, 142, 230, 914, 1432, 4438, 2362, 374, 22195, 1980, 16533, 25], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [6]:
tokens = [tokenizer.bos_token] + tokenizer.tokenize(prompt, return_tensors="pt")
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids)


[128000, 2028, 374, 264, 2007, 271, 678, 142, 230, 17166, 198, 13379, 142, 230, 966, 198, 63602, 8155, 609, 142, 230, 914, 1432, 4438, 2362, 374, 22195, 1980, 16533, 25]


In [7]:
tokens_pre = tokenizer.tokenize(prompt_pre, return_tensors="pt")
tokens_table = tokenizer.tokenize(prompt_table, return_tensors="pt")
tokens_post = tokenizer.tokenize(prompt_post, return_tensors="pt")
tokens = tokens_pre + tokens_table + tokens_post

text_ids = list(range(len(tokens_pre))) + list(range(len(tokens_pre) + len(tokens_table), len(tokens_pre) + len(tokens_table) + len(tokens_post)))
table_ids = list(range(len(tokens_pre), len(tokens_pre) + len(tokens_table)))

print(tokens)
print(text_ids)
print(table_ids)

['This', 'Ġis', 'Ġa', 'Ġtable', 'ĊĊ', 'Name', 'Ò', 'Ī', 'Age', 'Ċ', 'John', 'Ò', 'Ī', '30', 'Ċ', 'Jane', 'ĠLast', 'name', 'Ò', 'Ī', '25', 'Ċ', 'ĊĊ', 'How', 'Ġold', 'Ġis', 'ĠJane', '?ĊĊ', 'Answer', ':']
[0, 1, 2, 3, 4, 22, 23, 24, 25, 26, 27, 28, 29]
[5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]


In [8]:
curr_row, curr_col = 0, 0
newline_tokens = tokenizer.tokenize("\n", return_tensors="pt")
separator_tokens = tokenizer.tokenize(separator, return_tensors="pt")

content_ids = set()
rows_ids = []
cols_ids = []
just_added_col = False

for token_id in table_ids:
    if  tokens[token_id] in newline_tokens:
        curr_row += 1
        curr_col = 0

    if len(rows_ids) <= curr_row:
        rows_ids.append([])
    if len(cols_ids) <= curr_col:
        cols_ids.append([])
    
    if  tokens[token_id] in separator_tokens and not just_added_col:
        curr_col += 1
        just_added_col = True
    else:
        just_added_col = False

    if tokens[token_id] not in separator_tokens and tokens[token_id] not in newline_tokens:
        if len(rows_ids[curr_row]) <= curr_col:
            rows_ids[curr_row].append([])
        if len(cols_ids[curr_col]) <= curr_row:
            cols_ids[curr_col].append([])

        rows_ids[curr_row][curr_col].append(token_id)
        cols_ids[curr_col][curr_row].append(token_id)
        content_ids.add(token_id)


print("Rows:")
for row in rows_ids:
    for cell in row:
        print('"', end='')
        for token_id in cell:
            print(tokens[token_id], end='')
        print('",', end=' ')
    print()

print("Cols:")
for col in cols_ids:
    for cell in col:
        print('"', end='')
        for token_id in cell:
            print(tokens[token_id], end='')
        print('",', end=' ')
    print()

print()
print(rows_ids)
print(cols_ids)
print(content_ids)

Rows:
"Name", "Age", 
"John", "30", 
"JaneĠLastname", "25", 

Cols:
"Name", "John", "JaneĠLastname", 
"Age", "30", "25", 

[[[5], [8]], [[10], [13]], [[15, 16, 17], [20]], []]
[[[5], [10], [15, 16, 17]], [[8], [13], [20]]]
{5, 8, 10, 13, 15, 16, 17, 20}


In [17]:
attention_pairs = set()

for row in rows_ids:
    for i, cell in enumerate(row):
        for ti, token_id in enumerate(cell):

            for token_id2 in cell[ti + 1:]:
                attention_pairs.add((token_id, token_id2))

            for cell2 in row[i + 1:]:
                for token_id2 in cell2:
                    attention_pairs.add((token_id, token_id2))

for col in cols_ids:
    for i, cell in enumerate(col):
        for token_id in cell:
            for cell2 in col[i + 1:]:
                for token_id2 in cell2:
                    attention_pairs.add((token_id, token_id2))

attention_pairs

{(5, 8),
 (5, 10),
 (5, 15),
 (5, 16),
 (5, 17),
 (8, 13),
 (8, 20),
 (10, 13),
 (10, 15),
 (10, 16),
 (10, 17),
 (13, 20),
 (15, 16),
 (15, 17),
 (15, 20),
 (16, 17),
 (16, 20),
 (17, 20)}

In [7]:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.llama.modeling_llama import LlamaAttention
import torch

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8b", trust_remote_code=True)
model.config.pad_token_id = tokenizer.pad_token_id
model.config._attn_implementation = "eager"

Loading checkpoint shards: 100%|██████████| 4/4 [00:16<00:00,  4.11s/it]


In [8]:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.llama.modeling_llama import LlamaAttention
import torch
from torch import nn
from typing import Callable, Optional, Tuple, Union, Unpack
from transformers.utils import logging
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
logger = logging.get_logger(__name__)


def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: float,
    dropout: float = 0.0,
    **kwargs,
):
    key_states = repeat_kv(key, module.num_key_value_groups)
    value_states = repeat_kv(value, module.num_key_value_groups)

    attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
    if attention_mask is not None:
        causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    if attn_weights.shape[2] == attn_weights.shape[3] == len(tokens):
        for i in table_ids:
            for j in table_ids:
                if i != j:
                    if (i,j) not in attention_pairs and (j,i) not in attention_pairs:
                        attn_weights[:, :, i, j] = float(-1e9)
                        attn_weights[:, :, j, i] = float(-1e9)

    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)

    attn_output = torch.matmul(attn_weights, value_states)
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, attn_weights

class TabularLlamaAttention(LlamaAttention):

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:

        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attn_output, attn_weights = eager_attention_forward(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

for i, layer in enumerate(model.model.layers):
    original_attn = layer.self_attn
    tabular_attn = TabularLlamaAttention(model.config, i)
    tabular_attn.load_state_dict(original_attn.state_dict())
    layer.self_attn = tabular_attn



model.to("cuda")

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): TabularLlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (

In [11]:
inputs = tokenizer.convert_tokens_to_ids(tokens)
inputs = torch.tensor([inputs]).to("cuda")

with torch.no_grad():
    result = model.generate(input_ids=inputs)

tokenizer.decode(result[0], skip_special_tokens=True)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


'This is a table\n\nName҈Age\nJohn҈30\nJane Lastname҈25\n\n\nHow old is Jane?\n\nAnswer: 25\n\nJane is 25 years old.'