In [1]:
!pip install datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTNeoXConfig
from datasets import load_dataset
from torch.profiler import profile, record_function, ProfilerActivity
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import numpy as np
import time
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXLayer, GPTNeoXAttention, GPTNeoXForCausalLM, GPTNeoXModel
import math
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions

Successfully installed datasets-2.19.1 dill-0.3.8 huggingface-hub-0.23.0 multiprocess-0.70.16 xxhash-3.4.1


In [None]:
model_name = "EleutherAI/pythia-160m"
model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [3]:
model

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 768)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_value): Linear(in_features=768, out_features=2304, bias=True)
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=768, out_features=3072, bias=True)
          (dense_4h_to_h): Linear(in_features=3072, out_features=768, bias=True)
          

Just some experimenting...trying to modify the GPTNeoXAttention with a customized one that mask a percentage of tokens.

In [18]:
class GPTNeoXAttentionWithPruning(GPTNeoXAttention):
    def __init__(self, config):
        super().__init__(config)

    def top_k_tokens(self, attention_scores, k_percent):
        """
        Mask tokens based on Top-K attention scores.
        """
        k = int(attention_scores.size(-1) * k_percent)
        top_k_scores, _ = torch.topk(attention_scores, k, dim=-1)
        min_top_k_scores = top_k_scores[..., -1, None]
        mask = attention_scores >= min_top_k_scores
        #print(sum(mask))
        return mask

    def _attn(self, query, key, value, attention_mask=None, head_mask=None, k_percent=0.9):
        batch_size, num_attention_heads, query_length, attn_head_size = query.size()
        key_length = key.size(-2)

        if key_length > self.bias.shape[-1]:
            self._init_bias(key_length, device=key.device)
        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]

        query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
        key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
        attn_scores = torch.zeros(
            batch_size * num_attention_heads,
            query_length,
            key_length,
            dtype=query.dtype,
            device=key.device,
        )
        attn_scores = torch.baddbmm(
            attn_scores,
            query,
            key.transpose(1, 2),
            beta=1.0,
            alpha=self.norm_factor,
        )
        attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)

        mask_value = torch.finfo(attn_scores.dtype).min
        mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
        attn_scores = torch.where(causal_mask, attn_scores, mask_value)

        if attention_mask is not None:
            attn_scores = attn_scores + attention_mask

        #reduction
        token_mask = self.top_k_tokens(attn_scores, k_percent)
        #print(token_mask)
        attn_scores = attn_scores.masked_fill(~token_mask, mask_value)
        #print(attn_scores)
        attn_weights = nn.functional.softmax(attn_scores, dim=-1)

        attn_weights = attn_weights.to(value.dtype)

        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_weights = self.attention_dropout(attn_weights)

        attn_output = torch.matmul(attn_weights, value)
        return attn_output, attn_weights


In [19]:
class GPTNeoXLayerWithPruning(GPTNeoXLayer):
    def __init__(self, config):
        super().__init__(config)
        self.attention = GPTNeoXAttentionWithPruning(config)

class CustomGPTNeoXModel(GPTNeoXModel):
    def __init__(self, config):
        super().__init__(config)
        self.layers = nn.ModuleList([GPTNeoXLayerWithPruning(config) for _ in range(config.num_hidden_layers)])

class CustomGPTNeoXForCausalLM(GPTNeoXForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.gpt_neox = CustomGPTNeoXModel(config)


In [20]:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
config = GPTNeoXConfig.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
model.resize_token_embeddings((len(tokenizer) + 7) // 8 * 8)
model.to('cuda')

#preprocessing for the dataset
def preprocess(example):
    return tokenizer(example["text"], padding="max_length", truncation=True, max_length=128)

#apply preprocessing to dataset
dataset = load_dataset("ag_news", split="test[:1%]")
def preprocess(example):
    return tokenizer(example["text"], padding="max_length", truncation=True, max_length=128)

dataset = dataset.map(preprocess, batched=True)

input_ids_list = [torch.tensor(item['input_ids']) for item in dataset]
input_ids = torch.stack(input_ids_list).cuda()
attention_masks = torch.stack([torch.tensor(item['attention_mask']) for item in dataset]).cuda()

#generate autoregressive way taking count of execution time
def generate_autoregressively(model, tokenizer, input_ids, attention_mask, max_new_tokens=20):
    model.eval()
    with torch.no_grad():
        start_time = time.time()
        generated_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id)
        end_time = time.time()
        generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        execution_time = end_time - start_time
    return generated_text, execution_time

#just using pythia
baseline_generated_texts = []
baseline_times = []
for i in range(input_ids.size(0)):
    single_input_ids = input_ids[i:i+1]
    single_attention_mask = attention_masks[i:i+1]
    generated_text, single_execution_time = generate_autoregressively(model, tokenizer, single_input_ids, single_attention_mask, max_new_tokens=20)
    baseline_generated_texts.append(generated_text)
    baseline_times.append(single_execution_time)

In [21]:
#let's test now the customized version
custom_tokenizer = AutoTokenizer.from_pretrained(model_name)
custom_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
custom_config = GPTNeoXConfig.from_pretrained(model_name)
custom_model = CustomGPTNeoXForCausalLM.from_pretrained(model_name, config=custom_config)
custom_model.resize_token_embeddings((len(custom_tokenizer) + 7) // 8 * 8)
custom_model.to('cuda')

# Apply preprocessing to dataset using custom tokenizer
custom_input_ids_list = [torch.tensor(item['input_ids']) for item in dataset]
custom_input_ids = torch.stack(custom_input_ids_list).cuda()
custom_attention_masks = torch.stack([torch.tensor(item['attention_mask']) for item in dataset]).cuda()

# Test with custom model
custom_generated_texts = []
custom_times = []
for i in range(custom_input_ids.size(0)):
    single_input_ids = custom_input_ids[i:i+1]
    single_attention_mask = custom_attention_masks[i:i+1]
    generated_text, single_execution_time = generate_autoregressively(custom_model, custom_tokenizer, single_input_ids, single_attention_mask, max_new_tokens=20)
    custom_generated_texts.append(generated_text)
    custom_times.append(single_execution_time)

#print("Baseline Model Results:")
#for text, time_taken in zip(baseline_generated_texts, baseline_times):
#    print(f"Generated Text: {text}")
#    print(f"Time Taken: {time_taken:.4f} seconds")
#

#print("\nCustom Model Results:")
#for text, time_taken in zip(custom_generated_texts, custom_times):
#    print(f"Generated Text: {text}")
#    print(f"Time Taken: {time_taken:.4f} seconds")


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
print(sum(custom_times), sum(baseline_times))

27.010892391204834 34.501551389694214


It seems to be more efficient, but the results are a bit too similar, I will need to check if the mask is correctly done, if is really computationally cheaper and if the [PAD] tokens are considered during the topk selection

In [23]:
print(custom_generated_texts[4])

Calif. Aims to Limit Farm-Related Smog (AP) AP - Southern California's smog-fighting agency went after emissions of the bovine variety Friday, adopting the nation's first rules to reduce air pollution from dairy cow manure.- AP

The agency's new rules, which are expected to be finalized by the end


In [24]:
print(baseline_generated_texts[4])

Calif. Aims to Limit Farm-Related Smog (AP) AP - Southern California's smog-fighting agency went after emissions of the bovine variety Friday, adopting the nation's first rules to reduce air pollution from dairy cow manure.- AP

The agency's new rules, which are expected to be finalized by the end


Let's try to implement how usually is performed the reduction. The attention between the spatial tokens and the CLS token is used.

In [27]:
class GPTNeoXAttentionWithPruning(GPTNeoXAttention):
    def top_k_pruning(self, cls_scores, k_percent):
        k = int(cls_scores.size(-1) * k_percent)
        top_k_scores, _ = torch.topk(cls_scores, k, dim=-1, largest=False)
        min_top_k_scores = top_k_scores[..., -1, None]
        mask = cls_scores <= min_top_k_scores
        return mask

    def _attn(self, query, key, value, attention_mask=None, head_mask=None, k_percent=0.1):
        batch_size, num_attention_heads, query_length, attn_head_size = query.size()
        key_length = key.size(-2)

        if key_length > self.bias.shape[-1]:
            self._init_bias(key_length, device=key.device)
        causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]

        query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
        key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
        attn_scores = torch.baddbmm(
            torch.zeros(
                batch_size * num_attention_heads,
                query_length,
                key_length,
                dtype=query.dtype,
                device=key.device,
            ),
            query,
            key.transpose(1, 2),
            beta=1.0,
            alpha=self.norm_factor,
        )
        attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)

        mask_value = torch.finfo(attn_scores.dtype).min
        mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
        attn_scores = torch.where(causal_mask, attn_scores, mask_value)

        if attention_mask is not None:
            attn_scores = attn_scores + attention_mask

        cls_scores = attn_scores[:, :, 0, :]

        token_mask = self.top_k_pruning(cls_scores, k_percent)

        token_mask = token_mask.unsqueeze(2).expand(batch_size, num_attention_heads, query_length, key_length)
        attn_scores = attn_scores.masked_fill(~token_mask, mask_value)
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = attn_weights.to(value.dtype)

        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_weights = self.attention_dropout(attn_weights)
        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights, token_mask

    def forward(self, hidden_states, attention_mask=None, head_mask=None, layer_past=None, use_cache=False, output_attentions=False, k_percent=0.1):
        qkv = self.query_key_value(hidden_states)

        query, key, value = torch.chunk(qkv, 3, dim=-1)

        query = self._split_heads(query, self.num_attention_heads, self.head_size)
        key = self._split_heads(key, self.num_attention_heads, self.head_size)
        value = self._split_heads(value, self.num_attention_heads, self.head_size)

        attn_output, attn_weights, token_mask = self._attn(query, key, value, attention_mask, head_mask, k_percent)
        attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)

        attn_output = self.dense(attn_output)

        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
        return outputs, token_mask

class GPTNeoXLayerWithPruning(GPTNeoXLayer):
    def __init__(self, config):
        super().__init__(config)
        self.attention = GPTNeoXAttentionWithPruning(config)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        layer_past=None,
        use_cache=False,
        output_attentions=False,
        token_mask=None,
        **kwargs
    ):
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)

        layer_past = layer_past or [None, None]

        attn_outputs, token_mask = self.attention(
            hidden_states,
            attention_mask=attention_mask,
            head_mask=head_mask,
            layer_past=layer_past,
            use_cache=use_cache,
            output_attentions=output_attentions,
            k_percent=0.1
        )

        attn_output = attn_outputs[0]
        attn_weights = attn_outputs[1] if len(attn_outputs) > 1 else None

        if use_cache:
            present = (attn_outputs[1], layer_past)
        else:
            present = None

        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)

        hidden_states = hidden_states + residual

        outputs = (hidden_states, present)
        if output_attentions:
            outputs += (attn_weights,)

        return outputs, token_mask

class CustomGPTNeoXModel(GPTNeoXModel):
    def __init__(self, config):
        super().__init__(config)
        self.layers = nn.ModuleList([GPTNeoXLayerWithPruning(config) for _ in range(config.num_hidden_layers)])
        self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, input_ids, attention_mask=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, use_cache=None, **kwargs):
        if inputs_embeds is None:
            inputs_embeds = self.embed_in(input_ids)

        hidden_states = inputs_embeds

        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        token_mask = None
        for idx, layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs, token_mask = layer(
                hidden_states,
                attention_mask,
                head_mask=head_mask[idx] if head_mask is not None else None,
                layer_past=next_decoder_cache[idx] if next_decoder_cache is not None else None,
                use_cache=use_cache,
                output_attentions=output_attentions,
                token_mask=token_mask,
            )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[1],)
            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[2],)

        hidden_states = self.final_layernorm(hidden_states)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, next_decoder_cache, all_hidden_states, all_attentions] if v is not None)

        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_attentions,
        )

class CustomGPTNeoXForCausalLM(GPTNeoXForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.gpt_neox = CustomGPTNeoXModel(config)

In [None]:
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
config.num_hidden_layers = 12
custom_model = CustomGPTNeoXForCausalLM.from_pretrained(model_name, config=config).cuda()

print(custom_model)
# add CLS token
if tokenizer.cls_token is None:
    tokenizer.add_special_tokens({'cls_token': '[CLS]'})
    custom_model.resize_token_embeddings(len(tokenizer))

def preprocess(example):
    text = example["text"]
    inputs = tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
    inputs['input_ids'][0, 0] = tokenizer.cls_token_id
    return {
        'input_ids': inputs['input_ids'][0].tolist(),
        'attention_mask': inputs['attention_mask'][0].tolist()
    }

dataset = load_dataset("ag_news", split="test[:1%]")

input_ids_list = []
attention_masks_list = []

for example in dataset:
    processed = preprocess(example)
    input_ids_list.append(processed['input_ids'])
    attention_masks_list.append(processed['attention_mask'])

input_ids = torch.tensor(input_ids_list).cuda()
attention_masks = torch.tensor(attention_masks_list).cuda()

custom_generated_texts = []
custom_times = []
for i in range(input_ids.size(0)):
    single_input_ids = input_ids[i:i+1]
    single_attention_mask = attention_masks[i:i+1]
    generated_text, single_execution_time = generate_autoregressively(custom_model, tokenizer, single_input_ids, single_attention_mask, max_new_tokens=20)
    custom_generated_texts.append(generated_text)
    custom_times.append(single_execution_time)

print("\nCustom Model Results:")
for text, time_taken in zip(custom_generated_texts, custom_times):
    print(f"Generated Text: {text}")

In [30]:
sum(custom_times)

40.7024359703064

not working....