Skip to content

Commit

Permalink
[bug] Fix interpolation of positional embeddings (facebookresearch#378)
Browse files Browse the repository at this point in the history
Use size instead of scale factor to specify the output size of nn.interpolate(): this avoids any rounding issue leading to mismatching output size and consistently generate the same output size as with the previous kludge (from facebookresearch/dino#8).
  • Loading branch information
patricklabatut committed Feb 22, 2024
1 parent 2302b6b commit e1277af
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions dinov2/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,21 +188,25 @@ def interpolate_pos_encoding(self, x, w, h):
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset

sqrt_N = math.sqrt(N)
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
assert N == M * M
kwargs = {}
if self.interpolate_offset:
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
sx = float(w0 + self.interpolate_offset) / M
sy = float(h0 + self.interpolate_offset) / M
kwargs["scale_factor"] = (sx, sy)
else:
# Simply specify an output size instead of a scale factor
kwargs["size"] = (w0, h0)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy),
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
mode="bicubic",
antialias=self.interpolate_antialias,
**kwargs,
)

assert int(w0) == patch_pos_embed.shape[-2]
assert int(h0) == patch_pos_embed.shape[-1]
assert (w0, h0) == patch_pos_embed.shape[-2:]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)

Expand Down Expand Up @@ -306,7 +310,7 @@ def get_intermediate_layers(
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
if reshape:
B, _, w, h = x.shape
outputs = [
Expand Down

0 comments on commit e1277af

Please sign in to comment.