In [2]:
import transformers
import torch as t
import torchvision
import einops

In [42]:
gpt_model = transformers.AutoModelForSequenceClassification.from_pretrained("EleutherAI/gpt-j-6B")

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForSequenceClassification: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing GPTJForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPTJForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPTJForSequenceClassification were not initialized from the model checkpoint at EleutherAI/gpt-j-6B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [87]:
## Make the class
# Have self.model = model
# self.forward


class SplitModelWrapper(t.nn.Module):
    def __init__(self, s_model):
        super().__init__()
        self.model = s_model
    
    def forward(self, inputs):
        print(self.model)
        result = self.model(inputs)
        print(result)
        return result

In [43]:
n_gpus = 2

def split_model(model, n_gpus):
    starts = t.linspace(0, 28, n_gpus + 1).int()[:-1] # Starting index of each section
    ends = t.linspace(0, 28, n_gpus + 1).int()[1:]
    gpt_block_sections = [model.transformer.h[start:end] for start, end in zip(starts, ends)]

    first = t.nn.Sequential(
        model.transformer.wte,
        model.transformer.drop,
        gpt_block_sections[0]
    )

    last = t.nn.Sequential(
        gpt_block_sections[-1],
        model.transformer.ln_f,
        model.score
    )

    models = [first] + gpt_block_sections[1:-1] + [last]
    return models

models = split_model(gpt_model, n_gpus)
for i, model in enumerate(models):
    t.save(model, 'gpt-j-%d.pt' % i)

In [46]:
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

Downloading:   0%|          | 0.00/619 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/779k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.94k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/357 [00:00<?, ?B/s]

In [72]:
inputs = t.Tensor(tokenizer("Should Tamera refactor the code? Answer: ")["input_ids"]).int()
inputs = inputs.unsqueeze(0)

In [88]:
torso = SplitModelWrapper(models[0])
torso(inputs)

Sequential(
  (0): Embedding(50400, 4096)
  (1): Dropout(p=0.0, inplace=False)
  (2): ModuleList(
    (0): GPTJBlock(
      (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
      (attn): GPTJAttention(
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (mlp): GPTJMLP(
        (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
        (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (1): GPTJBlock(
      (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
      (attn): GPTJAttention(
        (attn_dropo

NotImplementedError: 

In [86]:
models[0]

Sequential(
  (0): Embedding(50400, 4096)
  (1): Dropout(p=0.0, inplace=False)
  (2): ModuleList(
    (0): GPTJBlock(
      (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
      (attn): GPTJAttention(
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (mlp): GPTJMLP(
        (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
        (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (1): GPTJBlock(
      (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
      (attn): GPTJAttention(
        (attn_dropo

In [73]:
original_output = gpt_model(inputs)

In [75]:
original_output.logits

tensor([[0.3227, 1.1226]], grad_fn=<IndexBackward0>)

In [76]:
our_model = t.nn.Sequential(*models)

In [77]:
our_output = our_model(inputs)

NotImplementedError: 

In [79]:
gpt_model.transformer.wte(inputs)

tensor([[[ 0.0010, -0.0453,  0.0014,  ...,  0.0157, -0.0010,  0.0138],
         [-0.0095, -0.0071, -0.0048,  ...,  0.0296,  0.0011,  0.0059],
         [-0.0159, -0.0139, -0.0133,  ..., -0.0032, -0.0170,  0.0362],
         ...,
         [-0.0242, -0.0049,  0.0158,  ...,  0.0052,  0.0274, -0.0107],
         [ 0.0019,  0.0064,  0.0051,  ...,  0.0023, -0.0058,  0.0081],
         [ 0.0013, -0.0045, -0.0043,  ...,  0.0041,  0.0029, -0.0003]]],
       grad_fn=<EmbeddingBackward0>)

In [44]:
def init_processes(rank, size, fn, backend='nccl'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    print('I am initiating the process at:', rank, ' of ', size)
    fn(rank, size)
    
def main():
    size = n_gpus
    processes = []
    mp.set_start_method('spawn', force=True)
    for rank in range(size):
        p = mp.Process(target=init_processes, args=(rank, size, run))
        print("I am starting!", rank, size)
        p.start()
        processes.append(p)
    for p in processes:
        p.join()    

2dp.gin        pipelineparallel.gin                   pp_nominibatch.py
2dparallel.py  pipelineparallel.py                    save_model.py
dist_basic.py  pipelineparallel_gptj_imdb.gin         w3d1.ipynb
gpt-j-0.pt     pipelineparallel_resnet50_cifar10.gin
gpt-j-1.pt     pp_naive.gin
