In [None]:
import transformers
import torch as t
import torchvision
import einops

gpt_model = transformers.AutoModelForSequenceClassification.from_pretrained("EleutherAI/gpt-j-6B")

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForSequenceClassification: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing GPTJForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPTJForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPTJForSequenceClassification were not initialized from the model checkpoint at EleutherAI/gpt-j-6B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
n_gpus = 4

class BlockWrapper(t.nn.Module):
    def __init__(self, gpt_block):
        super().__init__()
        self.model = gpt_block
    
    def forward(self, inputs):
        activations, *_ = self.model(inputs)
        return activations

def split_model(model, n_gpus):
    starts = t.linspace(0, 28, n_gpus + 1).int()[:-1] # Starting index of each section
    ends = t.linspace(0, 28, n_gpus + 1).int()[1:]
    blocks = [BlockWrapper(block) for block in model.transformer.h]
    gpt_block_sections = [t.nn.Sequential(*blocks[start:end]) for start, end in zip(starts, ends)]

    first = t.nn.Sequential(
        model.transformer.wte,
        model.transformer.drop,
        gpt_block_sections[0]
    )

    last = t.nn.Sequential(
        gpt_block_sections[-1],
    )
    
    real_last = t.nn.Sequential(        
        model.transformer.ln_f,
        model.score
    )

    models = [first] + gpt_block_sections[1:-1] + [last] + [real_last]
    return models

In [5]:
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

input_texts = [
    "Should Tamera refactor the code? Answer: ",
    "Some other text! Longer now",
    "Happy happy happy sad sad sad",
    "happy glad excited thrilled ecstatic",
    "sad blue bummed glum low"
]
input_texts = [
    "Should Tamera refactor the code? Answer: ",
    "Should Tamera refactor the code? Answer: ",
]

inputs_list = [t.Tensor(text).int().unsqueeze(0) for text in tokenizer(input_texts)['input_ids']]

Downloading:   0%|          | 0.00/619 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/779k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.94k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/357 [00:00<?, ?B/s]

In [6]:
mem = lambda gpu: t.cuda.memory_allocated(gpu) / 2**(30) 
mems = lambda gpu: f'{gpu} memory usage: {mem(gpu):.2f} GiB'

In [10]:
models = split_model(gpt_model, n_gpus)

TypeError: embedding(): argument 'indices' (position 2) must be Tensor, not list

In [14]:
holder = inputs_list[0].to("cuda:0")
gpu = 'cuda:0'

a = models[0].to("cuda:0")(holder)
print(mems(gpu))
b = models[1].to("cuda:0")(a)
print(mems(gpu))
c = models[2].to("cuda:0")(b)
print(mems(gpu))
d = models[3].to("cuda:0")(c)
print(mems(gpu))

cuda:0 memory usage: 22.05 GiB
cuda:0 memory usage: 22.08 GiB
cuda:0 memory usage: 22.12 GiB
cuda:0 memory usage: 22.02 GiB


In [17]:
gpt_model

GPTJForSequenceClassification(
  (transformer): GPTJModel(
    (wte): Embedding(50400, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0): GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): GPTJMLP(
          (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
          (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
      (1): GPTJBlock(
        (ln_1): LayerNorm