In [1]:
import os

import torch
from torch.utils.data import DataLoader

from datasets import load_dataset

from fairscale.nn.model_parallel.initialize import (
    initialize_model_parallel,
    model_parallel_is_initialized,
)
from llama.tokenizer import Tokenizer

from mixture_of_depths.routing_transformer import ModelArgs, MoDTransformer
from mixture_of_depths.train import MoDLlamaTrainer

In [2]:
if not torch.distributed.is_initialized():
    # torch.distributed.init_process_group("nccl")
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'
    torch.distributed.init_process_group(backend='gloo', rank=0, world_size=1)
if not model_parallel_is_initialized():
    model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
    initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
torch.manual_seed(42)
torch.set_default_tensor_type(torch.cuda.FloatTensor)

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1


  _C._set_default_tensor_type(t)


## Load Data

In [3]:
dataset = load_dataset("roneneldan/TinyStories", split="validation")

Repo card metadata block was not found. Setting CardData to empty.


In [4]:
dataset

Dataset({
    features: ['text'],
    num_rows: 21990
})

In [5]:
tokenizer = Tokenizer("../llama/tokenizer.model")
def collate_fn(batch):
    bsz = len(batch)
    tokenized_texts = [tokenizer.encode(x['text'], bos=True, eos=True) for x in batch]
    max_text_len = max(len(t) for t in tokenized_texts)

    pad_id = tokenizer.eos_id
    tokens = torch.full((bsz, min(2048, max_text_len)), pad_id, dtype=torch.long)
    for k, t in enumerate(tokenized_texts):
        tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)[:2048]
    
    return tokens[:,:-1], tokens[:,1:]

dataloader = DataLoader(
    dataset,
    batch_size=4,
    collate_fn=collate_fn,
)

In [6]:
next(iter(dataloader))

(tensor([[    1,  1706,   327,  ...,     2,     2,     2],
         [    1,  9038,  2501,  ..., 29890,  3099, 29889],
         [    1,  9038,  2501,  ...,     2,     2,     2],
         [    1,  9038,  2501,  ...,     2,     2,     2]]),
 tensor([[ 1706,   327, 29889,  ...,     2,     2,     2],
         [ 9038,  2501,   263,  ...,  3099, 29889,     2],
         [ 9038,  2501,   263,  ...,     2,     2,     2],
         [ 9038,  2501,   263,  ...,     2,     2,     2]]))

## Train

In [7]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

### With MoD and Auxiliary Loss

In [8]:
model_params = ModelArgs(
    dim=512,
    n_layers=6,
    n_heads=8,
    vocab_size=tokenizer.n_words,
    routing=True,
    aux_loss=True
)
model = MoDTransformer(model_params)
count_parameters(model)

53222400

In [12]:
trainer = MoDLlamaTrainer(
    params=model_params,
    model=model,
    tokenizer=tokenizer,
    dataloader=dataloader
)

In [13]:
%%time
trainer.train(
    epochs=5,
    use_aux_loss=model_params.aux_loss,
    use_aux_predictor=model.params.aux_routing
)

Epoch: 1:  18%|█▊        | 1001/5498 [01:23<08:12,  9.13it/s]

Loss at step 1000: 7.058915052413941
Causal Loss at step 1000: 0.08566971867182292


Epoch: 1:  36%|███▋      | 2001/5498 [03:02<05:32, 10.52it/s]

Loss at step 2000: 6.005800765275955
Causal Loss at step 2000: 0.06237742305314169


Epoch: 1:  55%|█████▍    | 3000/5498 [04:40<04:03, 10.27it/s]

Loss at step 3000: 5.407233722130457
Causal Loss at step 3000: 0.04932976132611899


Epoch: 1:  73%|███████▎  | 4000/5498 [06:16<02:24, 10.34it/s]

Loss at step 4000: 5.01274158936739
Causal Loss at step 4000: 0.04268608822947135


Epoch: 1:  91%|█████████ | 5001/5498 [07:55<00:56,  8.80it/s]

Loss at step 5000: 4.690546411371231
Causal Loss at step 5000: 0.03874449907559319


Epoch: 1: 100%|██████████| 5498/5498 [08:44<00:00, 10.48it/s]


Epoch 1/5 - Loss: 4.573528359568566


Epoch: 2:  18%|█▊        | 1001/5498 [01:36<08:20,  8.98it/s]

Loss at step 1000: 3.2603589498996737
Causal Loss at step 1000: 0.014048396761820186


Epoch: 2:  36%|███▋      | 2001/5498 [03:11<05:08, 11.33it/s]

Loss at step 2000: 3.188169597506523
Causal Loss at step 2000: 0.014249010106999777


Epoch: 2:  55%|█████▍    | 3001/5498 [04:46<03:20, 12.47it/s]

Loss at step 3000: 3.1245650848150253
Causal Loss at step 3000: 0.014287763172362853


Epoch: 2:  73%|███████▎  | 4002/5498 [06:06<01:57, 12.77it/s]

Loss at step 4000: 3.083756571292877
Causal Loss at step 4000: 0.014641977031242277


Epoch: 2:  91%|█████████ | 5002/5498 [07:26<00:46, 10.72it/s]

Loss at step 5000: 3.0176549889802935
Causal Loss at step 5000: 0.015028004694939591


Epoch: 2: 100%|██████████| 5498/5498 [08:05<00:00, 11.32it/s]


Epoch 2/5 - Loss: 3.0047075379410844


Epoch: 3:  18%|█▊        | 1001/5498 [01:24<08:08,  9.21it/s]

Loss at step 1000: 2.814422668814659
Causal Loss at step 1000: 0.01111296907128417


Epoch: 3:  36%|███▋      | 2000/5498 [03:01<05:29, 10.62it/s]

Loss at step 2000: 2.780482889652252
Causal Loss at step 2000: 0.01132837732146436


Epoch: 3:  55%|█████▍    | 3002/5498 [04:38<04:03, 10.27it/s]

Loss at step 3000: 2.7531852059761683
Causal Loss at step 3000: 0.011526678332971642


Epoch: 3:  73%|███████▎  | 4000/5498 [06:15<02:21, 10.56it/s]

Loss at step 4000: 2.7420962545573713
Causal Loss at step 4000: 0.011937062763237919


Epoch: 3:  91%|█████████ | 5001/5498 [07:52<00:55,  9.01it/s]

Loss at step 5000: 2.7021408375024794
Causal Loss at step 5000: 0.012376864799112082


Epoch: 3: 100%|██████████| 5498/5498 [08:41<00:00, 10.55it/s]


Epoch 3/5 - Loss: 2.7010415472434453


Epoch: 4:  18%|█▊        | 1001/5498 [01:37<08:06,  9.24it/s]

Loss at step 1000: 2.648483075261116
Causal Loss at step 1000: 0.00987689599065925


Epoch: 4:  36%|███▋      | 2000/5498 [03:12<08:07,  7.18it/s]

Loss at step 2000: 2.626298886656761
Causal Loss at step 2000: 0.010153757582142134


Epoch: 4:  55%|█████▍    | 3001/5498 [04:31<03:04, 13.53it/s]

Loss at step 3000: 2.6109603035052618
Causal Loss at step 3000: 0.010414849330942767


Epoch: 4:  73%|███████▎  | 4002/5498 [05:49<01:52, 13.33it/s]

Loss at step 4000: 2.6105089434981346
Causal Loss at step 4000: 0.01086059432811453


Epoch: 4:  91%|█████████ | 5002/5498 [07:09<00:46, 10.55it/s]

Loss at step 5000: 2.580741057062149
Causal Loss at step 5000: 0.011334059409727343


Epoch: 4: 100%|██████████| 5498/5498 [07:48<00:00, 11.74it/s]


Epoch 4/5 - Loss: 2.5843389183879117


Epoch: 5:  18%|█▊        | 1002/5498 [01:15<06:44, 11.12it/s]

Loss at step 1000: 2.587320025086403
Causal Loss at step 1000: 0.009450314964866266


Epoch: 5:  36%|███▋      | 2002/5498 [02:32<04:29, 12.99it/s]

Loss at step 2000: 2.571541560292244
Causal Loss at step 2000: 0.00974879532410705


Epoch: 5:  55%|█████▍    | 3002/5498 [03:49<03:03, 13.58it/s]

Loss at step 3000: 2.5622856976588566
Causal Loss at step 3000: 0.010035118803187894


Epoch: 5:  73%|███████▎  | 4002/5498 [05:06<01:52, 13.32it/s]

Loss at step 4000: 2.5672589445412157
Causal Loss at step 4000: 0.010508904797614377


Epoch: 5:  91%|█████████ | 5002/5498 [06:25<00:46, 10.72it/s]

Loss at step 5000: 2.54253048825264
Causal Loss at step 5000: 0.011013002326834248


Epoch: 5: 100%|██████████| 5498/5498 [07:04<00:00, 12.94it/s]


Epoch 5/5 - Loss: 2.5483269950353957
---------- Training Auxiliary Router ----------
CPU times: total: 15min 28s
Wall time: 40min 26s


### With MoD and Auxiliary Router

In [8]:
model_params = ModelArgs(
    dim=512,
    n_layers=6,
    n_heads=8,
    vocab_size=tokenizer.n_words,
    routing=True,
    aux_routing=True
)
model = MoDTransformer(model_params)
count_parameters(model)

53353728

In [9]:
trainer = MoDLlamaTrainer(
    params=model_params,
    model=model,
    tokenizer=tokenizer,
    dataloader=dataloader
)

In [10]:
%%time
trainer.train(
    epochs=5,
    model_dir="./models/MoDLlama_predictor/",
    log_path="./logs/MoDLlama_predictor_log.txt",
    use_aux_loss=model_params.aux_loss,
    use_aux_predictor=model.params.aux_routing
)

Epoch: 1:  18%|█▊        | 1001/5498 [01:15<06:28, 11.59it/s]

Loss at step 1000: 6.880225940465927


Epoch: 1:  36%|███▋      | 2001/5498 [02:30<04:15, 13.70it/s]

Loss at step 2000: 5.789195282936096


Epoch: 1:  55%|█████▍    | 3001/5498 [03:45<03:05, 13.49it/s]

Loss at step 3000: 5.182257438659668


Epoch: 1:  73%|███████▎  | 4001/5498 [05:01<01:51, 13.48it/s]

Loss at step 4000: 4.795879737734794


Epoch: 1:  91%|█████████ | 5001/5498 [06:19<00:45, 10.93it/s]

Loss at step 5000: 4.484475578093528


Epoch: 1: 100%|██████████| 5498/5498 [06:58<00:00, 13.14it/s]


Epoch 1/5 - Loss: 4.373087156465072


Epoch: 2:  18%|█▊        | 1002/5498 [01:15<06:36, 11.33it/s]

Loss at step 1000: 3.135460281252861


Epoch: 2:  36%|███▋      | 2002/5498 [02:30<04:20, 13.43it/s]

Loss at step 2000: 3.0683096895813944


Epoch: 2:  55%|█████▍    | 3002/5498 [03:46<03:03, 13.60it/s]

Loss at step 3000: 3.0107614399194715


Epoch: 2:  73%|███████▎  | 4002/5498 [05:02<01:49, 13.63it/s]

Loss at step 4000: 2.97463177767396


Epoch: 2:  91%|█████████ | 5002/5498 [06:21<00:45, 10.89it/s]

Loss at step 5000: 2.912122939181328


Epoch: 2: 100%|██████████| 5498/5498 [06:59<00:00, 13.11it/s]


Epoch 2/5 - Loss: 2.900725018310304


Epoch: 3:  18%|█▊        | 1002/5498 [01:15<06:35, 11.36it/s]

Loss at step 1000: 2.735691972732544


Epoch: 3:  36%|███▋      | 2002/5498 [02:31<04:15, 13.67it/s]

Loss at step 2000: 2.7026817091703417


Epoch: 3:  55%|█████▍    | 3002/5498 [03:46<03:00, 13.84it/s]

Loss at step 3000: 2.6766930611928306


Epoch: 3:  73%|███████▎  | 4002/5498 [05:02<01:50, 13.57it/s]

Loss at step 4000: 2.6669078089892864


Epoch: 3:  91%|█████████ | 5002/5498 [06:21<00:45, 10.90it/s]

Loss at step 5000: 2.627320266747475


Epoch: 3: 100%|██████████| 5498/5498 [07:01<00:00, 13.06it/s]


Epoch 3/5 - Loss: 2.626275461276864


Epoch: 4:  18%|█▊        | 1002/5498 [01:15<06:32, 11.47it/s]

Loss at step 1000: 2.5814638652801514


Epoch: 4:  36%|███▋      | 2002/5498 [02:31<04:14, 13.74it/s]

Loss at step 2000: 2.5593398686647415


Epoch: 4:  55%|█████▍    | 3002/5498 [03:46<03:01, 13.75it/s]

Loss at step 3000: 2.5444553532997767


Epoch: 4:  73%|███████▎  | 4002/5498 [05:02<01:49, 13.72it/s]

Loss at step 4000: 2.5444714041650296


Epoch: 4:  91%|█████████ | 5001/5498 [06:34<00:53,  9.27it/s]

Loss at step 5000: 2.5143767155885697


Epoch: 4: 100%|██████████| 5498/5498 [07:24<00:00, 12.38it/s]


Epoch 4/5 - Loss: 2.5176444780440623


Epoch: 5:  18%|█▊        | 1001/5498 [01:37<09:11,  8.15it/s]

Loss at step 1000: 2.5236118294000627


Epoch: 5:  36%|███▋      | 2002/5498 [03:11<05:17, 11.02it/s]

Loss at step 2000: 2.5076241349577906


Epoch: 5:  55%|█████▍    | 3001/5498 [04:46<03:49, 10.90it/s]

Loss at step 3000: 2.498526682774226


Epoch: 5:  73%|███████▎  | 4000/5498 [06:22<02:18, 10.82it/s]

Loss at step 4000: 2.5035794451236724


Epoch: 5:  91%|█████████ | 5001/5498 [08:00<00:56,  8.76it/s]

Loss at step 5000: 2.47829863512516


Epoch: 5: 100%|██████████| 5498/5498 [08:48<00:00, 10.39it/s]


Epoch 5/5 - Loss: 2.4835991757095575
---------- Training Auxiliary Router ----------


Epoch: 1:  18%|█▊        | 1003/5498 [00:26<02:21, 31.80it/s]

Token Predictor Accuracy at step 1000: 0.9636683464050293


Epoch: 1:  36%|███▋      | 2005/5498 [00:53<01:31, 38.34it/s]

Token Predictor Accuracy at step 2000: 0.9673848152160645


Epoch: 1:  55%|█████▍    | 3004/5498 [01:20<01:06, 37.58it/s]

Token Predictor Accuracy at step 3000: 0.9725164771080017


Epoch: 1:  73%|███████▎  | 4005/5498 [01:47<00:38, 38.61it/s]

Token Predictor Accuracy at step 4000: 0.9683014750480652


Epoch: 1:  91%|█████████ | 5006/5498 [02:15<00:14, 33.50it/s]

Token Predictor Accuracy at step 5000: 0.9649678468704224


Epoch: 1: 100%|██████████| 5498/5498 [02:28<00:00, 36.95it/s]


Epoch 1/1 for aux router - Causal Loss: 0.15164356773336618
CPU times: total: 8min 11s
Wall time: 39min 43s


### Baseline

In [8]:
model_params = ModelArgs(
    dim=512,
    n_layers=6,
    n_heads=8,
    vocab_size=tokenizer.n_words,
    routing=False,
)
model = MoDTransformer(model_params)
count_parameters(model)

53221888

In [9]:
trainer = MoDLlamaTrainer(
    params=model_params,
    model=model,
    tokenizer=tokenizer,
    dataloader=dataloader
)

In [10]:
%%time
trainer.train(
    epochs=5,
    model_dir="./models/BaselineLlama/",
    log_path="./logs/BaselineLlama_log.txt",
)

Epoch: 1:  18%|█▊        | 1001/5498 [01:21<08:10,  9.17it/s]

Loss at step 1000: 6.869257190704346


Epoch: 1:  36%|███▋      | 2001/5498 [02:51<05:13, 11.17it/s]

Loss at step 2000: 5.803991420865059


Epoch: 1:  55%|█████▍    | 3002/5498 [04:21<03:28, 11.97it/s]

Loss at step 3000: 5.199864539782206


Epoch: 1:  73%|███████▎  | 4002/5498 [05:49<02:04, 12.03it/s]

Loss at step 4000: 4.810587359905243


Epoch: 1:  91%|█████████ | 5001/5498 [07:18<00:49, 10.10it/s]

Loss at step 5000: 4.494121723914146


Epoch: 1: 100%|██████████| 5498/5498 [07:57<00:00, 11.50it/s]


Epoch 1/5 - Loss: 4.3813077763411385


Epoch: 2:  18%|█▊        | 1001/5498 [01:14<07:05, 10.58it/s]

Loss at step 1000: 3.1331484850645066


Epoch: 2:  36%|███▋      | 2002/5498 [02:43<05:03, 11.54it/s]

Loss at step 2000: 3.0648420315980913


Epoch: 2:  55%|█████▍    | 3001/5498 [04:15<03:49, 10.87it/s]

Loss at step 3000: 3.0063421929279963


Epoch: 2:  73%|███████▎  | 4002/5498 [05:49<02:10, 11.47it/s]

Loss at step 4000: 2.9690366181135177


Epoch: 2:  91%|█████████ | 5001/5498 [07:26<00:59,  8.34it/s]

Loss at step 5000: 2.9042758903980257


Epoch: 2: 100%|██████████| 5498/5498 [08:14<00:00, 11.11it/s]


Epoch 2/5 - Loss: 2.892433075524972


Epoch: 3:  18%|█▊        | 1001/5498 [01:32<08:16,  9.05it/s]

Loss at step 1000: 2.7258059184551238


Epoch: 3:  36%|███▋      | 2002/5498 [03:03<05:14, 11.13it/s]

Loss at step 2000: 2.6908659694194794


Epoch: 3:  55%|█████▍    | 3002/5498 [04:31<03:30, 11.85it/s]

Loss at step 3000: 2.6640439071655275


Epoch: 3:  73%|███████▎  | 4002/5498 [06:00<02:05, 11.92it/s]

Loss at step 4000: 2.6530855478644373


Epoch: 3:  91%|█████████ | 5001/5498 [07:31<00:54,  9.16it/s]

Loss at step 5000: 2.611816993069649


Epoch: 3: 100%|██████████| 5498/5498 [08:16<00:00, 11.07it/s]


Epoch 3/5 - Loss: 2.6106178867422307


Epoch: 4:  18%|█▊        | 1001/5498 [01:27<07:43,  9.69it/s]

Loss at step 1000: 2.5680330770015716


Epoch: 4:  36%|███▋      | 2002/5498 [02:56<05:17, 11.01it/s]

Loss at step 2000: 2.544290637731552


Epoch: 4:  55%|█████▍    | 3001/5498 [04:23<03:32, 11.75it/s]

Loss at step 3000: 2.528648873925209


Epoch: 4:  73%|███████▎  | 4002/5498 [05:52<02:07, 11.70it/s]

Loss at step 4000: 2.527595740944147


Epoch: 4:  91%|█████████ | 5001/5498 [07:23<00:54,  9.12it/s]

Loss at step 5000: 2.4959721014022827


Epoch: 4: 100%|██████████| 5498/5498 [08:09<00:00, 11.24it/s]


Epoch 4/5 - Loss: 2.4991861420702874


Epoch: 5:  18%|█▊        | 1001/5498 [01:28<07:43,  9.71it/s]

Loss at step 1000: 2.508699094891548


Epoch: 5:  36%|███▋      | 2002/5498 [02:56<04:59, 11.66it/s]

Loss at step 2000: 2.491215362250805


Epoch: 5:  55%|█████▍    | 3002/5498 [04:24<03:30, 11.84it/s]

Loss at step 3000: 2.4812898528575897


Epoch: 5:  73%|███████▎  | 4002/5498 [05:52<02:06, 11.83it/s]

Loss at step 4000: 2.4853811900913714


Epoch: 5:  91%|█████████ | 5001/5498 [07:23<00:54,  9.06it/s]

Loss at step 5000: 2.4586771923065185


Epoch: 5: 100%|██████████| 5498/5498 [08:08<00:00, 11.25it/s]


Epoch 5/5 - Loss: 2.4640085644052436
CPU times: total: 13min 40s
Wall time: 40min 49s
