In [119]:
from transformers import AutoModelForCausalLM, GPT2Config, GPT2ForQuestionAnswering, GPT2Model, GPT2LMHeadModel, AutoTokenizer
import torch
from typing import Optional, Tuple, Union
from datasets import load_from_disk

In [2]:
model = AutoModelForCausalLM.from_pretrained("gpt2")


In [120]:
tok = AutoTokenizer.from_pretrained("gpt2")

In [None]:
model = AutoModelForCausalLM.from_pretrained("gpt2")

In [77]:
to_select = ["attn", "mlp", "lm_head"]

In [6]:
# set([n for n, p in model.named_parameters()])

In [78]:
params = []
shapes = []
for n, p in model.named_parameters():
    if any(i in n for i in to_select):
        shapes.append(p.shape)
        params.append(p.flatten())

In [79]:
flat_params = torch.cat(params)

In [81]:
flat_params.shape

torch.Size([85017600])

In [23]:
def get_num_params(model):
    return sum(p.numel() for _, p in model.named_parameters())


In [82]:
flat_params.numel()#.view(-1, 768)

85017600

In [83]:
N = int(flat_params.numel() / (model.config.n_embd + model.config.vocab_size))
shape1 = (model.config.n_embd, N)
shape2 = (N, model.config.vocab_size)

In [101]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [145]:
class Model(GPT2LMHeadModel):
    def __init__(self) -> None:
        super().__init__(model.config)
        self.transformer.wte = model.transformer.wte
        self.transformer.wpe = model.transformer.wpe
        self.config = model.config
        self.transformer.h = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Linear(*shape1),
            )
        ])
        self.transformer.ln_f = torch.nn.LayerNorm((shape1[1],), eps=1e-05, elementwise_affine=True)
        self.lm_head = torch.nn.Linear(*shape2)
    
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(input_ids)
        hidden_states = transformer_outputs
        lm_logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(lm_logits.device)
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        output = (lm_logits,) + transformer_outputs[1:]
        return ((loss,) + output) if loss is not None else output

In [146]:
model2 = Model()

In [147]:
model2

Model(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=768, out_features=1666, bias=True)
      )
    )
    (ln_f): LayerNorm((1666,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1666, out_features=50257, bias=True)
)

Embedding(50257, 768)

In [121]:
dataset_dict = load_from_disk("../data/raw/wikitext-103-raw-v1/")

In [128]:
tok.pad_token_id = tok.eos_token_id
batch = tok(dataset_dict["train"][:10]["text"], padding=True, truncation=True, return_tensors="pt")

In [130]:
batch["labels"] = batch["input_ids"].clone()

In [132]:
model(**batch).loss

tensor(8.2907, grad_fn=<NllLossBackward0>)

In [148]:
model2(**batch)

TypeError: Sequential.forward() got an unexpected keyword argument 'layer_past'