Skip to content

Commit

Permalink
Update modeling.py
Browse files Browse the repository at this point in the history
  • Loading branch information
uakarsh committed Jun 2, 2022
1 parent 007ef4d commit 9681543
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions src/docformer/modeling.py
Expand Up @@ -6,9 +6,6 @@
from einops import rearrange
from torch import Tensor


device = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = device
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
Expand Down Expand Up @@ -374,7 +371,7 @@ def __init__(self, embed_dim, n_heads, max_relative_position, max_seq_length, dr
nn.Linear(embed_dim, embed_dim),
nn.Dropout(dropout)
)
self.scale = torch.sqrt(torch.FloatTensor([embed_dim]))
self.scale = embed_dim**-0.5

def forward(self, text_feat, img_feat, text_spatial_feat, img_spatial_feat):
text_feat = text_feat
Expand All @@ -385,10 +382,11 @@ def forward(self, text_feat, img_feat, text_spatial_feat, img_spatial_feat):

# self attention of text
# b -> batch, t -> time steps (l -> length has same meaning), head -> # of heads, k -> head dim.
key_text_nh = rearrange(self.fc_k_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads).to(DEVICE)
query_text_nh = rearrange(self.fc_q_text(text_feat), 'b l (head k) -> head b l k', head=self.n_heads).to(DEVICE)
value_text_nh = rearrange(self.fc_v_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads).to(DEVICE)
dots_text = torch.einsum('hblk,hbtk->hblt', query_text_nh, key_text_nh) / self.scale.to(DEVICE)
key_text_nh = rearrange(self.fc_k_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads)
query_text_nh = rearrange(self.fc_q_text(text_feat), 'b l (head k) -> head b l k', head=self.n_heads)
value_text_nh = rearrange(self.fc_v_text(text_feat), 'b t (head k) -> head b t k', head=self.n_heads)
dots_text = torch.einsum('hblk,hbtk->hblt', query_text_nh, key_text_nh)
dots_text = dots_text/ self.scale

# 1D relative positions (query, key)
rel_pos_embed_text = self.relative_positions_text(seq_length, seq_length)
Expand All @@ -400,16 +398,18 @@ def forward(self, text_feat, img_feat, text_spatial_feat, img_spatial_feat):
query_spatial_text = self.fc_q_spatial(text_spatial_feat)
key_spatial_text_nh = rearrange(key_spatial_text, 'b t (head k) -> head b t k', head=self.n_heads)
query_spatial_text_nh = rearrange(query_spatial_text, 'b l (head k) -> head b l k', head=self.n_heads)
dots_text_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_text_nh, key_spatial_text_nh) / self.scale.to(DEVICE)
dots_text_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_text_nh, key_spatial_text_nh)
dots_text_spatial = dots_text_spatial/ self.scale

# Line 38 of pseudo-code
text_attn_scores = dots_text + rel_pos_key_text + rel_pos_query_text + dots_text_spatial

# self-attention of image
key_img_nh = rearrange(self.fc_k_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads).to(DEVICE)
query_img_nh = rearrange(self.fc_q_img(img_feat), 'b l (head k) -> head b l k', head=self.n_heads).to(DEVICE)
value_img_nh = rearrange(self.fc_v_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads).to(DEVICE)
dots_img = torch.einsum('hblk,hbtk->hblt', query_img_nh, key_img_nh) / self.scale.to(DEVICE)
key_img_nh = rearrange(self.fc_k_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads)
query_img_nh = rearrange(self.fc_q_img(img_feat), 'b l (head k) -> head b l k', head=self.n_heads)
value_img_nh = rearrange(self.fc_v_img(img_feat), 'b t (head k) -> head b t k', head=self.n_heads)
dots_img = torch.einsum('hblk,hbtk->hblt', query_img_nh, key_img_nh)
dots_img = dots_img/ self.scale

# 1D relative positions (query, key)
rel_pos_embed_img = self.relative_positions_img(seq_length, seq_length)
Expand All @@ -421,7 +421,8 @@ def forward(self, text_feat, img_feat, text_spatial_feat, img_spatial_feat):
query_spatial_img = self.fc_q_spatial(img_spatial_feat)
key_spatial_img_nh = rearrange(key_spatial_img, 'b t (head k) -> head b t k', head=self.n_heads)
query_spatial_img_nh = rearrange(query_spatial_img, 'b l (head k) -> head b l k', head=self.n_heads)
dots_img_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_img_nh, key_spatial_img_nh) / self.scale.to(DEVICE)
dots_img_spatial = torch.einsum('hblk,hbtk->hblt', query_spatial_img_nh, key_spatial_img_nh)
dots_img_spatial = dots_img_spatial/ self.scale

# Line 59 of pseudo-code
img_attn_scores = dots_img + rel_pos_key_img + rel_pos_query_img + dots_img_spatial
Expand Down

1 comment on commit 9681543

@uakarsh
Copy link
Owner Author

@uakarsh uakarsh commented on 9681543 Jun 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solves this issue shabie#30

Please sign in to comment.