In [1]:
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from datasets import load_dataset
from llama.tokenizer import Tokenizer

from mixture_of_depths.routing_transformer import ModelArgs, MoDTransformer
from mixture_of_depths.train import MoDLlamaTrainer
from mixture_of_depths.utils import set_up_env

random_state = 88
set_up_env(random_state)

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1




In [2]:
dataset = load_dataset("igorktech/anekdots", split="train")
num_of_samples = 20000
dataset = dataset.select(range(num_of_samples))

In [3]:
tokenizer = Tokenizer("tokenizer.model") # follow instruction to load 
text_length_max = 64
def collate_fn(batch):
    bsz = len(batch)
    tokenized_texts = [tokenizer.encode(x['text'], bos=True, eos=True) for x in batch]
    max_text_len = max(len(t) for t in tokenized_texts)

    pad_id = tokenizer.eos_id
    tokens = torch.full((bsz, min(text_length_max, max_text_len)), pad_id, dtype=torch.long)
    for k, t in enumerate(tokenized_texts):
        tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)[:text_length_max]
    
    return tokens[:,:-1], tokens[:,1:]

dataloader = DataLoader(
    dataset,
    batch_size=8,
    collate_fn=collate_fn,
)

### MoD

In [6]:
model_params = ModelArgs(
    dim=256,
    n_layers=4,
    n_heads=4,
    vocab_size=tokenizer.n_words,
    routing=True,
    aux_loss=True
)
model = MoDTransformer(model_params)
writer = SummaryWriter(log_dir='runs/MoD')
print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

19.794432 M parameters


In [7]:
trainer = MoDLlamaTrainer(
    params=model_params,
    model=model,
    tokenizer=tokenizer,
    dataloader=dataloader
)

In [8]:
trainer.train(
    epochs=5,
    lr=3e-4,
    use_aux_loss=model_params.aux_loss,
    use_aux_predictor=model.params.aux_routing, 
    model_dir='models/MoDLlama',
    writer=writer
)

Epoch: 1:  40%|████      | 1001/2500 [02:34<03:50,  6.50it/s]

Loss at step 1000: 4.983256434202194
Causal Loss at step 1000: 0.0024631750199141608


Epoch: 1:  80%|████████  | 2001/2500 [05:09<01:18,  6.32it/s]

Loss at step 2000: 4.061281044125557
Causal Loss at step 2000: 0.0013510382889885477


Epoch: 1: 100%|██████████| 2500/2500 [06:25<00:00,  6.48it/s]


Epoch 1/5 - Loss: 3.7215605016708375


Epoch: 2:  40%|████      | 1001/2500 [02:34<03:53,  6.43it/s]

Loss at step 1000: 2.3142513897418975
Causal Loss at step 1000: 7.633750385866734e-05


Epoch: 2:  80%|████████  | 2001/2500 [05:08<01:19,  6.27it/s]

Loss at step 2000: 2.3842985979914664
Causal Loss at step 2000: 8.665982752245327e-05


Epoch: 2: 100%|██████████| 2500/2500 [06:27<00:00,  6.45it/s]


Epoch 2/5 - Loss: 2.4195097153663636


Epoch: 3:  40%|████      | 1001/2500 [02:36<03:56,  6.33it/s]

Loss at step 1000: 2.2019019101858137
Causal Loss at step 1000: 0.000135026712607214


Epoch: 3:  80%|████████  | 2001/2500 [05:08<01:13,  6.79it/s]

Loss at step 2000: 2.0999294768571852
Causal Loss at step 2000: 0.0001003341482801261


Epoch: 3: 100%|██████████| 2500/2500 [06:24<00:00,  6.51it/s]


Epoch 3/5 - Loss: 2.081954874229431


Epoch: 4:  40%|████      | 1001/2500 [02:34<03:44,  6.67it/s]

Loss at step 1000: 1.9111106233596802
Causal Loss at step 1000: 2.0166175017038767e-05


Epoch: 4:  80%|████████  | 2001/2500 [05:07<01:12,  6.89it/s]

Loss at step 2000: 1.8612041330337525
Causal Loss at step 2000: 1.9162412942023367e-05


Epoch: 4: 100%|██████████| 2500/2500 [06:23<00:00,  6.52it/s]


Epoch 4/5 - Loss: 1.8424236430168153


Epoch: 5:  40%|████      | 1001/2500 [02:34<03:59,  6.26it/s]

Loss at step 1000: 1.6624065164327622
Causal Loss at step 1000: 1.0917767894625286e-05


Epoch: 5:  80%|████████  | 2001/2500 [05:06<01:18,  6.40it/s]

Loss at step 2000: 1.6389817885160447
Causal Loss at step 2000: 1.1160827034245812e-05


Epoch: 5: 100%|██████████| 2500/2500 [06:25<00:00,  6.48it/s]


Epoch 5/5 - Loss: 1.6381528455257415


### Vanilla Llama

In [4]:
model_params = ModelArgs(
    dim=256,
    n_layers=4,
    n_heads=4,
    vocab_size=tokenizer.n_words,
    routing=False,
    aux_loss=False
)
writer = SummaryWriter(log_dir='runs/Llama')
model = MoDTransformer(model_params)
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

19.794176 M parameters


In [5]:
trainer = MoDLlamaTrainer(
    params=model_params,
    model=model,
    tokenizer=tokenizer,
    dataloader=dataloader
)

In [6]:
trainer.train(
    epochs=5,
    lr=3e-4,
    use_aux_loss=model_params.aux_loss,
    use_aux_predictor=model.params.aux_routing, 
    model_dir='models/Llama',
    writer=writer
)

Epoch: 1:  40%|████      | 1001/2500 [02:23<03:28,  7.20it/s]

Loss at step 1000: 5.0363804886341095


Epoch: 1:  80%|████████  | 2001/2500 [04:49<01:14,  6.66it/s]

Loss at step 2000: 4.484910386323929


Epoch: 1: 100%|██████████| 2500/2500 [06:01<00:00,  6.91it/s]


Epoch 1/5 - Loss: 4.320811706161499


Epoch: 2:  40%|████      | 1001/2500 [02:24<03:33,  7.02it/s]

Loss at step 1000: 3.4606797211170197


Epoch: 2:  80%|████████  | 2001/2500 [04:48<01:08,  7.32it/s]

Loss at step 2000: 3.3693976508378984


Epoch: 2: 100%|██████████| 2500/2500 [06:02<00:00,  6.90it/s]


Epoch 2/5 - Loss: 3.334313145637512


Epoch: 3:  40%|████      | 1001/2500 [02:27<03:44,  6.68it/s]

Loss at step 1000: 3.0729157400131224


Epoch: 3:  80%|████████  | 2001/2500 [04:52<01:08,  7.33it/s]

Loss at step 2000: 3.012603771328926


Epoch: 3: 100%|██████████| 2500/2500 [06:05<00:00,  6.85it/s]


Epoch 3/5 - Loss: 2.9882658170700074


Epoch: 4:  40%|████      | 1001/2500 [02:30<04:00,  6.24it/s]

Loss at step 1000: 2.8174964728355407


Epoch: 4:  80%|████████  | 2001/2500 [05:03<01:12,  6.91it/s]

Loss at step 2000: 2.768199549973011


Epoch: 4: 100%|██████████| 2500/2500 [06:20<00:00,  6.58it/s]


Epoch 4/5 - Loss: 2.749541953134537


Epoch: 5:  40%|████      | 1001/2500 [02:36<03:56,  6.35it/s]

Loss at step 1000: 2.656064735889435


Epoch: 5:  80%|████████  | 2001/2500 [05:07<01:14,  6.67it/s]

Loss at step 2000: 2.626405920088291


Epoch: 5: 100%|██████████| 2500/2500 [06:23<00:00,  6.52it/s]


Epoch 5/5 - Loss: 2.6183425469875337
