In [None]:
from transformers import GPT2Model, GPT2Config
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from typing import Optional, Tuple, Union # Import Optional, Tuple, and Union

ALU implementation

In [None]:
class ALU(torch.nn.Module):
    def __init__(self, model_dim=768, hidden_dim=512, internal_dim=10, use_output_projection=False):
        super(ALU, self).__init__()

        # input mlp does model_dim -> hidden_dim -> hidden_dim -> (internal_dim * 2 + 4)
        self.input_mlp = nn.Sequential(
            nn.Linear(model_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, internal_dim * 2 + 4),
            nn.LeakyReLU()
        )

        if use_output_projection:
            # output projection does 1 -> internal_dim -> hidden_dim -> model_dim
            self.output_projection = nn.Sequential(
                nn.Linear(1, internal_dim),
                nn.ReLU(),
                nn.Linear(internal_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, model_dim)
            )

        self.eps = 1e-8
        self.base = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128, 256, 512])

    def forward(self, x):
        print("X-before: ", x.shape)
        x = self.input_mlp(x)
        a = x[:, :10]
        b = x[:, 10:20]
        op = x[:, 20:24]
        print("X-after: ", x.shape)
        print("A: ", a.shape)
        print("B: ", b.shape)
        print("OP: ", op.shape)
        base = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128, 256, 512], device=x.device, dtype=x.dtype)
        a = torch.matmul(a, base)
        b = torch.matmul(b, base)

        op_weights = F.softmax(op, dim=1)  # Shape: (batch_size, 4)

        add = a + b
        sub = a - b
        mul = a * b
        div = a / (b + self.eps)

        op_outs = torch.stack([add, sub, mul, div], dim=1)  # Shape: (batch_size, 4)
        result = torch.sum(op_outs * op_weights, dim=1, keepdim=True)  # Shape: (batch_size, 1)

        if hasattr(self, 'output_projection'):
            result = self.output_projection(result)

        return result

Standard GPT-2

In [None]:
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig, GPT2Config
configuration = GPT2Config()
model = GPT2LMHeadModel(configuration)
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


Modified GPT-2

In [None]:
class CustomGPT2Block(GPT2Block):
    def __init__(self, config):
        super().__init__(config)
        self.alu = ALU(model_dim=config.n_embd)
        self.final_projection = nn.Linear(config.n_embd+24, config.n_embd)

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]
        # residual connection
        hidden_states = attn_output + residual

        if encoder_hidden_states is not None:
            # add one self-attention block for cross-attention
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
                    "cross-attention layers by setting `config.add_cross_attention=True`"
                )
            residual = hidden_states
            hidden_states = self.ln_cross_attn(hidden_states)
            cross_attn_outputs = self.crossattention(
                hidden_states,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
            )
            attn_output = cross_attn_outputs[0]
            # residual connection
            hidden_states = residual + attn_output
            outputs = outputs + cross_attn_outputs[2:]  # add cross attentions if we output attention weights

        alu_output = self.alu(hidden_states)     # NEW CODE: calling the ALU using the hidden_states
        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        hidden_states = self.final_projection(torch.cat([hidden_states, alu_output], dim=-1))  # NEW CODE: concatenating the ALU output to the hidden states and projecting it to n_embd
        outputs = (hidden_states,) + outputs

        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs  # hidden_states, present, (attentions, cross_attentions)

In [None]:
class CustomGPT2Model(GPT2Model):
    def __init__(self, config):
        super().__init__(config)
        num_layers = len(self.h)
        for i in range(num_layers - 3, num_layers):
            self.h[i] = CustomGPT2Block(config)

    def forward(self, input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None):
        return super().forward(input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict)

In [None]:
model2 = CustomGPT2Model(configuration)
print(model2)

CustomGPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-8): 9 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2SdpaAttention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (9-11): 3 x CustomGPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2SdpaAttention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=Fa

Later to load the weights

In [None]:
config = GPT2Config.from_pretrained('gpt2')
customModel = CustomGPT2Model(config)

# If you want to load pre-trained weights:
state_dict = GPT2Model.from_pretrained('gpt2').state_dict()
customModel.load_state_dict(state_dict, strict=False)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

_IncompatibleKeys(missing_keys=['h.9.extra_linear.weight', 'h.9.extra_linear.bias', 'h.10.extra_linear.weight', 'h.10.extra_linear.bias', 'h.11.extra_linear.weight', 'h.11.extra_linear.bias'], unexpected_keys=[])