Skip to content

Commit 1157e1b

Browse files
committedFeb 4, 2025
Initial commit
1 parent 87290f5 commit 1157e1b

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed
 

‎train_fsdp.py

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import os
2+
from datasets import load_dataset
3+
4+
import torch
5+
from transformers import LlamaForCausalLM, AutoTokenizer
6+
from torch.distributed._composable.fsdp import fully_shard
7+
import torch.distributed as dist
8+
from tqdm import tqdm
9+
from transformers.data import DataCollatorForSeq2Seq
10+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
11+
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
12+
13+
from torchdata.stateful_dataloader import StatefulDataLoader
14+
15+
from torchft import (
16+
DistributedSampler,
17+
Manager,
18+
Optimizer,
19+
ProcessGroupBabyNCCL,
20+
ProcessGroupGloo,
21+
)
22+
from torchft.process_group import ft_init_device_mesh
23+
24+
def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None, manager=None):
25+
26+
if replica_group_size is None or sharding_group_size is None:
27+
raise ValueError("Both replica_group_size and sharding_group_size must be provided.")
28+
29+
device = device or f"cuda"
30+
31+
device_mesh = ft_init_device_mesh(
32+
device_type=device,
33+
mesh_shape=(replica_group_size, sharding_group_size),
34+
mesh_dim_names=("dp_replicate", "dp_shard"),
35+
replicate_dim=0,
36+
manager=manager,
37+
)
38+
if device_mesh is None:
39+
raise RuntimeError("Failed to create a valid device mesh.")
40+
41+
return device_mesh
42+
43+
def parallelize_llama(model, mesh):
44+
sharding_conditions = [lambda m: isinstance(m, LlamaDecoderLayer)]
45+
46+
for m in reversed(list(model.modules())):
47+
if any(c(m) for c in sharding_conditions):
48+
# fully_shard(m, mesh=mesh, reshard_after_forward=True)
49+
fully_shard(m, mesh=mesh)
50+
# fully_shard([model.model.embed_tokens, model.lm_head], mesh=mesh)
51+
fully_shard(model, mesh=mesh)
52+
53+
def main():
54+
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
55+
NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))
56+
NUM_REPLICAS = int(os.environ.get("NUM_REPLICAS", 2))
57+
58+
rank = int(os.environ.get("RANK", 0))
59+
60+
model_name = "Meta-Llama/Llama-3.2-1B-Instruct"
61+
tokenizer = AutoTokenizer.from_pretrained(model_name)
62+
model = LlamaForCausalLM.from_pretrained(model_name)
63+
64+
if not tokenizer.pad_token_id:
65+
tokenizer.pad_token_id = tokenizer.eos_token_id
66+
67+
# If there is a mismatch between tokenizer vocab size and embedding matrix,
68+
# throw a warning and then expand the embedding matrix
69+
assert len(tokenizer) == model.get_input_embeddings().weight.shape[0]
70+
71+
train_data = load_dataset("samsum", split="train")
72+
73+
class SAMSumDataset(torch.utils.data.Dataset):
74+
def __init__(self, data, tokenizer):
75+
self.data = data
76+
self.tokenizer = tokenizer
77+
def __getitem__(self, idx):
78+
text = self.data[idx]
79+
prompt = self.tokenizer.encode(tokenizer.bos_token + f"Summarize this dialog: {text['dialogue']}\n---\nSummary: ", add_special_tokens=False)
80+
summary = self.tokenizer.encode(text["summary"] + self.tokenizer.eos_token, add_special_tokens=False)
81+
input_ids = prompt + summary
82+
labels = len(prompt) * [-100] + summary
83+
return {"input_ids": input_ids, "labels": labels}
84+
def __len__(self):
85+
return len(self.data)
86+
87+
88+
train_dataset = SAMSumDataset(train_data, tokenizer)
89+
90+
batch_size = 8
91+
92+
sampler = DistributedSampler(
93+
train_dataset,
94+
replica_group=REPLICA_GROUP_ID,
95+
num_replica_groups=NUM_REPLICA_GROUPS,
96+
rank=rank,
97+
shuffle=True,
98+
num_replicas=NUM_REPLICAS,
99+
)
100+
101+
train_dataloader = StatefulDataLoader(train_dataset, batch_size=batch_size, collate_fn=DataCollatorForSeq2Seq(tokenizer), sampler=sampler)
102+
103+
def load_state_dict(state_dict):
104+
set_state_dict(
105+
model,
106+
optimizer.optim,
107+
model_state_dict=state_dict["model"],
108+
optim_state_dict=state_dict["optim"],
109+
)
110+
111+
112+
def state_dict():
113+
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer.optim)
114+
return {
115+
"model": model_state_dict,
116+
"optim": optimizer_state_dict,
117+
}
118+
119+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
120+
121+
pg = ProcessGroupBabyNCCL() if torch.cuda.is_available() else ProcessGroupGloo()
122+
123+
manager = Manager(
124+
pg=pg,
125+
min_replica_size=1,
126+
load_state_dict=load_state_dict,
127+
state_dict=state_dict,
128+
replica_id=f"train_fsdp_{REPLICA_GROUP_ID}",
129+
use_async_quorum=False,
130+
)
131+
132+
mesh = hsdp_device_mesh(NUM_REPLICA_GROUPS, NUM_REPLICAS, "cuda" if torch.cuda.is_available() else "cpu", manager=manager)
133+
134+
parallelize_llama(model, mesh)
135+
136+
model.to(device)
137+
138+
optimizer = Optimizer(manager, torch.optim.Adam(model.parameters(), lr=1e-5))
139+
140+
while manager.current_step() < 500:
141+
model.train()
142+
for batch in tqdm(train_dataloader):
143+
input_ids = batch["input_ids"].to(device)
144+
labels = batch["labels"].to(device)
145+
optimizer.zero_grad()
146+
147+
outputs = model(input_ids, labels=labels)
148+
loss = outputs.loss
149+
loss.backward()
150+
optimizer.step()
151+
152+
if manager.current_step() % 100 == 0:
153+
print(f"[{manager.current_step()}] loss = {loss.item()}")
154+
155+
156+
if __name__ == "__main__":
157+
main()

0 commit comments

Comments
 (0)
Failed to load comments.