Skip to content

Training hangs after 1 epoch when using QAT #2423

@kiyoonyoo

Description

@kiyoonyoo

Hi,
I am trying to do QAT with LoRA adapter.
The training works well, but after one epoch the training halts with a timeout error.
In the same configuration without QAT, it works fine.

The error log states an error occured during AllReduce operation.

Has anyone experience anything similar?

My environment is

  • torch 2.8.0.dev20250621+cu126
  • torchao 0.12.0.dev20250621+cu126
  • torchtune 0.7.0.dev20250621+cpu

This is my config.

# Model arguments
model:
  _component_: torchtune.models.qwen3.lora_qwen3_1_7b_instruct
  lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
  apply_lora_to_mlp: True
  lora_rank: 128  # higher increases accuracy and memory
  lora_alpha: 256  # usually alpha=2*rank
  lora_dropout: 0.0

# Tokenizer
tokenizer:
  _component_: torchtune.models.qwen3.qwen3_tokenizer
  path: /tmp/Qwen3-1.7B/vocab.json
  merges_file: /tmp/Qwen3-1.7B/merges.txt
  max_seq_len: 1024

# Checkpointer
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Qwen3-1.7B
  checkpoint_files: [
    model-00001-of-00002.safetensors,
    model-00002-of-00002.safetensors,
  ]
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: QWEN3
resume_from_checkpoint: False
save_adapter_weights_only: True

# Dataset
dataset:
  _component_: torchtune.datasets.wikitext_dataset
  packed: False  # True increases speed
seed: null
shuffle: True

# Fine-tuning arguments
epochs: 3
max_steps_per_epoch: 10
batch_size: 8
gradient_accumulation_steps: 1  # Use to increase effective batch size
optimizer:
  _component_: torch.optim.AdamW
  fused: True
  lr: 2e-5
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
optimizer_in_bwd: False  # True saves memory. Requires gradient_accumulation_steps=1
loss:
  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
fake_quant_after_n_steps: 200

# Training env
device: cuda

# Memory management / performance
enable_activation_checkpointing: False  # True reduces memory
enable_activation_offloading: False  # True reduces memory
dtype: bf16
clip_grad_norm: null
compile: False  # torch.compile the model + loss, True increases speed + decreases memory

# Logging
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
log_level: INFO  # DEBUG, WARN, etc.


# Profiler (disabled)
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: False

  #Output directory of trace artifacts
  output_dir: ${output_dir}/profiling_outputs

  #`torch.profiler.ProfilerActivity` types to trace
  cpu: True
  cuda: True

  #trace options passed to `torch.profiler.profile`
  profile_memory: False
  with_stack: False
  record_shapes: True
  with_flops: False

  # `torch.profiler.schedule` options:
  # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
  wait_steps: 5
  warmup_steps: 3
  active_steps: 2
  num_cycles: 1

# QAT arguments
quantizer:
  _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
  groupsize: 256

This is the error log:

[rank1]:[E622 13:43:15.343471750 ProcessGroupNCCL.cpp:685] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=2539, OpType=ALLREDUCE, NumelIn=1, NumelOut=1, Timeout(ms)=600000) ran for 600045
 milliseconds before timing out.                                                                                                                                                                                   
[rank1]:[E622 13:43:15.343632760 ProcessGroupNCCL.cpp:2237] [PG ID 0 PG GUID 0(default_pg) Rank 1]  failure detected by watchdog at work sequence id: 2539 PG status: last enqueued work: 2539, last completed work
: 2538                                                                                                                                                                                                             
[rank1]:[E622 13:43:15.344535441 ProcessGroupNCCL.cpp:729] Stack trace of the failed collective:                                                                                                                   
#0 barrier from /home/jovyan/.venv/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py:4809
#1 wrapper from /home/jovyan/.venv/lib/python3.11/site-packages/torch/distributed/c10d_logger.py:81                                                                               
#2 _save_checkpoint_sync from /home/jovyan/.venv/lib/python3.11/site-packages/torchtune/training/checkpointing/_checkpoint_client.py:336
#3 save_checkpoint from /home/jovyan/.venv/lib/python3.11/site-packages/torchtune/training/checkpointing/_checkpoint_client.py:373                                                
#4 save_checkpoint from /home/jovyan/.venv/lib/python3.11/site-packages/recipes/qat_lora_finetune_distributed.py:672
#5 train from /home/jovyan/.venv/lib/python3.11/site-packages/recipes/qat_lora_finetune_distributed.py:837                                                                        
#6 recipe_main from /home/jovyan/.venv/lib/python3.11/site-packages/recipes/qat_lora_finetune_distributed.py:870                                              
#7 wrapper from /home/jovyan/.venv/lib/python3.11/site-packages/torchtune/config/_parse.py:99                                                                                                             
#8 <module> from /home/jovyan/.venv/lib/python3.11/site-packages/recipes/qat_lora_finetune_distributed.py:875  

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions