Skip to content

fix: prevent batch_size=1 crashes, secure torch.load, fix device/contiguity issues#901

Merged
jhnwu3 merged 1 commit intosunlabuiuc:masterfrom
haoyu-haoyu:fix/squeeze-torchload-contiguous
Mar 22, 2026
Merged

fix: prevent batch_size=1 crashes, secure torch.load, fix device/contiguity issues#901
jhnwu3 merged 1 commit intosunlabuiuc:masterfrom
haoyu-haoyu:fix/squeeze-torchload-contiguous

Conversation

@haoyu-haoyu
Copy link
Contributor

Summary

Four categories of bug fixes across 8 files, addressing silent correctness errors, security warnings, and device compatibility crashes.

1. Bare .squeeze() → explicit dim (prevents batch_size=1 crashes)

When batch_size=1, bare .squeeze() silently removes the batch dimension, causing all downstream operations to compute on wrong axes without any error.

File Line Before After
concare.py 129 .squeeze() .squeeze(dim=-1)
concare.py 148 .squeeze() .squeeze(dim=1)
agent.py 465 .squeeze() .squeeze(dim=-1)
agent.py 469 .sum(-1).squeeze() .sum(dim=-1) (already 1-D)
agent.py 473 .mean(dim=-1).squeeze() .mean(dim=-1) (already 1-D)

2. torch.load() + weights_only=True (5 call sites)

Since PyTorch 2.6, torch.load() defaults to weights_only=True. Without explicit opt-in, loading checkpoints produces deprecation warnings and will break in future versions. Also prevents arbitrary code execution via pickle deserialization (security).

Files: trainer.py, biot.py, tfm_tokenizer.py (×2), kg_base.py

3. RNN .contiguous() before pack_padded_sequence (fixes #800)

Non-contiguous tensors (from slicing/transposing) passed to pack_padded_sequence cause CUDNN_STATUS_NOT_SUPPORTED errors. Added .contiguous() call.

4. StageNet device mismatch fix

torch.zeros()/torch.ones() were created on CPU regardless of input device, causing device mismatch crashes during GPU training. Now created on x.device. Also fixed time == Nonetime is None (PEP8).

Test plan

  • Run ConCare with batch_size=1 — should produce correct predictions
  • Load checkpoint with PyTorch 2.6+ — no deprecation warnings
  • Train RNN model with non-contiguous input — no cuDNN error
  • Train StageNet on GPU — no device mismatch

…x device/contiguity issues

1. Fix bare .squeeze() calls that silently remove the batch dimension
   when batch_size=1, causing wrong results during single-sample inference:
   - concare.py: .squeeze() → .squeeze(dim=-1) and .squeeze(dim=1)
   - agent.py: .squeeze() → .squeeze(dim=-1) or removed (already 1-D after .sum/.mean)

2. Add weights_only=True to all torch.load() calls for PyTorch 2.6+
   compatibility and security (prevents arbitrary code execution via
   pickle deserialization):
   - trainer.py, biot.py, tfm_tokenizer.py (2 calls), kg_base.py

3. Add .contiguous() before pack_padded_sequence in RNNLayer to prevent
   cuDNN errors with non-contiguous input tensors (fixes sunlabuiuc#800)

4. Fix StageNet device mismatch — tensors were created on CPU instead of
   the input tensor's device, causing crashes during GPU training:
   - torch.zeros/ones(...) → torch.zeros/ones(..., device=device)
   - time == None → time is None (PEP8)
@jhnwu3 jhnwu3 requested a review from joshuasteier March 21, 2026 18:05
Copy link
Collaborator

@jhnwu3 jhnwu3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, nice catch! Thanks for the bug fixes. If you're on the discord, happy to chat more.

@jhnwu3 jhnwu3 merged commit 06f19ab into sunlabuiuc:master Mar 22, 2026
1 check passed
@haoyu-haoyu
Copy link
Contributor Author

Ah, nice catch! Thanks for the bug fixes. If you're on the discord, happy to chat more.

No problem at all! I'll join the Discord, happy to chat with you more there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RNN contiguous error issue

2 participants