From b05584f6f85d700d4becc03b0b91f6598cf4aee4 Mon Sep 17 00:00:00 2001 From: Tuatini Godard Date: Wed, 17 Apr 2024 14:31:30 +0200 Subject: [PATCH] Add Pytorch 2.1 compatibility https://github.com/facebookresearch/dinov2/pull/281 --- dinov2/fsdp/__init__.py | 9 ++++----- dinov2/train/ssl_meta_arch.py | 8 +++++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/dinov2/fsdp/__init__.py b/dinov2/fsdp/__init__.py index ed454480e..56683c7b6 100644 --- a/dinov2/fsdp/__init__.py +++ b/dinov2/fsdp/__init__.py @@ -62,11 +62,10 @@ def is_sharded_fsdp(x): return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD -def free_if_fsdp(x): - if is_sharded_fsdp(x): - handles = x._handles - true_list = [True for h in handles] - _reshard(x, handles, true_list) +def free_if_fsdp(x: FSDP): + if is_sharded_fsdp(x) and x._has_params: + handle = x._handle + _reshard(x, handle, True) def get_fsdp_modules(x): diff --git a/dinov2/train/ssl_meta_arch.py b/dinov2/train/ssl_meta_arch.py index 3ccf15e90..cdacfe556 100644 --- a/dinov2/train/ssl_meta_arch.py +++ b/dinov2/train/ssl_meta_arch.py @@ -348,9 +348,11 @@ def get_teacher_output(): def fsdp_synchronize_streams(self): if self.need_to_synchronize_fsdp_streams: torch.cuda.synchronize() - self.student.dino_head._streams = ( - self.teacher.dino_head._streams - ) = self.student.backbone._streams = self.teacher.backbone._streams + for attr in {"_unshard_stream", "_post_backward_stream", "_pre_unshard_stream", "_all_reduce_stream", "_default_stream"}: + stream = getattr(self.teacher.backbone, attr) + setattr(self.student.dino_head, attr, stream) + setattr(self.teacher.dino_head, attr, stream) + setattr(self.student.backbone, attr, stream) self.need_to_synchronize_fsdp_streams = False def update_teacher(self, m):