Skip to content

Update Transforms tutorial — replace deprecated ToTensor(), migrate to v2 namespace, modernize one-hot encoding #3853

@sekyondaMeta

Description

@sekyondaMeta

Description

The Transforms tutorial uses the deprecated torchvision.transforms.ToTensor() transform, operates entirely in the legacy v1 namespace, and uses a low-level scatter_ pattern for one-hot encoding where F.one_hot is the modern alternative.

Changes needed

Deprecated APIs

Issue Current Code Replacement Since
torchvision.transforms.ToTensor from torchvision.transforms import ToTensor, Lambda / transform=ToTensor() (lines 31, 37) from torchvision.transforms import v2 / transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]) torchvision 0.16 (Oct 2023) — ToTensor emits a deprecation warning; v2 transforms are the supported API

Suboptimal / Outdated Patterns

Issue Current Code Modern Alternative Notes
v1 transforms namespace throughout from torchvision.transforms import ToTensor, Lambda from torchvision.transforms import v2 The entire tutorial operates in the legacy namespace. All imports and explanations should shift to v2.
Lambda for one-hot encoding Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)) (line 38) torch.nn.functional.one_hot(torch.tensor(y), num_classes=10).float() F.one_hot is cleaner and more readable. Lambda itself still works in v2, but the one-hot pattern has a better built-in.
scatter_ for one-hot encoding torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1) (lines 60-61) torch.nn.functional.one_hot(torch.tensor(y), num_classes=10).float() scatter_ is a low-level op; F.one_hot is purpose-built and self-documenting.

Specific lines

  • Line 31: from torchvision.transforms import ToTensor, Lambda -> from torchvision.transforms import v2
  • Line 37: transform=ToTensor(), -> transform=v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
  • Line 38: target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)) -> target_transform=v2.Lambda(lambda y: torch.nn.functional.one_hot(torch.tensor(y), num_classes=10).float())
  • Lines 42-48: Update the ToTensor() section heading and explanation to describe v2.ToImage() + v2.ToDtype()
  • Lines 60-61: target_transform = Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1)) -> target_transform = v2.Lambda(lambda y: torch.nn.functional.one_hot(torch.tensor(y), num_classes=10).float())

Files

  • beginner_source/basics/transforms_tutorial.py

Copied from sekyondaMeta#71

cc @subramen

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions