Skip to content

Commit

Permalink
Add Pytorch 2.1 compatibility facebookresearch#281
Browse files Browse the repository at this point in the history
  • Loading branch information
EKami committed Apr 17, 2024
1 parent e1277af commit b05584f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
9 changes: 4 additions & 5 deletions dinov2/fsdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,10 @@ def is_sharded_fsdp(x):
return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD


def free_if_fsdp(x):
if is_sharded_fsdp(x):
handles = x._handles
true_list = [True for h in handles]
_reshard(x, handles, true_list)
def free_if_fsdp(x: FSDP):
if is_sharded_fsdp(x) and x._has_params:
handle = x._handle
_reshard(x, handle, True)


def get_fsdp_modules(x):
Expand Down
8 changes: 5 additions & 3 deletions dinov2/train/ssl_meta_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,11 @@ def get_teacher_output():
def fsdp_synchronize_streams(self):
if self.need_to_synchronize_fsdp_streams:
torch.cuda.synchronize()
self.student.dino_head._streams = (
self.teacher.dino_head._streams
) = self.student.backbone._streams = self.teacher.backbone._streams
for attr in {"_unshard_stream", "_post_backward_stream", "_pre_unshard_stream", "_all_reduce_stream", "_default_stream"}:
stream = getattr(self.teacher.backbone, attr)
setattr(self.student.dino_head, attr, stream)
setattr(self.teacher.dino_head, attr, stream)
setattr(self.student.backbone, attr, stream)
self.need_to_synchronize_fsdp_streams = False

def update_teacher(self, m):
Expand Down

0 comments on commit b05584f

Please sign in to comment.