In [1]:
import os

import torch
from torch.utils.data import DataLoader

from datasets import load_dataset

from fairscale.nn.model_parallel.initialize import (
    initialize_model_parallel,
    model_parallel_is_initialized,
)
from llama.tokenizer import Tokenizer

from mixture_of_depths.routing_transformer import ModelArgs, MoDTransformer
from mixture_of_depths.train import MoDLlamaTrainer

In [2]:
if not torch.distributed.is_initialized():
    # torch.distributed.init_process_group("nccl")
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'
    torch.distributed.init_process_group(backend='gloo', rank=0, world_size=1)
if not model_parallel_is_initialized():
    model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
    initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
torch.manual_seed(42)
torch.set_default_tensor_type(torch.cuda.FloatTensor)

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


  _C._set_default_tensor_type(t)


## Load Data

In [3]:
dataset = load_dataset("roneneldan/TinyStories", split="validation")

Repo card metadata block was not found. Setting CardData to empty.


In [4]:
dataset

Dataset({
    features: ['text'],
    num_rows: 21990
})

In [5]:
tokenizer = Tokenizer("../llama/tokenizer.model")
def collate_fn(batch):
    bsz = len(batch)
    tokenized_texts = [tokenizer.encode(x['text'], bos=True, eos=True) for x in batch]
    max_text_len = max(len(t) for t in tokenized_texts)

    pad_id = tokenizer.eos_id
    tokens = torch.full((bsz, min(2048, max_text_len)), pad_id, dtype=torch.long)
    for k, t in enumerate(tokenized_texts):
        tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)[:2048]
    
    return tokens[:,:-1], tokens[:,1:]

dataloader = DataLoader(
    dataset,
    batch_size=4,
    collate_fn=collate_fn,
)

In [6]:
next(iter(dataloader))

(tensor([[    1,  1706,   327,  ...,     2,     2,     2],
         [    1,  9038,  2501,  ..., 29890,  3099, 29889],
         [    1,  9038,  2501,  ...,     2,     2,     2],
         [    1,  9038,  2501,  ...,     2,     2,     2]]),
 tensor([[ 1706,   327, 29889,  ...,     2,     2,     2],
         [ 9038,  2501,   263,  ...,  3099, 29889,     2],
         [ 9038,  2501,   263,  ...,     2,     2,     2],
         [ 9038,  2501,   263,  ...,     2,     2,     2]]))

## Train

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

### With MoD and Auxiliary Loss

In [8]:
model_params = ModelArgs(
    dim=512,
    n_layers=6,
    n_heads=8,
    vocab_size=tokenizer.n_words,
    routing=True,
    aux_loss=True
)
model = MoDTransformer(model_params)

In [9]:
count_parameters(model)

53222400

In [10]:
trainer = MoDLlamaTrainer(
    params=model_params,
    model=model,
    tokenizer=tokenizer,
    dataloader=dataloader
)

In [11]:
%%time
trainer.train(
    epochs=5,
    use_aux_loss=model_params.aux_loss,
    use_aux_predictor=model.params.aux_routing
)

Epoch: 0:  18%|█▊        | 1002/5498 [01:16<06:30, 11.51it/s]

Loss at step 1000: 7.006802941322326
Causal Loss at step 1000: 0.08809202676918358


Epoch: 0:  36%|███▋      | 2002/5498 [02:32<04:18, 13.54it/s]

Loss at step 2000: 5.918302805185318
Causal Loss at step 2000: 0.06418199791555526


Epoch: 0:  55%|█████▍    | 3002/5498 [03:47<03:00, 13.83it/s]

Loss at step 3000: 5.321991624275843
Causal Loss at step 3000: 0.05198440214487103


Epoch: 0:  73%|███████▎  | 4002/5498 [05:03<01:49, 13.64it/s]

Loss at step 4000: 4.9382692607045175
Causal Loss at step 4000: 0.04592174272070406


Epoch: 0:  91%|█████████ | 5002/5498 [06:22<00:45, 10.82it/s]

Loss at step 5000: 4.626036303830147
Causal Loss at step 5000: 0.04173516972358338


Epoch: 0: 100%|██████████| 5498/5498 [07:01<00:00, 13.05it/s]


Epoch 1/5 - Loss: 4.513263336026743


Epoch: 1:  18%|█▊        | 1002/5498 [01:15<06:33, 11.43it/s]

Loss at step 1000: 3.236232976078987
Causal Loss at step 1000: 0.0156562289170688


Epoch: 1:  36%|███▋      | 2002/5498 [02:31<04:13, 13.78it/s]

Loss at step 2000: 3.1664874161481857
Causal Loss at step 2000: 0.01575509571362636


Epoch: 1:  55%|█████▍    | 3002/5498 [03:47<03:02, 13.70it/s]

Loss at step 3000: 3.105056991259257
Causal Loss at step 3000: 0.015763346709737863


Epoch: 1:  73%|███████▎  | 4002/5498 [05:03<01:49, 13.65it/s]

Loss at step 4000: 3.066151744902134
Causal Loss at step 4000: 0.016194940210931236


Epoch: 1:  91%|█████████ | 5002/5498 [06:21<00:45, 10.89it/s]

Loss at step 5000: 3.002454901599884
Causal Loss at step 5000: 0.01648979005278088


Epoch: 1: 100%|██████████| 5498/5498 [07:00<00:00, 13.07it/s]


Epoch 2/5 - Loss: 2.990427528263743


Epoch: 2:  18%|█▊        | 1002/5498 [01:15<06:32, 11.46it/s]

Loss at step 1000: 2.80145192360878
Causal Loss at step 1000: 0.0125106017353246


Epoch: 2:  36%|███▋      | 2002/5498 [02:31<04:13, 13.78it/s]

Loss at step 2000: 2.768246438384056
Causal Loss at step 2000: 0.012678612090501702


Epoch: 2:  55%|█████▍    | 3002/5498 [03:46<03:03, 13.59it/s]

Loss at step 3000: 2.7416327044169106
Causal Loss at step 3000: 0.012900318554563759


Epoch: 2:  73%|███████▎  | 4002/5498 [05:03<01:49, 13.69it/s]

Loss at step 4000: 2.7311190115809443
Causal Loss at step 4000: 0.013410496817494276


Epoch: 2:  91%|█████████ | 5002/5498 [06:21<00:45, 10.83it/s]

Loss at step 5000: 2.6926740930080415
Causal Loss at step 5000: 0.013704256183304824


Epoch: 2: 100%|██████████| 5498/5498 [07:00<00:00, 13.08it/s]


Epoch 3/5 - Loss: 2.6920377413287255


Epoch: 3:  18%|█▊        | 1002/5498 [01:14<06:35, 11.36it/s]

Loss at step 1000: 2.6364756326675414
Causal Loss at step 1000: 0.011534790724166669


Epoch: 3:  36%|███▋      | 2002/5498 [02:30<04:15, 13.68it/s]

Loss at step 2000: 2.615027206659317
Causal Loss at step 2000: 0.011711884467717027


Epoch: 3:  55%|█████▍    | 3002/5498 [03:45<03:01, 13.75it/s]

Loss at step 3000: 2.600178484280904
Causal Loss at step 3000: 0.011988362297338124


Epoch: 3:  73%|███████▎  | 4002/5498 [05:01<01:51, 13.40it/s]

Loss at step 4000: 2.60011988979578
Causal Loss at step 4000: 0.012498831250646616


Epoch: 3:  91%|█████████ | 5002/5498 [06:19<00:45, 10.95it/s]

Loss at step 5000: 2.5716913623332975
Causal Loss at step 5000: 0.012757879426144064


Epoch: 3: 100%|██████████| 5498/5498 [06:58<00:00, 13.14it/s]


Epoch 4/5 - Loss: 2.5755449517634443


Epoch: 4:  18%|█▊        | 1002/5498 [01:14<06:26, 11.62it/s]

Loss at step 1000: 2.5731737401485444
Causal Loss at step 1000: 0.011133916068880353


Epoch: 4:  36%|███▋      | 2002/5498 [02:29<04:09, 14.02it/s]

Loss at step 2000: 2.5584543668627737
Causal Loss at step 2000: 0.011319278810551623


Epoch: 4:  55%|█████▍    | 3002/5498 [03:44<03:04, 13.53it/s]

Loss at step 3000: 2.549927330096563
Causal Loss at step 3000: 0.011631888091758203


Epoch: 4:  73%|███████▎  | 4002/5498 [05:00<01:48, 13.85it/s]

Loss at step 4000: 2.5555321508049964
Causal Loss at step 4000: 0.012148252352315467


Epoch: 4:  91%|█████████ | 5001/5498 [06:21<00:45, 10.99it/s]

Loss at step 5000: 2.53256622800827
Causal Loss at step 5000: 0.01245836284031393


Epoch: 4: 100%|██████████| 5498/5498 [06:59<00:00, 13.10it/s]


Epoch 5/5 - Loss: 2.5387222885478407
CPU times: total: 6min 7s
Wall time: 35min 2s


### With MoD and Auxiliary Router

In [9]:
model_params = ModelArgs(
    dim=512,
    n_layers=6,
    n_heads=8,
    vocab_size=tokenizer.n_words,
    routing=True,
    aux_routing=True
)
model = MoDTransformer(model_params)
count_parameters(model)

53353728

In [10]:
trainer = MoDLlamaTrainer(
    params=model_params,
    model=model,
    tokenizer=tokenizer,
    dataloader=dataloader
)

In [11]:
%%time
trainer.train(
    epochs=5,
    model_dir="./models/MoDLlama_predictor/",
    log_path="./logs/MoDLlama_predictor_log.txt",
    use_aux_loss=model_params.aux_loss,
    use_aux_predictor=model.params.aux_routing
)

Epoch: 0:   0%|          | 0/5498 [00:00<?, ?it/s]

Loss at step 1000: 6.924468765258789
Token Predictor Accuracy at step 1000: 0.9802144169807434
Loss at step 2000: 5.814654923915863
Token Predictor Accuracy at step 2000: 0.9834375381469727
Loss at step 3000: 5.205756803115209
Token Predictor Accuracy at step 3000: 0.9811303615570068
Loss at step 4000: 4.827160907268524
Token Predictor Accuracy at step 4000: 0.978736400604248
Loss at step 5000: 4.530920606350898
Token Predictor Accuracy at step 5000: 0.9693493843078613
Epoch 1/5 - Loss: 4.4247111710646925


Epoch: 1:   0%|          | 0/5498 [00:00<?, ?it/s]

Loss at step 1000: 3.2280124650001527
Token Predictor Accuracy at step 1000: 0.9690613746643066
Loss at step 2000: 3.170233298122883
Token Predictor Accuracy at step 2000: 0.9609028697013855
Loss at step 3000: 3.1131341084639232
Token Predictor Accuracy at step 3000: 0.9612652063369751
Loss at step 4000: 3.080255357682705
Token Predictor Accuracy at step 4000: 0.9577699899673462
Loss at step 5000: 3.0225194420576096
Token Predictor Accuracy at step 5000: 0.9523912072181702
Epoch 2/5 - Loss: 3.012032274486109


Epoch: 2:   0%|          | 0/5498 [00:00<?, ?it/s]

Loss at step 1000: 2.80890511071682
Token Predictor Accuracy at step 1000: 0.9658401012420654
Loss at step 2000: 2.776543844342232
Token Predictor Accuracy at step 2000: 0.964267909526825
Loss at step 3000: 2.751001624822617
Token Predictor Accuracy at step 3000: 0.9614712595939636
Loss at step 4000: 2.741134049206972
Token Predictor Accuracy at step 4000: 0.9604632258415222
Loss at step 5000: 2.7065152017354963
Token Predictor Accuracy at step 5000: 0.9627573490142822
Epoch 3/5 - Loss: 2.7054980094973153


Epoch: 3:   0%|          | 0/5498 [00:00<?, ?it/s]

Loss at step 1000: 2.6169092404842376
Token Predictor Accuracy at step 1000: 0.9725882411003113
Loss at step 2000: 2.5960776077508925
Token Predictor Accuracy at step 2000: 0.9690541625022888
Loss at step 3000: 2.583292514403661
Token Predictor Accuracy at step 3000: 0.9690254926681519
Loss at step 4000: 2.5858921572268008
Token Predictor Accuracy at step 4000: 0.9694963693618774
Loss at step 5000: 2.560883826804161
Token Predictor Accuracy at step 5000: 0.9632113575935364
Epoch 4/5 - Loss: 2.5656060975176413


Epoch: 4:   0%|          | 0/5498 [00:00<?, ?it/s]

Loss at step 1000: 2.551059379696846
Token Predictor Accuracy at step 1000: 0.9690996408462524
Loss at step 2000: 2.539133041083813
Token Predictor Accuracy at step 2000: 0.9656355977058411
Loss at step 3000: 2.533072443127632
Token Predictor Accuracy at step 3000: 0.9659383296966553
Loss at step 4000: 2.5428641163706778
Token Predictor Accuracy at step 4000: 0.9659696221351624
Loss at step 5000: 2.5200535717725754
Token Predictor Accuracy at step 5000: 0.9642000794410706
Epoch 5/5 - Loss: 2.5262841506410227
CPU times: total: 6min 37s
Wall time: 36min 52s


### Baseline

In [9]:
model_params = ModelArgs(
    dim=512,
    n_layers=6,
    n_heads=8,
    vocab_size=tokenizer.n_words,
    routing=False,
)
model = MoDTransformer(model_params)
count_parameters(model)

53221888

In [10]:
trainer = MoDLlamaTrainer(
    params=model_params,
    model=model,
    tokenizer=tokenizer,
    dataloader=dataloader
)

In [11]:
%%time
trainer.train(
    epochs=5,
    model_dir="./models/BaselineLlama/",
    log_path="./logs/BaselineLlama_log.txt",
)

Epoch: 0:   0%|          | 0/5498 [00:00<?, ?it/s]

Loss at step 1000: 6.869257176876068
Loss at step 2000: 5.80399141049385
Loss at step 3000: 5.199864533980687
Loss at step 4000: 4.810587356686592
Loss at step 5000: 4.494121722722054
Epoch 1/5 - Loss: 4.381307774953471


Epoch: 1:   0%|          | 0/5498 [00:00<?, ?it/s]

Loss at step 1000: 3.133148486495018
Loss at step 2000: 3.064842033982277
Loss at step 3000: 3.006342195947965
Loss at step 4000: 2.9690366214215755
Loss at step 5000: 2.9042758924722674
Epoch 2/5 - Loss: 2.8924330781268486


Epoch: 2:   0%|          | 0/5498 [00:00<?, ?it/s]

Loss at step 1000: 2.7258059195280073
Loss at step 2000: 2.690865971207619
Loss at step 3000: 2.6640439066092174
Loss at step 4000: 2.653085547119379
Loss at step 5000: 2.6118169929027557
Epoch 3/5 - Loss: 2.610617885939986


Epoch: 3:   0%|          | 0/5498 [00:00<?, ?it/s]

Loss at step 1000: 2.568033074617386
Loss at step 2000: 2.544290638566017
Loss at step 3000: 2.5286488727728527
Loss at step 4000: 2.5275957393050192
Loss at step 5000: 2.495972100639343
Epoch 4/5 - Loss: 2.4991861418317822


Epoch: 4:   0%|          | 0/5498 [00:00<?, ?it/s]

Loss at step 1000: 2.508699095606804
Loss at step 2000: 2.491215365052223
Loss at step 3000: 2.4812898535728456
Loss at step 4000: 2.485381192624569
Loss at step 5000: 2.4586771958351137
Epoch 5/5 - Loss: 2.4640085676575887
CPU times: total: 8min 39s
Wall time: 40min 16s
