In [1]:
import transformers
import torch as t

In [25]:
model = transformers.AutoModelForSequenceClassification.from_pretrained("EleutherAI/gpt-j-6B")
model.eval();

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForSequenceClassification: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing GPTJForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPTJForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPTJForSequenceClassification were not initialized from the model checkpoint at EleutherAI/gpt-j-6B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
model

GPTJForSequenceClassification(
  (transformer): GPTJModel(
    (wte): Embedding(50400, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0): GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): GPTJMLP(
          (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
          (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
      (1): GPTJBlock(
        (ln_1): LayerNorm

In [10]:
model.transformer.h[0](t.randn((1,2,4096)))[0].shape

torch.Size([1, 2, 4096])

In [11]:
class SaneBlock(t.nn.Module):
    def __init__(self, h_block: t.nn.Module):
        super().__init__()
        self.h_block = h_block

    def forward(self, x):
        return self.h_block(x)[0]

In [26]:
stage_1 = t.nn.Sequential(
    model.transformer.wte,
    model.transformer.drop,
    *[SaneBlock(h_block) for h_block in model.transformer.h[:14]],
)

In [27]:
stage_2 = t.nn.Sequential(
    *[SaneBlock(h_block) for h_block in model.transformer.h[14:]],
    model.transformer.ln_f,
    model.score
)

In [28]:
t.save(stage_1, "stage_1.pt")
t.save(stage_2, "stage_2.pt")

In [31]:
input = t.randint(low=0, high=1232, size=(1,3))
stage_1_out = stage_1(input)
stage_2_out = stage_2(stage_1_out)
print(input.shape, stage_1_out.shape, stage_2_out.shape)

torch.Size([1, 3]) torch.Size([1, 3, 4096]) torch.Size([1, 3, 2])


In [None]:
model_out = model(input)


In [39]:
stage_2_out

tensor([[[ 0.3735,  0.2450],
         [ 0.1385, -1.4714],
         [ 3.8115, -0.7507]]], grad_fn=<UnsafeViewBackward0>)

In [37]:
t.allclose(stage_2_out[:,-1,:], model_out.logits)

True

In [2]:
stage_1 = t.nn.Linear(128, 128)
stage_2 = t.nn.Linear(128, 128)

t.save(stage_1, 'mlp1.pt')
t.save(stage_2, 'mlp2.pt')