In [1]:
import os
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from datasets import load_dataset

from fairscale.nn.model_parallel.initialize import (
    initialize_model_parallel,
    model_parallel_is_initialized,
)
from llama.tokenizer import Tokenizer

from mixture_of_depths.routing_transformer import ModelArgs, MoDTransformer

In [2]:
if not torch.distributed.is_initialized():
    # torch.distributed.init_process_group("nccl")
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'
    torch.distributed.init_process_group(backend='gloo', rank=0, world_size=1)
if not model_parallel_is_initialized():
    model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
    initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
torch.manual_seed(42)
torch.set_default_tensor_type(torch.cuda.FloatTensor)

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


  _C._set_default_tensor_type(t)


## Load Data

In [3]:
dataset = load_dataset(
    'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation'
)

In [4]:
dataset

Dataset({
    features: ['text', 'timestamp', 'url'],
    num_rows: 45576
})

In [5]:
tokenizer = Tokenizer("../llama/tokenizer.model")
def collate_fn(batch):
    bsz = len(batch)
    tokenized_texts = [tokenizer.encode(x['text'], bos=True, eos=True) for x in batch]
    max_text_len = max(len(t) for t in tokenized_texts)

    pad_id = tokenizer.eos_id
    tokens = torch.full((bsz, min(2048, max_text_len)), pad_id, dtype=torch.long)
    for k, t in enumerate(tokenized_texts):
        tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)[:2048]
    
    return tokens[:,:-1], tokens[:,1:]

dataloader = DataLoader(
    dataset,
    batch_size=4,
    collate_fn=collate_fn,
)

In [6]:
next(iter(dataloader))

(tensor([[    1,   450,  6114,  ...,     2,     2,     2],
         [    1,  4473,  6751,  ...,     2,     2,     2],
         [    1, 26871, 17101,  ...,  4404,   331, 29889],
         [    1,   382,  5348,  ...,     2,     2,     2]]),
 tensor([[  450,  6114,  1058,  ...,     2,     2,     2],
         [ 4473,  6751,  1788,  ...,     2,     2,     2],
         [26871, 17101,   379,  ...,   331, 29889,     2],
         [  382,  5348,   399,  ...,     2,     2,     2]]))

## Train

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [8]:
class MoDLlamaTrainer():
    def __init__(self, params, model, dataloader):
        self.params = params
        self.model = model
        self.dataloader = dataloader

    def train(
        self,
        epochs=10,
        lr=1e-5,
        model_path="./models/MoDLlama.pt",
        log_path="./logs/MoDLlama_log.txt",
        use_aux_loss=False,
        use_aux_predictor=False,
        log_steps=1000,
    ):  
        self.model.train()

        min_loss = float("inf")
        criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_id)
        bce_criterion = nn.BCEWithLogitsLoss()
        
        if use_aux_predictor:
            parameters = []
            aux_parameters = []
            for name, param in self.model.named_parameters():
                if name.startswith("aux_router"):
                    aux_parameters.append(param)
                else:
                    parameters.append(param)
            optimizer = torch.optim.AdamW(parameters, lr)
            aux_optimizer = torch.optim.AdamW(aux_parameters, 1e-3)
        else:
            optimizer = torch.optim.AdamW(self.model.parameters(), lr)
        optimizer = torch.optim.AdamW(self.model.parameters(), lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs*len(dataloader))

        for epoch in range(epochs):
            running_loss = 0.0
            running_causal_loss = 0.0

            correct = 0
            total = 0
            for i, (inputs, targets) in enumerate(tqdm(self.dataloader, desc=f"Epoch: {epoch}")):
                optimizer.zero_grad()

                outputs = self.model(inputs, start_pos=0)

                loss = criterion(outputs['output'].permute(0, 2, 1), targets)

                causal_loss = 0.0
                if use_aux_loss or use_aux_predictor:
                    # compute auxiliary loss
                    weights = outputs['aux_weights'] if use_aux_predictor else outputs['token_weights']
                    token_weights = torch.stack(weights).flatten(0,1)
                    aux_targets = torch.zeros_like(token_weights)
                    batches = torch.arange(token_weights.size(0)).unsqueeze(-1)
                    aux_targets[batches, torch.stack(outputs['topk_indices']).flatten(0,1)] = 1.0
                    aux_targets = aux_targets.flatten()
                    causal_loss = bce_criterion(token_weights.flatten().to("cuda"), aux_targets.to("cuda"))
                    running_causal_loss += causal_loss.detach().cpu().item()
                    
                    if use_aux_predictor:
                        # measure accuracy during training
                        k = min(token_weights.size(-1), self.params.capacity)
                        pred_indices = torch.topk(token_weights, k=k, sorted=False).indices
                        preds = torch.zeros_like(token_weights)
                        preds[batches, pred_indices] = 1.0
                        correct += (preds.flatten() == aux_targets).sum()  
                        total += len(aux_targets)

                loss += causal_loss
                loss.backward()
                optimizer.step()
                scheduler.step()
                nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)

                if use_aux_predictor:
                    aux_optimizer.step()

                running_loss += loss.detach().cpu().item()

                if (i+1) % log_steps == 0:
                    avg_loss = running_loss / (i+1)
                    print(f"Loss at step {i+1}: {avg_loss}")
                    if use_aux_loss:
                        avg_causal_loss = running_causal_loss / (i+1)
                        print(f"Causal Loss at step {i+1}: {avg_causal_loss}")
                    if use_aux_predictor:
                        accuracy = correct / total
                        correct, total = 0, 0
                        print(f"Token Predictor Accuracy at step {i+1}: {accuracy}")

            epoch_loss = running_loss / len(self.dataloader)
            if min_loss > epoch_loss:
                torch.save(self.model.state_dict(), model_path)
                min_loss = epoch_loss
            
            print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss}")
            with open(log_path, 'w') as f:
                f.write(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss}")
                f.write("\n")

### With MoD and Auxiliary Loss

In [9]:
model_params = ModelArgs(
    dim=512,
    n_layers=6,
    n_heads=8,
    vocab_size=tokenizer.n_words,
    routing=True,
    aux_loss=True
)
model = MoDTransformer(model_params)

In [10]:
count_parameters(model)

53222400

In [12]:
trainer = MoDLlamaTrainer(model_params, model, dataloader)

In [13]:
%%time
trainer.train(
    epochs=1,
    use_aux_loss=model_params.aux_loss,
    use_aux_predictor=model.params.aux_routing
)

Epoch: 0:   0%|          | 0/11394 [00:00<?, ?it/s]

Loss at step 1000: 6.328129431724548
Causal Loss at step 1000: 0.14201861082948744
Loss at step 2000: 5.206919669151306
Causal Loss at step 2000: 0.09980903305439279
Loss at step 3000: 4.755034089326858
Causal Loss at step 3000: 0.08180954160541296
Loss at step 4000: 4.48384180355072
Causal Loss at step 4000: 0.06814707864960655
Loss at step 5000: 4.307397578907013
Causal Loss at step 5000: 0.058223096687858925
Loss at step 6000: 4.171844494640827
Causal Loss at step 6000: 0.05103358306797842
Loss at step 7000: 4.0757166142123085
Causal Loss at step 7000: 0.0456944351373573
Loss at step 8000: 3.9912534810900686
Causal Loss at step 8000: 0.041461185942491284
Loss at step 9000: 3.927055069035954
Causal Loss at step 9000: 0.038154133067246424
Loss at step 10000: 3.8782797238588333
Causal Loss at step 10000: 0.035498045698471835
Loss at step 11000: 3.8336125680641695
Causal Loss at step 11000: 0.033279732182219794
Epoch 1/1 - Loss: 3.8147398159548715
CPU times: total: 2min 53s
Wall time: 2

### With MoD and Auxiliary Router

In [9]:
model_params = ModelArgs(
    dim=512,
    n_layers=6,
    n_heads=8,
    vocab_size=tokenizer.n_words,
    routing=True,
    aux_routing=True
)
model = MoDTransformer(model_params)
count_parameters(model)

53353728

In [10]:
trainer = MoDLlamaTrainer(model_params, model, dataloader)

In [11]:
%%time
trainer.train(
    epochs=1,
    model_path="./models/MoDLlama_predictor.pt",
    log_path="./logs/MoDLlama_predictor_log.txt",
    use_aux_loss=model_params.aux_loss,
    use_aux_predictor=model.params.aux_routing
)

Epoch: 0:   0%|          | 0/11394 [00:00<?, ?it/s]

Loss at step 1000: 6.632888051509857
Token Predictor Accuracy at step 1000: 0.831622838973999
Loss at step 2000: 5.561314790248871
Token Predictor Accuracy at step 2000: 0.8503028154373169
Loss at step 3000: 5.114324019749959
Token Predictor Accuracy at step 3000: 0.8806051015853882
Loss at step 4000: 4.840769713997841
Token Predictor Accuracy at step 4000: 0.8832852840423584
Loss at step 5000: 4.664250442028045
Token Predictor Accuracy at step 5000: 0.8767071962356567
Loss at step 6000: 4.533340272188187
Token Predictor Accuracy at step 6000: 0.8746333718299866
Loss at step 7000: 4.441962276424681
Token Predictor Accuracy at step 7000: 0.8692646622657776
Loss at step 8000: 4.360247699260712
Token Predictor Accuracy at step 8000: 0.8706242442131042
Loss at step 9000: 4.298070017523235
Token Predictor Accuracy at step 9000: 0.8725472092628479
Loss at step 10000: 4.25081290872097
Token Predictor Accuracy at step 10000: 0.8689330816268921
Loss at step 11000: 4.206770216616717
Token Predic

### Baseline

In [9]:
model_params = ModelArgs(
    dim=512,
    n_layers=6,
    n_heads=8,
    vocab_size=tokenizer.n_words,
    routing=False,
)
model = MoDTransformer(model_params)
count_parameters(model)

53221888

In [10]:
trainer = MoDLlamaTrainer(model_params, model, dataloader)

In [11]:
%%time
trainer.train(
    epochs=1,
    model_path="./models/BaselineLlama.pt",
    log_path="./logs/BaselineLlama_log.txt",
)

Epoch: 0:   0%|          | 0/11394 [00:00<?, ?it/s]

Loss at step 1000: 6.059305932760239
Loss at step 2000: 5.038956387758255
Loss at step 3000: 4.623529662529627
Loss at step 4000: 4.368253807425499
Loss at step 5000: 4.2002917053222655
Loss at step 6000: 4.071478405634562
Loss at step 7000: 3.981069858942713
Loss at step 8000: 3.9011246553510426
Loss at step 9000: 3.8405606058438617
Loss at step 10000: 3.794752025604248
Loss at step 11000: 3.7525986541726373
Epoch 1/1 - Loss: 3.734649644832099
CPU times: total: 4min 11s
Wall time: 27min 59s
