In [1]:
import os
import time
import random
from typing import Optional
import transformers
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import torchtext
import torchvision
import einops

DEVICES = ["cuda:0", "cuda:1", "cuda:2", "cuda:3"]

In [2]:
model = transformers.AutoModelForSequenceClassification.from_pretrained("EleutherAI/gpt-j-6B")

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForSequenceClassification: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing GPTJForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPTJForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPTJForSequenceClassification were not initialized from the model checkpoint at EleutherAI/gpt-j-6B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
torch.nn.Sequential(torch.nn.Linear(10, 10))[0]

Linear(in_features=10, out_features=10, bias=True)

In [None]:
class WrappedGPTJBlock(torch.nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = block

    def forward(self, x):
        # Taking output out from one-element tuple
        [activations] = self.block(x)
        return activations


def save_gptj_blocks(model):
    block_0 = torch.nn.Sequential(
        model.transformer.wte,
        model.transformer.drop,
        *[WrappedGPTJBlock(model.transformer.h[i]) for i in range(7)]
    )

    block_1 = torch.nn.Sequential(
        *[WrappedGPTJBlock(model.transformer.h[i]) for i in range(7, 7*2)]
    )

    block_2 = torch.nn.Sequential(
        *[WrappedGPTJBlock(model.transformer.h[i]) for i in range(7*2, 7*3)]
    )

    block_3 = torch.nn.Sequential(
        *[WrappedGPTJBlock(model.transformer.h[i]) for i in range(7*3, 7*4)],
        model.transformer.ln_f,
        model.score,
    )

    blocks = [block_0, block_1, block_2, block_3]

    for i, block in enumerate(blocks):
        torch.save(block, f"gptj_block_{i}.pt")


# save_gptj_blocks(model)

In [None]:
def compare_our_model_to_theirs(our_model, their_model):
    our_model.eval()
    their_model.eval()
    
    with torch.no_grad():
        our_model = MultiGPUGPTJ(model)
        inp = torch.randint(0, 100, (1, 2))
        expected_outputs = their_model(inp).logits # shape: 1,2 -- batch num_class
        actual_outputs = our_model(inp) # shape: 1,2,2 -- batch seq num_class

        assert torch.allclose(expected_outputs, actual_outputs), f"Got {actual_outputs} but expected {expected_outputs}"


# compare_our_model_to_theirs(our_model, their_model)

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token

# TODO: find out why the following breaks `tokenizer(['hi how are you', 'something'], padding='longest').input_ids`

In [None]:
# From dataloader above, hardcoded
BATCH_SIZE = 64
MAX_SEQ_LEN = 128

In [None]:
data_train, data_test = torchtext.datasets.IMDB(root='.data', split=('train', 'test'))

data_train = list(data_train)
data_test = list(data_test)
print(len(data_train), len(data_test))

def to_batches(data, batch_size=BATCH_SIZE, max_seq_len=MAX_SEQ_LEN):
    sorted_data = sorted(data, key=lambda d: len(d[1]))
    num_batches = (len(data) + batch_size - 1) // batch_size
    batched_data = []
    for i in range(num_batches):
        batch_start = i * batch_size
        batch_end = batch_start + batch_size
        batch = sorted_data[batch_start:batch_end]
        sentiments = torch.tensor([1 if s == 'pos' else 0 for s, r in batch])
        reviews = [r for s, r in batch]
        tokenization = tokenizer(reviews, padding='max_length', max_length=max_seq_len, truncation=True, return_tensors="pt")
        review_tokens = tokenization.input_ids
        batched_data.append((review_tokens, sentiments))
    random.shuffle(batched_data)
    return batched_data


data_batches = to_batches(data_train, batch_size=16)

25000 25000
25000 25000


In [None]:
def train(model):
    model.train()
    optimizer = t.optim.Adam(model.parameters(), lr=1e-5)

    for i, (input, target) in enumerate(data_batches):
        optimizer.zero_grad()
        logits, class_logits = model(input.to(device))
        loss = t.nn.functional.cross_entropy(class_logits, target.to(device))
        loss.backward()
        optimizer.step()
        print(f"{i} {loss}")

In [None]:
def eval_on_imdb(model, num_batches_to_use=None):
    model.eval()
    test_batches = to_batches(data_test, batch_size=8)
    total_correct = 0
    total_samples = 0
    if num_batches_to_use is None:
        num_batches_to_use = len(test_batches)
    print(f'evaluating using {num_batches_to_use} batches')
    for i, (input, target) in enumerate(test_batches):
        logits, class_logits = model(input.to(device))
        answers = t.argmax(class_logits, dim=-1)
        total_correct += t.sum(answers == target.to(device))
        total_samples += answers.shape[0]
        print(i, total_correct, total_samples)
        if i >= num_batches_to_use:
            break
    return total_correct / total_samples


# eval_on_imdb(my_bert, num_batches_to_use=50)

In [None]:
SHARD_OPTIMIZER_STATE=False
"""
mini_batch_size 350: both pass
mini_batch_size 375: sharding passes
mini_batch_size 400: both fail
"""

LEADER = 0
HIDDEN_SIZE = 4096


def load_block(rank):
    return torch.load(f"block_{rank}.pt")


def run(rank, size, world_size):
    """ Distributed function to be implemented later. """
    
    device = DEVICES[rank]
    block = load_block(rank).to(device)
    
    if rank == LEADER:
        # If there's still data, fetch them
        if data_batches:
            inps, labels = next(data_batches)
            inps.to(device)
            labels.to(device)
    else:
        # Initialise inps: batch, seq_len, hidden_size (I'm guessing we don't consider num_heads, head_size??)
        # We are fetching outputs from the last block
        inps = torch.zeros(BATCH_SIZE, MAX_SEQ_LEN, HIDDEN_SIZE).to(device)
        group = dist.new_group([rank-1, rank])
        dist.broadcast(tensor=inps, src=rank-1, group=group)

    # Put it through block
    out = block(inps)
    print(f"Rank {rank} output {out.shape}")
    
    # Send output on to next block
    if rank < len(DEVICES) - 1:
        group = dist.new_group([rank, rank+1])
        dist.broadcast(tensor=out, src=rank, group=group)
        if rank == LEADER:
            # Send labels to last block
            group = dist.new_group([rank, world_size - 1])
            dist.broadcast(tensor=labels, src=rank, group=group)
    else:
        # Get labels from the first block
        labels = torch.zeros(BATCH_SIZE).to(device)
        group = dist.new_group([LEADER, world_size - 1])
        dist.broadcast(tensor=labels, src=LEADER, group=group)

        # Handle loss and backprop at last block
        classification_logits = out[:, -1]
        
        # inputs [N, C] and targets [N]
        loss = torch.nn.functional.cross_entropy(classification_logits, labels)
        loss.backward()
        print(loss.detach().item())

#     if rank == 0:
#         start_time = time.time()


#     model = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(device)

#     if SHARD_OPTIMIZER_STATE:
#         params_to_optimize = []
#         for i, param in enumerate(model.parameters()):
#             if i % size == 0:
#                 params_to_optimize.append(param)
#         optimizer = torch.optim.Adam(params_to_optimize, lr=1e-5)
#     else:
#         optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

#     # Iterate over minibatches
#     model.train()
#     for epoch in range(4):
#         print(epoch)
#         ddl = DistributedDataLoader(rank, len(DEVICES), 375, random_seed = epoch)
#         for minibatch_data in ddl:
#             optimizer.zero_grad()
#             # Normal training loop
#             # FIXME
#             minibatch_data = {'input_ids': minibatch_data,
#                             'attention_mask': torch.ones_like(minibatch_data, dtype=torch.long)}
#             outputs = model(**minibatch_data, labels=minibatch_data['input_ids']) 
#             loss = outputs.loss
#             loss.backward()
#             print(loss.detach())
#             # All-reduce to share gradients, for each parameter
#             for param in model.parameters():
#                 old_grad = param.grad.detach().clone()
#                 # Taking the mean over the gradients
#                 dist.all_reduce(param.grad, dist.ReduceOp.SUM)
#                 # assert not torch.allclose(old_grad, param.grad.detach())
#                 param.grad = param.grad / size
#             # Does it take real long? Maybe time optimizer.step() and dist.broadcast, and compare them
#             optimizer.step()
#             if SHARD_OPTIMIZER_STATE:
#                 for i, param in enumerate(model.parameters()):
#                     dist.broadcast(param.data, src=i % 3)

#     print('Training completed')
#     loss = 0.

#     ddl = DistributedDataLoader(rank, len(DEVICES), 32, random_seed = epoch)

#     if rank == 0:
#         model.eval()
#         test_data = ddl.test_dataloader
#         c = 0
#         for test_datum in test_data:
#             test_datum = ddl.tokenize(test_datum)
#             test_datum = {'input_ids': test_datum,
#                         'attention_mask': torch.ones_like(test_datum, dtype=torch.long)}
#             outputs = model(**test_datum, labels=test_datum['input_ids']) 
#             loss += outputs.loss.detach()
#             c +=1
#             if c % 10 == 0:
#                 print(loss)
#             if c > 100:
#                 break    
#         print("eval loss: ", loss / len(test_data) / c)  

#         print("time: ", time.time() - start_time)          
    # After 4 epochs evaluate on test set

def init_process(rank, size, fn, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29503'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)
    device = DEVICES[rank]


In [None]:
size = len(DEVICES)
processes = []
# mp.set_start_method("spawn")
for rank in range(size):
    p = mp.Process(target=init_process, args=(rank, size, run, "gloo"))
    p.start()
    processes.append(p)

for p in processes:
    p.join() 