In [1]:
import os

import torch
from torch.utils.data import DataLoader

from datasets import load_dataset

from fairscale.nn.model_parallel.initialize import (
    initialize_model_parallel,
    model_parallel_is_initialized,
)
from llama.tokenizer import Tokenizer

from mixture_of_depths.routing_transformer import ModelArgs, MoDTransformer
from mixture_of_depths.train import MoDLlamaTrainer

In [2]:
if not torch.distributed.is_initialized():
    # torch.distributed.init_process_group("nccl")
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12345'
    torch.distributed.init_process_group(backend='gloo', rank=0, world_size=1)
if not model_parallel_is_initialized():
    model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
    initialize_model_parallel(model_parallel_size)

torch.manual_seed(42)

> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1




<torch._C.Generator at 0x132297290>

## Load Data

In [3]:
dataset = load_dataset("igorktech/anekdots", split="train")
num_of_samples = 1000
dataset = dataset.select(range(10000))

In [4]:
tokenizer = Tokenizer("tokenizer.model")
text_length_max = 4
def collate_fn(batch):
    bsz = len(batch)
    tokenized_texts = [tokenizer.encode(x['text'], bos=True, eos=True) for x in batch]
    max_text_len = max(len(t) for t in tokenized_texts)

    pad_id = tokenizer.eos_id
    tokens = torch.full((bsz, min(text_length_max, max_text_len)), pad_id, dtype=torch.long)
    for k, t in enumerate(tokenized_texts):
        tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)[:text_length_max]
    
    return tokens[:,:-1], tokens[:,1:]

dataloader = DataLoader(
    dataset,
    batch_size=4,
    collate_fn=collate_fn,
)

### MoD

In [5]:
model_params = ModelArgs(
    dim=256,
    n_layers=4,
    n_heads=4,
    vocab_size=tokenizer.n_words,
    routing=True,
    aux_loss=True
)
model = MoDTransformer(model_params)
print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

19.794432 M parameters


In [6]:
trainer = MoDLlamaTrainer(
    params=model_params,
    model=model,
    tokenizer=tokenizer,
    dataloader=dataloader
)

In [8]:
%%time
trainer.train(
    epochs=5,
    use_aux_loss=model_params.aux_loss,
    use_aux_predictor=model.params.aux_routing, 
    model_dir='models/MoDLlama'
    #device='mps'
)

Epoch: 1:   4%|▎         | 88/2500 [00:04<01:53, 21.22it/s]


KeyboardInterrupt: 

### Baseline

In [8]:
model_params = ModelArgs(
    dim=128,
    n_layers=4,
    n_heads=4,
    vocab_size=tokenizer.n_words,
    routing=False,
    aux_loss=False
)

model = MoDTransformer(model_params)
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

53221888

In [9]:
trainer = MoDLlamaTrainer(
    params=model_params,
    model=model,
    tokenizer=tokenizer,
    dataloader=dataloader
)

In [10]:
%%time
trainer.train(
    epochs=5,
    model_dir="./models/BaselineLlama/",
    log_path="./logs/BaselineLlama_log.txt",
)

Epoch: 1:  18%|█▊        | 1001/5498 [01:21<08:10,  9.17it/s]

Loss at step 1000: 6.869257190704346


Epoch: 1:  36%|███▋      | 2001/5498 [02:51<05:13, 11.17it/s]

Loss at step 2000: 5.803991420865059


Epoch: 1:  55%|█████▍    | 3002/5498 [04:21<03:28, 11.97it/s]

Loss at step 3000: 5.199864539782206


Epoch: 1:  73%|███████▎  | 4002/5498 [05:49<02:04, 12.03it/s]

Loss at step 4000: 4.810587359905243


Epoch: 1:  91%|█████████ | 5001/5498 [07:18<00:49, 10.10it/s]

Loss at step 5000: 4.494121723914146


Epoch: 1: 100%|██████████| 5498/5498 [07:57<00:00, 11.50it/s]


Epoch 1/5 - Loss: 4.3813077763411385


Epoch: 2:  18%|█▊        | 1001/5498 [01:14<07:05, 10.58it/s]

Loss at step 1000: 3.1331484850645066


Epoch: 2:  36%|███▋      | 2002/5498 [02:43<05:03, 11.54it/s]

Loss at step 2000: 3.0648420315980913


Epoch: 2:  55%|█████▍    | 3001/5498 [04:15<03:49, 10.87it/s]

Loss at step 3000: 3.0063421929279963


Epoch: 2:  73%|███████▎  | 4002/5498 [05:49<02:10, 11.47it/s]

Loss at step 4000: 2.9690366181135177


Epoch: 2:  91%|█████████ | 5001/5498 [07:26<00:59,  8.34it/s]

Loss at step 5000: 2.9042758903980257


Epoch: 2: 100%|██████████| 5498/5498 [08:14<00:00, 11.11it/s]


Epoch 2/5 - Loss: 2.892433075524972


Epoch: 3:  18%|█▊        | 1001/5498 [01:32<08:16,  9.05it/s]

Loss at step 1000: 2.7258059184551238


Epoch: 3:  36%|███▋      | 2002/5498 [03:03<05:14, 11.13it/s]

Loss at step 2000: 2.6908659694194794


Epoch: 3:  55%|█████▍    | 3002/5498 [04:31<03:30, 11.85it/s]

Loss at step 3000: 2.6640439071655275


Epoch: 3:  73%|███████▎  | 4002/5498 [06:00<02:05, 11.92it/s]

Loss at step 4000: 2.6530855478644373


Epoch: 3:  91%|█████████ | 5001/5498 [07:31<00:54,  9.16it/s]

Loss at step 5000: 2.611816993069649


Epoch: 3: 100%|██████████| 5498/5498 [08:16<00:00, 11.07it/s]


Epoch 3/5 - Loss: 2.6106178867422307


Epoch: 4:  18%|█▊        | 1001/5498 [01:27<07:43,  9.69it/s]

Loss at step 1000: 2.5680330770015716


Epoch: 4:  36%|███▋      | 2002/5498 [02:56<05:17, 11.01it/s]

Loss at step 2000: 2.544290637731552


Epoch: 4:  55%|█████▍    | 3001/5498 [04:23<03:32, 11.75it/s]

Loss at step 3000: 2.528648873925209


Epoch: 4:  73%|███████▎  | 4002/5498 [05:52<02:07, 11.70it/s]

Loss at step 4000: 2.527595740944147


Epoch: 4:  91%|█████████ | 5001/5498 [07:23<00:54,  9.12it/s]

Loss at step 5000: 2.4959721014022827


Epoch: 4: 100%|██████████| 5498/5498 [08:09<00:00, 11.24it/s]


Epoch 4/5 - Loss: 2.4991861420702874


Epoch: 5:  18%|█▊        | 1001/5498 [01:28<07:43,  9.71it/s]

Loss at step 1000: 2.508699094891548


Epoch: 5:  36%|███▋      | 2002/5498 [02:56<04:59, 11.66it/s]

Loss at step 2000: 2.491215362250805


Epoch: 5:  55%|█████▍    | 3002/5498 [04:24<03:30, 11.84it/s]

Loss at step 3000: 2.4812898528575897


Epoch: 5:  73%|███████▎  | 4002/5498 [05:52<02:06, 11.83it/s]

Loss at step 4000: 2.4853811900913714


Epoch: 5:  91%|█████████ | 5001/5498 [07:23<00:54,  9.06it/s]

Loss at step 5000: 2.4586771923065185


Epoch: 5: 100%|██████████| 5498/5498 [08:08<00:00, 11.25it/s]


Epoch 5/5 - Loss: 2.4640085644052436
CPU times: total: 13min 40s
Wall time: 40min 49s
