In [1]:
import os
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from datasets import load_dataset

from fairscale.nn.model_parallel.initialize import (
    initialize_model_parallel,
    model_parallel_is_initialized,
)
from llama.tokenizer import Tokenizer

from mixture_of_depths.routing_transformer import ModelArgs, MoDTransformer

In [2]:
if not torch.distributed.is_initialized():
    # torch.distributed.init_process_group("nccl")
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'
    torch.distributed.init_process_group(backend='gloo', rank=0, world_size=1)
if not model_parallel_is_initialized():
    model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
    initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
torch.manual_seed(42)
torch.set_default_tensor_type(torch.cuda.FloatTensor)

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


  _C._set_default_tensor_type(t)


## Load Data

In [3]:
dataset = load_dataset(
    'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation'
)

In [4]:
dataset

Dataset({
    features: ['text', 'timestamp', 'url'],
    num_rows: 45576
})

In [5]:
tokenizer = Tokenizer("../llama/tokenizer.model")
def collate_fn(batch):
    bsz = len(batch)
    tokenized_texts = [tokenizer.encode(x['text'], bos=True, eos=True) for x in batch]
    max_text_len = max(len(t) for t in tokenized_texts)

    pad_id = tokenizer.eos_id
    tokens = torch.full((bsz, min(2048, max_text_len)), pad_id, dtype=torch.long)
    for k, t in enumerate(tokenized_texts):
        tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)[:2048]
    
    return tokens[:,:-1], tokens[:,1:]

dataloader = DataLoader(
    dataset,
    batch_size=4,
    collate_fn=collate_fn,
)

In [6]:
next(iter(dataloader))

(tensor([[    1,   450,  6114,  ...,     2,     2,     2],
         [    1,  4473,  6751,  ...,     2,     2,     2],
         [    1, 26871, 17101,  ...,  4404,   331, 29889],
         [    1,   382,  5348,  ...,     2,     2,     2]]),
 tensor([[  450,  6114,  1058,  ...,     2,     2,     2],
         [ 4473,  6751,  1788,  ...,     2,     2,     2],
         [26871, 17101,   379,  ...,   331, 29889,     2],
         [  382,  5348,   399,  ...,     2,     2,     2]]))

## Train

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [8]:
class MoDLlamaTrainer():
    def __init__(self, model, dataloader):
        self.model = model
        self.dataloader = dataloader

    def train(
        self,
        epochs=10,
        lr=1e-5,
        model_path="./models/MoDLlama.pt",
        log_path="./logs/MoDLlama_log.txt",
        causal_loss=False,
        log_steps=1000,
    ):
        min_loss = float("inf")
        criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_id)
        bce_criterion = nn.BCEWithLogitsLoss()
        optimizer = torch.optim.AdamW(self.model.parameters(), lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs*len(dataloader))

        for epoch in range(epochs):
            self.model.train()
            running_loss = 0.0
            running_causal_loss = 0.0
            for i, (inputs, targets) in enumerate(tqdm(self.dataloader, desc=f"Epoch: {epoch}")):
                optimizer.zero_grad()

                outputs = self.model(inputs, start_pos=0)

                loss = criterion(outputs['output'].permute(0, 2, 1), targets)
                if causal_loss:
                    # compute auxiliary loss
                    token_weights = torch.stack(outputs['token_weights']).flatten(0,1)
                    aux_targets = torch.zeros_like(token_weights)
                    batches = torch.arange(token_weights.size(0)).unsqueeze(-1)
                    aux_targets[batches, torch.stack(outputs['topk_indices']).flatten(0,1)] = 1.0
                    causal_loss = bce_criterion(token_weights.flatten().to("cuda"), aux_targets.flatten().to("cuda"))

                loss += causal_loss
                loss.backward()
                optimizer.step()
                scheduler.step()
                nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)

                running_loss += loss.detach().cpu().item()
                running_causal_loss += causal_loss.detach().cpu().item()

                if (i+1) % log_steps == 0:
                    avg_loss = running_loss / (i+1)
                    print(f"Loss at step {i+1}: {avg_loss}")
                    if causal_loss:
                        avg_causal_loss = running_causal_loss / (i+1)
                        print(f"Causal Loss at step {i+1}: {avg_causal_loss}")

            epoch_loss = running_loss / len(self.dataloader)
            if min_loss > epoch_loss:
                torch.save(self.model.state_dict(), model_path)
                min_loss = epoch_loss
            
            print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss}")
            with open(log_path, 'w') as f:
                f.write(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss}")
                f.write("\n")

### With MoD and Auxiliary Loss

In [9]:
model_params = ModelArgs(
    dim=512,
    n_layers=6,
    n_heads=8,
    vocab_size=tokenizer.n_words,
    routing=True,
    aux_loss=True
)
model = MoDTransformer(model_params)

In [10]:
count_parameters(model)

53222401

In [11]:
trainer = MoDLlamaTrainer(model, dataloader)

In [12]:
%%time
trainer.train(
    epochs=1,
    causal_loss=model_params.aux_loss
)

Epoch: 0:   0%|          | 0/11394 [00:00<?, ?it/s]

Loss at step 1000: 6.3720275304317475
Causal Loss at step 1000: 0.1424809495024383
Loss at step 2000: 5.227943378686905
Causal Loss at step 2000: 0.1003828225969337
Loss at step 3000: 4.766925075292587
Causal Loss at step 3000: 0.080016283841338
Loss at step 4000: 4.4906316210329535
Causal Loss at step 4000: 0.06525151234818623
Loss at step 5000: 4.310698924183845
Causal Loss at step 5000: 0.055193195453938096
Loss at step 6000: 4.174149837970734
Causal Loss at step 6000: 0.048282769833652615
Loss at step 7000: 4.077880990794727
Causal Loss at step 7000: 0.04323345345704417
Loss at step 8000: 3.9933470735400913
Causal Loss at step 8000: 0.03922612870761077
Loss at step 9000: 3.929166636255052
Causal Loss at step 9000: 0.03613605445200422
Loss at step 10000: 3.88042227602005
Causal Loss at step 10000: 0.03365151501047658
Loss at step 11000: 3.835826164787466
Causal Loss at step 11000: 0.03159500788033686
Epoch 1/1 - Loss: 3.816949661413745
CPU times: total: 2min 46s
Wall time: 24min 13s

### With MoD and Auxiliary Router

### Baseline

In [9]:
model_params = ModelArgs(
    dim=512,
    n_layers=6,
    n_heads=8,
    vocab_size=tokenizer.n_words,
    routing=False,
)
model = MoDTransformer(model_params)

In [10]:
count_parameters(model)

53221888

In [11]:
trainer = MoDLlamaTrainer(model, dataloader)

In [12]:
%%time
trainer.train(
    epochs=1,
)

Epoch: 0:   0%|          | 0/11394 [00:00<?, ?it/s]

Loss as step 1000: 7.209186413764954
Loss as step 2000: 6.076255900025368
Loss as step 3000: 5.418029305934906
Loss as step 4000: 5.024323759019375
Loss as step 5000: 4.773851863384246
Loss as step 6000: 4.5845877949992815
Loss as step 7000: 4.4457164600746975
Loss as step 8000: 4.323944765463471
Loss as step 9000: 4.226143040763008
Loss as step 10000: 4.146285073900223
Loss as step 11000: 4.072505737889896
Epoch 1/1 - Loss: 4.042734246672884
CPU times: total: 4min 41s
Wall time: 32min 33s
