Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about Training Pipeline and Parallel Attention #19

Closed
YanShuang17 opened this issue Jan 6, 2022 · 3 comments
Closed

Questions about Training Pipeline and Parallel Attention #19

YanShuang17 opened this issue Jan 6, 2022 · 3 comments

Comments

@YanShuang17
Copy link

YanShuang17 commented Jan 6, 2022

作者你好,论文提供的思路对我启发很大!

我有两个问题想请教一下:

1. 关于训练的pipeline

我看到论文中描述的训练流程只包含language-freelanguage-aware两个环节,类似于代码中的LF_2LA,但代码中还额外增加了LF_1来专门预训练backbone + VRM部分,并且在LF_2过程的optimizer中还针对LF_1训练过的params采用了不同的lr。请问LF_1 --> LF_2 --> LALF_2 --> LA两种训练方式差别大吗?

2. 关于并行解码环节的attention计算方式

问题是我对实现逻辑不太理解

这里我和SRN中视觉部分(PVAM)中的attention过程作对比:

(a) SRN-PVAM中的attention过程(伪代码,假设qkv的维度都是d_model):

# e.g. d_model = 512, max_seq_len = seq_len_q = 25, vocab_size = 37
key2att = nn.Linear(d_model, d_model)
query2att = nn.Linear(d_model, d_model)
embedding = nn.Embedding(max_seq_len, d_model)
score = nn.Linear(d_model, 1)
classifier = nn.Linear(d_model, vocab_size)

# input is encoder_out
reading_order = torch.arange(max_seq_len, dtype=torch.long)
Q = embedding(reading_order)  # (max_seq_len, d_model)
K, V = encoder_out  # (batch_size, seq_len_k, d_model)

# 这里计算att_weight的过程很容易理解,和经典的attention模型比如ASTER的attention过程相同
######
att_q = key2att(Q).unsqueeze(0).unsqueeze(2)  # (1, seq_len_q, 1, d_model)
att_k = query2att(K).unsqueeze(1)  # (batch_size, 1, seq_len_k, d_model)
att_weight = score(torch.tanh(att_q + att_k)).squeeze(3)  # (batch_size, seq_len_q, seq_len_k)
######

att_weight = F.softmax(att_weight, dim=-1)
decoder_out = torch.bmm(att_weight, K)  # (batch_size, seq_len_q, d_model)
logits = classifier(decoder_out)  # (batch_size, seq_len_q, vicab_size)

(b) VisionLAN中的attention过程:

# e.g. d_model = 512, max_seq_len = seq_len_q = 25, vocab_size = 37
embedding = nn.Embedding(max_seq_len, d_model)
w0 = nn.Linear(max_seq_len, seq_len_k)
wv = nn.Linear(d_model, d_model)
we = nn.Linear(d_model, max_seq_len)
classifier = nn.Linear(d_model, vocab_size)

# input is encoder_out
K, V = encoder_out  # (batch_size, seq_len_k, d_model)
reading_order = torch.arange(max_seq_len, dtype=torch.long)

# 如何理解下面这段计算att_weight的代码?
#####
reading_order = embedding(reading_order)  # (seq_len_q, d_model)
reading_order = reading_order.unsqueeze(0).expand(K.size(0), -1)  # (batch_size, seq_len_q, d_model)
t = w0(reading_order.permute(0, 2, 1))  # (batch_size, d_model, seq_len_q) ==> (batch_size, d_model, seq_len_k)
t = torch.tanh(t.permute(0, 2, 1) + wv(K))  # (batch_size, seq_len_k, d_model)
att_weight = we(t)  # (batch_size, seq_len_k, d_model) ==> (batch_size, seq_len_k, seq_len_q)
att_weight = att_weight.permute(0, 2, 1)
######

att_weight = F.softmax(att_weight, dim=-1)
decoder_out = torch.bmm(att_weight, K)  # (batch_size, seq_len_q, d_model)
logits = classifier(decoder_out)  # (batch_size, seq_len_q, vicab_size)

期待你的回复,谢谢!

@wangyuxin87
Copy link
Owner

1.因为直接加上MLM的训练速度会慢很多,所以我们采用先训练视觉模型再微调MLM的方式学习。另外,我们在实验中也发现这样的精度会高一些。
2.我们的预测过程目的是通过阅读顺序解码识别结果。(a)(b)两个实现是一样的。

@YanShuang17
Copy link
Author

感谢回复!
第一点理解了;
不过第二点中,(a)(b)两个实现不太一样吧?首先权重的shape就不一样,其次,(a)中能显式地看到每一个query访问每一个key的过程,但在(b)中没看到。
两种实现的显存占用区别也很大,(a)中会产生一个很大的中间tensor(shape=[batch_size, seq_len_q, seq_len_k, d_model]),所以显存占用会大很多。

@wangyuxin87
Copy link
Owner

抱歉没有注意到你写的维度。VisionLAN的出发点是实现简洁快速的pipeline,因此我们对识别部分进行了简化。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants