|
| 1 | +import os |
| 2 | +from datasets import load_dataset |
| 3 | + |
| 4 | +import torch |
| 5 | +from transformers import LlamaForCausalLM, AutoTokenizer |
| 6 | +from torch.distributed._composable.fsdp import fully_shard |
| 7 | +import torch.distributed as dist |
| 8 | +from tqdm import tqdm |
| 9 | +from transformers.data import DataCollatorForSeq2Seq |
| 10 | +from transformers.models.llama.modeling_llama import LlamaDecoderLayer |
| 11 | +from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict |
| 12 | + |
| 13 | +from torchdata.stateful_dataloader import StatefulDataLoader |
| 14 | + |
| 15 | +from torchft import ( |
| 16 | + DistributedSampler, |
| 17 | + Manager, |
| 18 | + Optimizer, |
| 19 | + ProcessGroupBabyNCCL, |
| 20 | + ProcessGroupGloo, |
| 21 | +) |
| 22 | +from torchft.process_group import ft_init_device_mesh |
| 23 | + |
| 24 | +def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None, manager=None): |
| 25 | + |
| 26 | + if replica_group_size is None or sharding_group_size is None: |
| 27 | + raise ValueError("Both replica_group_size and sharding_group_size must be provided.") |
| 28 | + |
| 29 | + device = device or f"cuda" |
| 30 | + |
| 31 | + device_mesh = ft_init_device_mesh( |
| 32 | + device_type=device, |
| 33 | + mesh_shape=(replica_group_size, sharding_group_size), |
| 34 | + mesh_dim_names=("dp_replicate", "dp_shard"), |
| 35 | + replicate_dim=0, |
| 36 | + manager=manager, |
| 37 | + ) |
| 38 | + if device_mesh is None: |
| 39 | + raise RuntimeError("Failed to create a valid device mesh.") |
| 40 | + |
| 41 | + return device_mesh |
| 42 | + |
| 43 | +def parallelize_llama(model, mesh): |
| 44 | + sharding_conditions = [lambda m: isinstance(m, LlamaDecoderLayer)] |
| 45 | + |
| 46 | + for m in reversed(list(model.modules())): |
| 47 | + if any(c(m) for c in sharding_conditions): |
| 48 | + # fully_shard(m, mesh=mesh, reshard_after_forward=True) |
| 49 | + fully_shard(m, mesh=mesh) |
| 50 | + # fully_shard([model.model.embed_tokens, model.lm_head], mesh=mesh) |
| 51 | + fully_shard(model, mesh=mesh) |
| 52 | + |
| 53 | +def main(): |
| 54 | + REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) |
| 55 | + NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2)) |
| 56 | + NUM_REPLICAS = int(os.environ.get("NUM_REPLICAS", 2)) |
| 57 | + |
| 58 | + rank = int(os.environ.get("RANK", 0)) |
| 59 | + |
| 60 | + model_name = "Meta-Llama/Llama-3.2-1B-Instruct" |
| 61 | + tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 62 | + model = LlamaForCausalLM.from_pretrained(model_name) |
| 63 | + |
| 64 | + if not tokenizer.pad_token_id: |
| 65 | + tokenizer.pad_token_id = tokenizer.eos_token_id |
| 66 | + |
| 67 | + # If there is a mismatch between tokenizer vocab size and embedding matrix, |
| 68 | + # throw a warning and then expand the embedding matrix |
| 69 | + assert len(tokenizer) == model.get_input_embeddings().weight.shape[0] |
| 70 | + |
| 71 | + train_data = load_dataset("samsum", split="train") |
| 72 | + |
| 73 | + class SAMSumDataset(torch.utils.data.Dataset): |
| 74 | + def __init__(self, data, tokenizer): |
| 75 | + self.data = data |
| 76 | + self.tokenizer = tokenizer |
| 77 | + def __getitem__(self, idx): |
| 78 | + text = self.data[idx] |
| 79 | + prompt = self.tokenizer.encode(tokenizer.bos_token + f"Summarize this dialog: {text['dialogue']}\n---\nSummary: ", add_special_tokens=False) |
| 80 | + summary = self.tokenizer.encode(text["summary"] + self.tokenizer.eos_token, add_special_tokens=False) |
| 81 | + input_ids = prompt + summary |
| 82 | + labels = len(prompt) * [-100] + summary |
| 83 | + return {"input_ids": input_ids, "labels": labels} |
| 84 | + def __len__(self): |
| 85 | + return len(self.data) |
| 86 | + |
| 87 | + |
| 88 | + train_dataset = SAMSumDataset(train_data, tokenizer) |
| 89 | + |
| 90 | + batch_size = 8 |
| 91 | + |
| 92 | + sampler = DistributedSampler( |
| 93 | + train_dataset, |
| 94 | + replica_group=REPLICA_GROUP_ID, |
| 95 | + num_replica_groups=NUM_REPLICA_GROUPS, |
| 96 | + rank=rank, |
| 97 | + shuffle=True, |
| 98 | + num_replicas=NUM_REPLICAS, |
| 99 | + ) |
| 100 | + |
| 101 | + train_dataloader = StatefulDataLoader(train_dataset, batch_size=batch_size, collate_fn=DataCollatorForSeq2Seq(tokenizer), sampler=sampler) |
| 102 | + |
| 103 | + def load_state_dict(state_dict): |
| 104 | + set_state_dict( |
| 105 | + model, |
| 106 | + optimizer.optim, |
| 107 | + model_state_dict=state_dict["model"], |
| 108 | + optim_state_dict=state_dict["optim"], |
| 109 | + ) |
| 110 | + |
| 111 | + |
| 112 | + def state_dict(): |
| 113 | + model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer.optim) |
| 114 | + return { |
| 115 | + "model": model_state_dict, |
| 116 | + "optim": optimizer_state_dict, |
| 117 | + } |
| 118 | + |
| 119 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 120 | + |
| 121 | + pg = ProcessGroupBabyNCCL() if torch.cuda.is_available() else ProcessGroupGloo() |
| 122 | + |
| 123 | + manager = Manager( |
| 124 | + pg=pg, |
| 125 | + min_replica_size=1, |
| 126 | + load_state_dict=load_state_dict, |
| 127 | + state_dict=state_dict, |
| 128 | + replica_id=f"train_fsdp_{REPLICA_GROUP_ID}", |
| 129 | + use_async_quorum=False, |
| 130 | + ) |
| 131 | + |
| 132 | + mesh = hsdp_device_mesh(NUM_REPLICA_GROUPS, NUM_REPLICAS, "cuda" if torch.cuda.is_available() else "cpu", manager=manager) |
| 133 | + |
| 134 | + parallelize_llama(model, mesh) |
| 135 | + |
| 136 | + model.to(device) |
| 137 | + |
| 138 | + optimizer = Optimizer(manager, torch.optim.Adam(model.parameters(), lr=1e-5)) |
| 139 | + |
| 140 | + while manager.current_step() < 500: |
| 141 | + model.train() |
| 142 | + for batch in tqdm(train_dataloader): |
| 143 | + input_ids = batch["input_ids"].to(device) |
| 144 | + labels = batch["labels"].to(device) |
| 145 | + optimizer.zero_grad() |
| 146 | + |
| 147 | + outputs = model(input_ids, labels=labels) |
| 148 | + loss = outputs.loss |
| 149 | + loss.backward() |
| 150 | + optimizer.step() |
| 151 | + |
| 152 | + if manager.current_step() % 100 == 0: |
| 153 | + print(f"[{manager.current_step()}] loss = {loss.item()}") |
| 154 | + |
| 155 | + |
| 156 | +if __name__ == "__main__": |
| 157 | + main() |
0 commit comments