In [7]:
import transformers
import torch
import torchvision
import einops

In [2]:
model = transformers.AutoModelForSequenceClassification.from_pretrained("EleutherAI/gpt-j-6B")

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForSequenceClassification: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing GPTJForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPTJForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPTJForSequenceClassification were not initialized from the model checkpoint at EleutherAI/gpt-j-6B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
model

GPTJForSequenceClassification(
  (transformer): GPTJModel(
    (wte): Embedding(50400, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0): GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): GPTJMLP(
          (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
          (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
      (1): GPTJBlock(
        (ln_1): LayerNorm

In [43]:
# embedding on GPU 1
# 28 blocks, divide into 7 each
# linear on GPU 3

DEVICES = ["cuda:0", "cuda:1", "cuda:2", "cuda:3"]

class WrappedGPTJBlock(torch.nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = block

    def forward(self, x):
        # Taking output out from one-element tuple
        [activations] = self.block(x)
        return activations

class MultiGPUGPTJ(torch.nn.Module):
    
    def __init__(self, transformer):
        super().__init__()
        
        self.gpu_0_block = torch.nn.Sequential(
            transformer.transformer.wte,
            transformer.transformer.drop,
            *[WrappedGPTJBlock(transformer.transformer.h[i]) for i in range(7)]
        )#.to(device=DEVICES[0])
        self.gpu_1_block = torch.nn.Sequential(
            *[WrappedGPTJBlock(transformer.transformer.h[i]) for i in range(7, 7*2)]
        )#.to(device=DEVICES[1])
        self.gpu_2_block = torch.nn.Sequential(
            *[WrappedGPTJBlock(transformer.transformer.h[i]) for i in range(7*2, 7*3)]
        )#.to(device=DEVICES[2])
        self.gpu_3_block = torch.nn.Sequential(
            *[WrappedGPTJBlock(transformer.transformer.h[i]) for i in range(7*3, 7*4)],
            transformer.transformer.ln_f,
            transformer.score,
        )#.to(device=DEVICES[3])

    def forward(self, x):
        x = self.gpu_0_block(x)
        x = self.gpu_1_block(x)
        x = self.gpu_2_block(x)
        x = self.gpu_3_block(x)
        final_logits = x[:, -1]  # We want to index into last sequence only!
        return final_logits


In [44]:
with torch.no_grad():
    model.eval()
    split_model = MultiGPUGPTJ(model)
    split_model.eval()
    
    inp = torch.randint(0, 100, (1, 2))
    expected_outputs = model(inp).logits # shape: 1,2 -- batch num_class
    actual_outputs = split_model(inp) # shape: 1,2,2 -- batch seq num_class
    
    assert torch.allclose(expected_outputs, actual_outputs), f"Got {actual_outputs} but expected {expected_outputs}"