Skip to content

Commit

Permalink
Fix docs, mypy and linter
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Aug 8, 2022
1 parent 78851e6 commit 22c9850
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions torchvision/models/video/mvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

# Reference: https://github.com/facebookresearch/SlowFast/commit/1aebd71a2efad823d52b827a3deaf15a56cf4932#


def get_rel_pos(rel_pos: torch.Tensor, d: int) -> torch.Tensor:
if rel_pos.shape[0] == d:
return rel_pos
Expand Down Expand Up @@ -290,9 +291,9 @@ def __init__(
norm_layer(self.head_dim),
)

self.rel_pos_h: Optional[nn.Module] = None
self.rel_pos_w: Optional[nn.Module] = None
self.rel_pos_t: Optional[nn.Module] = None
self.rel_pos_h: Optional[nn.Parameter] = None
self.rel_pos_w: Optional[nn.Parameter] = None
self.rel_pos_t: Optional[nn.Parameter] = None
if rel_pos:
assert input_size[1] == input_size[2] # TODO: remove this limitation
size = input_size[1]
Expand Down Expand Up @@ -471,6 +472,8 @@ def __init__(
temporal_size (int): The temporal size ``T`` of the input.
block_setting (sequence of MSBlockConfig): The Network structure.
residual_pool (bool): If True, use MViTv2 pooling residual connection.
rel_pos (bool): TODO
dim_mul_in_att (bool): TODO
dropout (float): Dropout rate. Default: 0.0.
attention_dropout (float): Attention dropout rate. Default: 0.0.
stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0.
Expand Down Expand Up @@ -508,7 +511,7 @@ def __init__(
# Spatio-Temporal Class Positional Encoding
self.pos_encoding = PositionalEncoding(
embed_size=block_setting[0].input_channels,
spatial_size=tuple(input_size[1:]),
spatial_size=(input_size[1], input_size[2]),
temporal_size=input_size[0],
rel_pos=rel_pos,
)
Expand Down

0 comments on commit 22c9850

Please sign in to comment.