From 5085f99f160efc693a3aa89f5faa7cfd87500440 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 30 Oct 2025 15:19:51 -0700 Subject: [PATCH 01/11] add single prompt mode --- tests/integration_tests/flux.py | 11 ++++++++ torchtitan/models/flux/inference/infer.py | 31 ++++++++++++++++++----- torchtitan/models/flux/job_config.py | 2 ++ torchtitan/models/flux/run_infer.sh | 2 +- 4 files changed, 38 insertions(+), 8 deletions(-) diff --git a/tests/integration_tests/flux.py b/tests/integration_tests/flux.py index 321ac1280c..d759fd024b 100755 --- a/tests/integration_tests/flux.py +++ b/tests/integration_tests/flux.py @@ -41,6 +41,17 @@ def build_flux_test_list() -> list[OverrideDefinitions]: "Flux Validation Test", "validation", ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable", + ], + ["--inference.prompt='A beautiful sunset over the ocean'"], + ], + "Flux Generation script test", + "test_generate", + ngpu=2, + ), ] return integration_tests_flavors diff --git a/torchtitan/models/flux/inference/infer.py b/torchtitan/models/flux/inference/infer.py index 0c06a385ef..cdb5e73135 100644 --- a/torchtitan/models/flux/inference/infer.py +++ b/torchtitan/models/flux/inference/infer.py @@ -25,11 +25,24 @@ def inference(config: JobConfig): # Distributed processing setup: Each GPU/process handles a subset of prompts world_size = int(os.environ["WORLD_SIZE"]) global_rank = int(os.environ["RANK"]) - original_prompts = open(config.inference.prompts_path).readlines() - total_prompts = len(original_prompts) - # Distribute prompts across processes using round-robin assignment - prompts = original_prompts[global_rank::world_size] + single_prompt_mode = config.inference.prompt is not None + + # Use single prompt if specified, otherwise read from file + if single_prompt_mode: + original_prompts = [config.inference.prompt] + logger.info(f"Using single prompt: {config.inference.prompt}") + bs = 1 + # If only single prompt, each rank will generate an image with the same prompt + prompts = original_prompts + else: + original_prompts = open(config.inference.prompts_path).readlines() + logger.info(f"Reading prompts from: {config.inference.prompts_path}") + bs = config.inference.local_batch_size + # Distribute prompts across processes using round-robin assignment + prompts = original_prompts[global_rank::world_size] + + total_prompts = len(original_prompts) trainer.checkpointer.load(step=config.checkpoint.load_step) t5_tokenizer, clip_tokenizer = build_flux_tokenizer(config) @@ -39,15 +52,19 @@ def inference(config: JobConfig): if prompts: # Generate images for this process's assigned prompts - bs = config.inference.local_batch_size output_dir = os.path.join( config.job.dump_folder, config.inference.save_img_folder, ) - # Create mapping from local indices to global prompt indices - global_ids = list(range(global_rank, total_prompts, world_size)) + if single_prompt_mode: + # In single prompt mode, all ranks process the same prompt (index 0) + # But each rank generates a different image (different seed/rank) + global_ids = [0] * len(prompts) + else: + # In multi-prompt mode, use round-robin distribution + global_ids = list(range(global_rank, total_prompts, world_size)) for i in range(0, len(prompts), bs): images = generate_image( diff --git a/torchtitan/models/flux/job_config.py b/torchtitan/models/flux/job_config.py index 60422de2ee..f9bda99760 100644 --- a/torchtitan/models/flux/job_config.py +++ b/torchtitan/models/flux/job_config.py @@ -62,6 +62,8 @@ class Inference: """Path to save the inference results""" prompts_path: str = "./torchtitan/experiments/flux/inference/prompts.txt" """Path to file with newline separated prompts to generate images for""" + prompt: str = "" + """Single prompt to generate image for. If specified, takes precedence over prompts_path""" local_batch_size: int = 2 """Batch size for inference""" img_size: int = 256 diff --git a/torchtitan/models/flux/run_infer.sh b/torchtitan/models/flux/run_infer.sh index bf1b4aa5a6..67c2690a49 100755 --- a/torchtitan/models/flux/run_infer.sh +++ b/torchtitan/models/flux/run_infer.sh @@ -10,7 +10,7 @@ set -ex # use envs as local overrides for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./torchtitan/models/flux/run_train.sh -NGPU=${NGPU:-"8"} +NGPU=${NGPU:-"4"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/flux/train_configs/debug_model.toml"} From 2ceab76f4cef5c454969e8885cdcad26dc13e43a Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 30 Oct 2025 15:44:11 -0700 Subject: [PATCH 02/11] add flux --- torchtitan/models/flux/run_train.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/models/flux/run_train.sh b/torchtitan/models/flux/run_train.sh index 2661e02691..dda0515f19 100755 --- a/torchtitan/models/flux/run_train.sh +++ b/torchtitan/models/flux/run_train.sh @@ -10,7 +10,7 @@ set -ex # use envs as local overrides for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./torchtitan/experiments/flux/run_train.sh -NGPU=${NGPU:-"8"} +NGPU=${NGPU:-"4"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/flux/train_configs/debug_model.toml"} From 882ad3243aae3b6b6558ad67821c1e01c59789cd Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 13 Nov 2025 15:50:27 -0800 Subject: [PATCH 03/11] revert --- torchtitan/models/flux/inference/infer.py | 34 +++++++------------- torchtitan/models/flux/inference/prompts.txt | 24 -------------- torchtitan/models/flux/job_config.py | 2 -- torchtitan/models/flux/run_infer.sh | 2 +- torchtitan/models/flux/run_train.sh | 2 +- 5 files changed, 14 insertions(+), 50 deletions(-) diff --git a/torchtitan/models/flux/inference/infer.py b/torchtitan/models/flux/inference/infer.py index cdb5e73135..461f348873 100644 --- a/torchtitan/models/flux/inference/infer.py +++ b/torchtitan/models/flux/inference/infer.py @@ -26,21 +26,17 @@ def inference(config: JobConfig): world_size = int(os.environ["WORLD_SIZE"]) global_rank = int(os.environ["RANK"]) - single_prompt_mode = config.inference.prompt is not None - - # Use single prompt if specified, otherwise read from file - if single_prompt_mode: - original_prompts = [config.inference.prompt] - logger.info(f"Using single prompt: {config.inference.prompt}") - bs = 1 - # If only single prompt, each rank will generate an image with the same prompt - prompts = original_prompts - else: - original_prompts = open(config.inference.prompts_path).readlines() - logger.info(f"Reading prompts from: {config.inference.prompts_path}") - bs = config.inference.local_batch_size - # Distribute prompts across processes using round-robin assignment - prompts = original_prompts[global_rank::world_size] + original_prompts = open(config.inference.prompts_path).readlines() + logger.info(f"Reading prompts from: {config.inference.prompts_path}") + if len(original_prompts) < world_size: + raise ValueError( + f"Number of prompts ({len(prompts)}) must be >= number of ranks ({world_size}). " + f"FSDP all-gather will hang if some ranks have no prompts to process." + ) + + bs = config.inference.local_batch_size + # Distribute prompts across processes using round-robin assignment + prompts = original_prompts[global_rank::world_size] total_prompts = len(original_prompts) @@ -58,13 +54,7 @@ def inference(config: JobConfig): config.inference.save_img_folder, ) # Create mapping from local indices to global prompt indices - if single_prompt_mode: - # In single prompt mode, all ranks process the same prompt (index 0) - # But each rank generates a different image (different seed/rank) - global_ids = [0] * len(prompts) - else: - # In multi-prompt mode, use round-robin distribution - global_ids = list(range(global_rank, total_prompts, world_size)) + global_ids = list(range(global_rank, total_prompts, world_size)) for i in range(0, len(prompts), bs): images = generate_image( diff --git a/torchtitan/models/flux/inference/prompts.txt b/torchtitan/models/flux/inference/prompts.txt index 23a76c5a34..feaf6a9b48 100644 --- a/torchtitan/models/flux/inference/prompts.txt +++ b/torchtitan/models/flux/inference/prompts.txt @@ -2,27 +2,3 @@ A serene mountain landscape at sunset with a crystal clear lake reflecting the g A futuristic cityscape with flying cars and neon lights illuminating the night sky A cozy cafe interior with steam rising from coffee cups and warm lighting A magical forest with glowing mushrooms and fireflies dancing between ancient trees -A peaceful beach scene with turquoise waves and palm trees swaying in the breeze -A steampunk-inspired mechanical dragon soaring through clouds -A mystical library with floating books and magical artifacts -A Japanese garden in spring with cherry blossoms falling gently -A space station orbiting a colorful nebula -A medieval castle on a hilltop during a dramatic thunderstorm -A underwater scene with bioluminescent creatures and coral reefs -A desert oasis with a majestic palace and palm trees -A cyberpunk street market with holographic signs and diverse crowds -A cozy winter cabin surrounded by snow-covered pine trees -A fantasy tavern filled with unique characters and magical atmosphere -A tropical rainforest with exotic birds and waterfalls -A steampunk airship navigating through storm clouds -A peaceful zen garden with a traditional Japanese tea house -A magical potion shop with bubbling cauldrons and mysterious ingredients -A futuristic space colony on Mars with domed habitats -A mystical temple hidden in the clouds -A vintage train station with steam locomotives and period architecture -A magical bakery with floating pastries and enchanted ingredients -A peaceful countryside scene with rolling hills and a rustic farmhouse -A underwater city with advanced technology and marine life -A fantasy marketplace with magical creatures and exotic goods -A peaceful meditation garden with lotus flowers and koi ponds -A steampunk laboratory with intricate machinery and glowing elements diff --git a/torchtitan/models/flux/job_config.py b/torchtitan/models/flux/job_config.py index f9bda99760..60422de2ee 100644 --- a/torchtitan/models/flux/job_config.py +++ b/torchtitan/models/flux/job_config.py @@ -62,8 +62,6 @@ class Inference: """Path to save the inference results""" prompts_path: str = "./torchtitan/experiments/flux/inference/prompts.txt" """Path to file with newline separated prompts to generate images for""" - prompt: str = "" - """Single prompt to generate image for. If specified, takes precedence over prompts_path""" local_batch_size: int = 2 """Batch size for inference""" img_size: int = 256 diff --git a/torchtitan/models/flux/run_infer.sh b/torchtitan/models/flux/run_infer.sh index 67c2690a49..bf1b4aa5a6 100755 --- a/torchtitan/models/flux/run_infer.sh +++ b/torchtitan/models/flux/run_infer.sh @@ -10,7 +10,7 @@ set -ex # use envs as local overrides for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./torchtitan/models/flux/run_train.sh -NGPU=${NGPU:-"4"} +NGPU=${NGPU:-"8"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/flux/train_configs/debug_model.toml"} diff --git a/torchtitan/models/flux/run_train.sh b/torchtitan/models/flux/run_train.sh index dda0515f19..2661e02691 100755 --- a/torchtitan/models/flux/run_train.sh +++ b/torchtitan/models/flux/run_train.sh @@ -10,7 +10,7 @@ set -ex # use envs as local overrides for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./torchtitan/experiments/flux/run_train.sh -NGPU=${NGPU:-"4"} +NGPU=${NGPU:-"8"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/flux/train_configs/debug_model.toml"} From b0c16026a701f2537cbbdc0f28c6a78307038a55 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 13 Nov 2025 16:13:58 -0800 Subject: [PATCH 04/11] fix format --- torchtitan/models/flux/inference/infer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torchtitan/models/flux/inference/infer.py b/torchtitan/models/flux/inference/infer.py index 461f348873..deef034c30 100644 --- a/torchtitan/models/flux/inference/infer.py +++ b/torchtitan/models/flux/inference/infer.py @@ -25,12 +25,12 @@ def inference(config: JobConfig): # Distributed processing setup: Each GPU/process handles a subset of prompts world_size = int(os.environ["WORLD_SIZE"]) global_rank = int(os.environ["RANK"]) - original_prompts = open(config.inference.prompts_path).readlines() - logger.info(f"Reading prompts from: {config.inference.prompts_path}") - if len(original_prompts) < world_size: + total_prompts = len(original_prompts) + + if total_prompts < world_size: raise ValueError( - f"Number of prompts ({len(prompts)}) must be >= number of ranks ({world_size}). " + f"Number of prompts ({total_prompts}) must be >= number of ranks ({world_size}). " f"FSDP all-gather will hang if some ranks have no prompts to process." ) @@ -38,8 +38,6 @@ def inference(config: JobConfig): # Distribute prompts across processes using round-robin assignment prompts = original_prompts[global_rank::world_size] - total_prompts = len(original_prompts) - trainer.checkpointer.load(step=config.checkpoint.load_step) t5_tokenizer, clip_tokenizer = build_flux_tokenizer(config) From d5a5a880d86cd1a070fc9cdd3125d896bf0c403a Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 13 Nov 2025 16:18:16 -0800 Subject: [PATCH 05/11] remove commands --- tests/integration_tests/flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/flux.py b/tests/integration_tests/flux.py index d759fd024b..0967ebd44a 100755 --- a/tests/integration_tests/flux.py +++ b/tests/integration_tests/flux.py @@ -46,7 +46,7 @@ def build_flux_test_list() -> list[OverrideDefinitions]: [ "--checkpoint.enable", ], - ["--inference.prompt='A beautiful sunset over the ocean'"], + [], ], "Flux Generation script test", "test_generate", From 5f6dbc4aaa43c5dbbcb9457033a19832a7337d79 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 18 Nov 2025 22:51:31 -0800 Subject: [PATCH 06/11] trigger the CI flow --- .github/workflows/integration_test_8gpu_models.yaml | 1 + torchtitan/models/flux/inference/infer.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/integration_test_8gpu_models.yaml b/.github/workflows/integration_test_8gpu_models.yaml index 129049b8f6..b3fa769cca 100644 --- a/.github/workflows/integration_test_8gpu_models.yaml +++ b/.github/workflows/integration_test_8gpu_models.yaml @@ -54,3 +54,4 @@ jobs: python -m tests.integration_tests.run_tests --test_suite models artifacts-to-be-uploaded --ngpu 8 python -m tests.integration_tests.flux artifacts-to-be-uploaded/flux --ngpu 8 rm -rf artifacts-to-be-uploaded/*/checkpoint + rm -rf artifacts-to-be-uploaded/flux/test_generate/inference_results/ diff --git a/torchtitan/models/flux/inference/infer.py b/torchtitan/models/flux/inference/infer.py index deef034c30..b89887ad51 100644 --- a/torchtitan/models/flux/inference/infer.py +++ b/torchtitan/models/flux/inference/infer.py @@ -34,7 +34,6 @@ def inference(config: JobConfig): f"FSDP all-gather will hang if some ranks have no prompts to process." ) - bs = config.inference.local_batch_size # Distribute prompts across processes using round-robin assignment prompts = original_prompts[global_rank::world_size] @@ -46,6 +45,7 @@ def inference(config: JobConfig): if prompts: # Generate images for this process's assigned prompts + bs = config.inference.local_batch_size output_dir = os.path.join( config.job.dump_folder, From 6321841fd5b579528b0cb48cfad02560a062bd9c Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 20 Nov 2025 11:56:01 -0800 Subject: [PATCH 07/11] merge flux tests --- tests/integration_tests/flux.py | 27 ++++---------------- torchtitan/models/flux/inference/prompts.txt | 24 +++++++++++++++++ 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/tests/integration_tests/flux.py b/tests/integration_tests/flux.py index 0967ebd44a..afb746c2e7 100755 --- a/tests/integration_tests/flux.py +++ b/tests/integration_tests/flux.py @@ -26,31 +26,14 @@ def build_flux_test_list() -> list[OverrideDefinitions]: "--parallelism.data_parallel_shard_degree 2", "--parallelism.data_parallel_replicate_degree 2", "--parallelism.context_parallel_degree 2", - ] - ], - "HSDP+CP", - "hsdp+cp", - ngpu=8, - ), - OverrideDefinitions( - [ - [ "--validation.enable", - ] - ], - "Flux Validation Test", - "validation", - ), - OverrideDefinitions( - [ - [ "--checkpoint.enable", ], - [], + [] ], - "Flux Generation script test", - "test_generate", - ngpu=2, + "HSDP+CP+Validation+Inference", + "hsdp+cp+validation+inference", + ngpu=8, ), ] return integration_tests_flavors @@ -84,7 +67,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd # save checkpoint (idx == 0) and load it for generation (idx == 1) - if test_name == "test_generate" and idx == 1: + if test_name == "hsdp+cp+validation+inference" and idx == 1: # For flux generation, test using inference script cmd = ( f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} " diff --git a/torchtitan/models/flux/inference/prompts.txt b/torchtitan/models/flux/inference/prompts.txt index feaf6a9b48..23a76c5a34 100644 --- a/torchtitan/models/flux/inference/prompts.txt +++ b/torchtitan/models/flux/inference/prompts.txt @@ -2,3 +2,27 @@ A serene mountain landscape at sunset with a crystal clear lake reflecting the g A futuristic cityscape with flying cars and neon lights illuminating the night sky A cozy cafe interior with steam rising from coffee cups and warm lighting A magical forest with glowing mushrooms and fireflies dancing between ancient trees +A peaceful beach scene with turquoise waves and palm trees swaying in the breeze +A steampunk-inspired mechanical dragon soaring through clouds +A mystical library with floating books and magical artifacts +A Japanese garden in spring with cherry blossoms falling gently +A space station orbiting a colorful nebula +A medieval castle on a hilltop during a dramatic thunderstorm +A underwater scene with bioluminescent creatures and coral reefs +A desert oasis with a majestic palace and palm trees +A cyberpunk street market with holographic signs and diverse crowds +A cozy winter cabin surrounded by snow-covered pine trees +A fantasy tavern filled with unique characters and magical atmosphere +A tropical rainforest with exotic birds and waterfalls +A steampunk airship navigating through storm clouds +A peaceful zen garden with a traditional Japanese tea house +A magical potion shop with bubbling cauldrons and mysterious ingredients +A futuristic space colony on Mars with domed habitats +A mystical temple hidden in the clouds +A vintage train station with steam locomotives and period architecture +A magical bakery with floating pastries and enchanted ingredients +A peaceful countryside scene with rolling hills and a rustic farmhouse +A underwater city with advanced technology and marine life +A fantasy marketplace with magical creatures and exotic goods +A peaceful meditation garden with lotus flowers and koi ponds +A steampunk laboratory with intricate machinery and glowing elements From 9935e9643baf53cb4622bbf3a5663a00dad6e1a4 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 20 Nov 2025 11:59:08 -0800 Subject: [PATCH 08/11] lint --- .github/workflows/integration_test_8gpu_models.yaml | 2 +- tests/integration_tests/flux.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration_test_8gpu_models.yaml b/.github/workflows/integration_test_8gpu_models.yaml index b3fa769cca..b673da5adf 100644 --- a/.github/workflows/integration_test_8gpu_models.yaml +++ b/.github/workflows/integration_test_8gpu_models.yaml @@ -54,4 +54,4 @@ jobs: python -m tests.integration_tests.run_tests --test_suite models artifacts-to-be-uploaded --ngpu 8 python -m tests.integration_tests.flux artifacts-to-be-uploaded/flux --ngpu 8 rm -rf artifacts-to-be-uploaded/*/checkpoint - rm -rf artifacts-to-be-uploaded/flux/test_generate/inference_results/ + rm -rf artifacts-to-be-uploaded/flux/*/inference_results/ diff --git a/tests/integration_tests/flux.py b/tests/integration_tests/flux.py index afb746c2e7..0005805175 100755 --- a/tests/integration_tests/flux.py +++ b/tests/integration_tests/flux.py @@ -29,7 +29,7 @@ def build_flux_test_list() -> list[OverrideDefinitions]: "--validation.enable", "--checkpoint.enable", ], - [] + [], ], "HSDP+CP+Validation+Inference", "hsdp+cp+validation+inference", From 4f4918cbee7b789b4926661745d5cf1a1bd2b999 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 20 Nov 2025 16:08:21 -0800 Subject: [PATCH 09/11] test FLUX CP + validation --- tests/integration_tests/flux.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/integration_tests/flux.py b/tests/integration_tests/flux.py index 0005805175..709d5e7a15 100755 --- a/tests/integration_tests/flux.py +++ b/tests/integration_tests/flux.py @@ -27,12 +27,10 @@ def build_flux_test_list() -> list[OverrideDefinitions]: "--parallelism.data_parallel_replicate_degree 2", "--parallelism.context_parallel_degree 2", "--validation.enable", - "--checkpoint.enable", ], - [], ], - "HSDP+CP+Validation+Inference", - "hsdp+cp+validation+inference", + "HSDP+CP+Validation", + "hsdp+cp+validation", ngpu=8, ), ] From c446bfa6e99d226709f2f6cff05bb694a5f49373 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 24 Nov 2025 07:59:52 -0800 Subject: [PATCH 10/11] test changes --- tests/integration_tests/flux.py | 10 ++++++---- torchtitan/models/flux/train_configs/debug_model.toml | 2 ++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/integration_tests/flux.py b/tests/integration_tests/flux.py index 709d5e7a15..4a8040207d 100755 --- a/tests/integration_tests/flux.py +++ b/tests/integration_tests/flux.py @@ -27,10 +27,12 @@ def build_flux_test_list() -> list[OverrideDefinitions]: "--parallelism.data_parallel_replicate_degree 2", "--parallelism.context_parallel_degree 2", "--validation.enable", + "--validaiton.step 5" "--checkpoint.enable", ], + [], ], - "HSDP+CP+Validation", - "hsdp+cp+validation", + "HSDP+CP+Validation+Inference", + "hsdp+cp+validation+inference", ngpu=8, ), ] @@ -55,7 +57,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir t5_encoder_version_arg = ( "--encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/" ) - tokenzier_path_arg = "--model.tokenizer_path tests/assets/tokenizer" + hf_assets_path_arg = "--model.hf_assets_path tests/assets/tokenizer" all_ranks = ",".join(map(str, range(test_flavor.ngpu))) @@ -76,7 +78,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir cmd += " " + random_init_encoder_arg cmd += " " + clip_encoder_version_arg cmd += " " + t5_encoder_version_arg - cmd += " " + tokenzier_path_arg + cmd += " " + hf_assets_path_arg if override_arg: cmd += " " + " ".join(override_arg) diff --git a/torchtitan/models/flux/train_configs/debug_model.toml b/torchtitan/models/flux/train_configs/debug_model.toml index 47a033c546..b943925c1c 100644 --- a/torchtitan/models/flux/train_configs/debug_model.toml +++ b/torchtitan/models/flux/train_configs/debug_model.toml @@ -21,6 +21,7 @@ enable_wandb = false [model] name = "flux" flavor = "flux-debug" +hf_assets_path = "tests/assets/tokenizer" [optimizer] name = "AdamW" @@ -48,6 +49,7 @@ autoencoder_path = "assets/hf/FLUX.1-dev/ae.safetensors" # Autoencoder to use f [parallelism] data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 +context_parallel_degree = 1 [activation_checkpoint] mode = "full" From 44fed4d01f29971ae037f49752069bf69676478a Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 24 Nov 2025 13:25:27 -0800 Subject: [PATCH 11/11] fix typo --- tests/integration_tests/flux.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/integration_tests/flux.py b/tests/integration_tests/flux.py index 4a8040207d..a7ed51832f 100755 --- a/tests/integration_tests/flux.py +++ b/tests/integration_tests/flux.py @@ -27,7 +27,8 @@ def build_flux_test_list() -> list[OverrideDefinitions]: "--parallelism.data_parallel_replicate_degree 2", "--parallelism.context_parallel_degree 2", "--validation.enable", - "--validaiton.step 5" "--checkpoint.enable", + "--validation.steps 5", + "--checkpoint.enable", ], [], ],