In [1]:
import transformers
import torch as t
import torchvision
import einops

In [2]:
gpt_model = transformers.AutoModelForSequenceClassification.from_pretrained("EleutherAI/gpt-j-6B")

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForSequenceClassification: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing GPTJForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPTJForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPTJForSequenceClassification were not initialized from the model checkpoint at EleutherAI/gpt-j-6B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
n_gpus = 8

class BlockWrapper(t.nn.Module):
    def __init__(self, gpt_block):
        super().__init__()
        self.model = gpt_block
    
    def forward(self, inputs):
        activations, *_ = self.model(inputs)
        return activations

def split_model(model, n_gpus):
    starts = t.linspace(0, 28, n_gpus + 1).int()[:-1] # Starting index of each section
    ends = t.linspace(0, 28, n_gpus + 1).int()[1:]
    blocks = [BlockWrapper(block) for block in model.transformer.h]
    gpt_block_sections = [t.nn.Sequential(*blocks[start:end]) for start, end in zip(starts, ends)]

    first = t.nn.Sequential(
        model.transformer.wte,
        model.transformer.drop,
        gpt_block_sections[0]
    )

    last = t.nn.Sequential(
        gpt_block_sections[-1],
        model.transformer.ln_f,
        model.score
    )

    models = [first] + gpt_block_sections[1:-1] + [last]
    return models


In [4]:
models = split_model(gpt_model, n_gpus)
for i, model in enumerate(models):
    t.save(model, 'gpt-j-%d.pt' % i)

In [None]:
inputs_list[0].shape

In [None]:
a = models[0](inputs_list)
b = models[1](a)
c = models[2](b)
d = models[3](c)

In [None]:
[x.shape for x in (a,b,c,d)]

In [None]:
fpo = [t.randn(10, 512, 2) for i in range(5)]

In [None]:
t.cat(fpo, 0).shape

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")


In [None]:
tokenizer.special_tokens_map

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

input_texts = [
    "Should Tamera refactor the code? Answer: ",
    "Some other text! Longer now",
    "Happy happy happy sad sad sad",
    "happy glad excited thrilled ecstatic",
    "sad blue bummed glum low"
]
input_texts = [
    "Should Tamera refactor the code? Answer: ",
    "Should Tamera refactor the code? Answer: ",
]

inputs_list = [t.Tensor(text).int().unsqueeze(0) for text in tokenizer(input_texts)['input_ids']]

for i, inputs in enumerate(inputs_list):

    original_output = gpt_model(inputs).logits # (1, 2)

    our_model = t.nn.Sequential(*models)
    our_output = our_model(inputs)[:,-1]
    
    if i in (3, 4):
        print(our_output)

    assert original_output.equal(our_output)

In [None]:
a, b = t.tensor([1, 2])

In [None]:
a

In [None]:
%load_ext autoreload
%autoreload 2

from utils import *

In [None]:
#train, test = imdb_data()

fake_train, fake_test = fake_imdb_data(n_batches = 100)

In [None]:
def label_to_tensor(label): 
    if label == 0:
        return t.Tensor([0, 1])
    else: 
        return t.Tensor([1, 0])

def preprocess_batch(batch):
    labels, inputs = zip(*batch)
    labels = t.stack([label_to_tensor(label) for label in labels])
    return (labels, inputs)

In [None]:
l, i = preprocess_batch(fake_train[0])

In [None]:
l.shape

In [None]:
len(train), len(train[0]), type(train[0][0][0]), train[0][0][1].shape, train[0][0][1].dtype

In [None]:
len(fake_train), len(fake_train[0]), type(fake_train[0][0][0]), fake_train[0][0][1].shape, fake_train[0][0][1].dtype

In [None]:
train[0][0][1].device