In [85]:
import torch.nn as nn
import torch
import torch.nn.functional as F

In [14]:
from torchsnooper import snoop

In [95]:
class Gru_Decoder(nn.Module):

    def __init__(self,  word_embeddings, batch_size, vocab_size, input_size = 768, hidden_size = 2048, num_layers = 2):
        super(Gru_Decoder,self).__init__()
        self.batch_size, self.num_layers, self.hidden_size = batch_size, num_layers, hidden_size
        self.gru = nn.GRU(input_size + input_size, hidden_size, num_layers, batch_first = True)  # teacher_forcing的是768，trg的也是768
        self.bert_embedding = word_embeddings
        self.fc_1 = nn.Linear(2048, 4096)
        self.fc_2 = nn.Linear(4096, vocab_size)
        self.dropout = nn.Dropout(p = 0.5)
    @snoop()
    def forward_step(self, trg_emb, teacher_forcing_input, hidden):
        '''
        gru一开始的hidden 应该怎么传入进去
        '''
        teacher_forcing_embed = self.bert_embedding(teacher_forcing_input)  # (batch_size, 1, 768)
        x = torch.cat((trg_emb, teacher_forcing_embed), -1)        # (batch_size, 1, 768*2 -> teacher_forcing)
        output, hidden = self.gru(x, hidden)  # (output:batch_size, 1, 2048)
        o1 = F.relu(self.fc_1(output))
        o2 = self.dropout(o1)
        o3 = self.fc_2(o1)  
        return o3, hidden


    def forward(self, trg_emb, labels):  # trg_embed是整个batch的
        '''
        返回该batch中所有trg的概率结果，输出格式为batch_size, seq_len, vocab_size
        '''
        hidden = nn.Parameter(torch.zeros(self.num_layers, self.batch_size, self.hidden_size))
        add_seq_len = []
        for time_step in range(1, labels.size(1)):  # labels.size(1) == seq_len
            teacher_forcing_input = labels[:, time_step].unsqueeze(1)  # (batch_size, 1), seq_len都是1
            output, hidden = self.forward_step(trg_emb, teacher_forcing_input, hidden)  # 每个时间点放入的都是trg和当前decode的词，output : batch_size, 1, vocab_size 
            # hidden : batch_size, 1, hidden
            add_seq_len.append(output)  # 每个时间点的预测结果存起来
        
        trg_pred = torch.stack(add_seq_len).squeeze().permute(1,0,2)  # (batch_size, seq_len, vocab_size)
        return trg_pred

In [96]:
batch_size, vocab_size, word_embedding_dim = 3, 5, 768 
labels = torch.LongTensor([[0,1,3,4,0],[2,3,3,4,0],[1,2,0,0,0]])
trg_embed = torch.randn(batch_size, 1, word_embedding_dim)

In [97]:
embedding = nn.Embedding(vocab_size, word_embedding_dim)

In [98]:
embedding(torch.LongTensor([0,1])).shape

torch.Size([2, 768])

In [99]:
model = Gru_Decoder(embedding, batch_size, vocab_size)

In [100]:
pred = model(trg_embed, labels)

Source path:... <ipython-input-95-6ea889145288>
Starting var:.. self = Gru_Decoder(  (gru): GRU(1536, 2048, num_layers=...=True)  (dropout): Dropout(p=0.5, inplace=False))
Starting var:.. trg_emb = tensor<(3, 1, 768), float32, cpu>
Starting var:.. teacher_forcing_input = tensor<(3, 1), int64, cpu>
Starting var:.. hidden = tensor<(2, 3, 2048), float32, cpu, grad>
20:48:38.771030 call        12     def forward_step(self, trg_emb, teacher_forcing_input, hidden):
20:48:38.773328 line        16         teacher_forcing_embed = self.bert_embedding(teacher_forcing_input)  # (batch_size, 1, 768)
New var:....... teacher_forcing_embed = tensor<(3, 1, 768), float32, cpu, grad>
20:48:38.775502 line        17         x = torch.cat((trg_emb, teacher_forcing_embed), -1)        # (batch_size, 1, 768*2 -> teacher_forcing)
New var:....... x = tensor<(3, 1, 1536), float32, cpu, grad>
20:48:38.776548 line        18         output, hidden = self.gru(x, hidden)  # (output:batch_size, 1, 2048)
New var:.......

In [101]:
labels.size()

torch.Size([3, 5])

In [102]:
pred.shape

torch.Size([3, 4, 5])

In [27]:
x = embedding(labels[:,1].unsqueeze(1))

In [28]:
x.shape

torch.Size([3, 1, 768])

In [25]:
trg_embed.shape

torch.Size([3, 1, 768])

In [30]:
x = torch.cat((trg_embed, x), -1)

In [31]:
x.shape

torch.Size([3, 1, 1536])

In [32]:
gru = nn.GRU(input_size = 1536, hidden_size = 2048, num_layers = 2, batch_first = True)

In [33]:
o, h = gru(x)

In [34]:
o.shape

torch.Size([3, 1, 2048])

In [35]:
h.shape

torch.Size([2, 3, 2048])

In [37]:
y, z = torch.randn(3,1,2048), torch.randn(3,1,2048)

In [67]:
p = [o,y]

In [68]:
m = torch.stack(p)

In [69]:
m.shape

torch.Size([2, 3, 1, 2048])

In [72]:
m.squeeze().permute(1,0,2).shape

torch.Size([3, 2, 2048])