Encoder model

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(EncoderAttention, self).__init__()
        resnet = resnet50(pretrained=True)
        self.resnet=nn.Sequential(*list(resnet.children())[:-2])
        self.relu = nn.ReLU()
    def forward(self, images):
        features = self.resnet(images)
        features=self.relu(features)
        features=features.view(features.size(0), 2048,-1)
        features=features.permute(0,2,1)
        return features

Badraunan Attention

In [None]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size=256, attention_dim=64):
        super(BahdanauAttention, self).__init__()
        self.Ua=nn.Linear(hidden_size, attention_dim)
        self.Wa=nn.Linear(hidden_size, attention_dim)
        self.Va=nn.Linear(attention_dim,1)
    def forward(self, decoder_hidden, encoder_hidden):
        sum1=self.Ua(decoder_hidden)
        sum2=self.Wa(encoder_hidden)
        scores=self.Va(torch.tanh(sum1+sum2))
        scores=scores.squeeze(2).unsqueeze(1)
        weights=nn.functional.softmax(scores, dim=-1)
        context=torch.bmm(weights, encoder_hidden)
        return context

Decoder RNN with attention

In [None]:
class DecoderAttention(nn.Module):
    def __init__(self, vocab_size=vocab_size, feature_size=2048, emb_size=100, hidden_size=256):
        super(DecoderAttention, self).__init__()
        self.linear1 = nn.Linear(feature_size, hidden_size)
        self.embed = nn.Embedding(vocab_size, emb_size)
        self.attention=BahdanauAttention()
        self.lstm = nn.LSTM(emb_size+hidden_size,hidden_size, batch_first=True)
        self.linear2 = nn.Linear(hidden_size, vocab_size)
        self.relu=nn.ReLU()
        #self.softmax=nn.Softmax()
    def forward(self, encoder_outputs, target=None):
        encoder_outputs=self.linear1(encoder_outputs)#batch_size, pixed, hidden_size
        batch_size=encoder_outputs.size(0)
        decoder_input=torch.ones(batch_size, 1, dtype=torch.long)
        decoder_hidden=encoder_outputs.mean(dim=1).unsqueeze(dim=0)# 1, batchsize, hidden_size
        decoder_cell=encoder_outputs.mean(dim=1).unsqueeze(dim=0)# 1, batch_size, hidden_size
        decoder_outputs=[]
        for i in range(max_length):
            decoder_emb=self.embed(decoder_input.to(device))
            decoder_emb=self.relu(decoder_emb) # batch_size,1, emb_size
            hidden_att=decoder_hidden.permute(1,0,2)# batch_size, 1, hidden_size
            context=self.attention(hidden_att, encoder_outputs) # batch_size, 1, hidden_size
            input_lstm=torch.cat((decoder_emb, context), dim=2)
            decoder_output, (decoder_hidden, decoder_cell)= self.lstm(input_lstm, (decoder_hidden, decoder_cell))
            decoder_output=self.relu(decoder_output)
            decoder_output=self.linear2(decoder_output)
            decoder_outputs.append(decoder_output)#.squeeze(1))
            if target is not None:
                decoder_input=target[:,i].unsqueeze(1).to(torch.long)
            else:
                _, dcd_input=decoder_output.topk(1)
                decoder_input=dcd_input.squeeze(-1).detach()
        decoder_outputs=torch.cat(decoder_outputs, dim=1)
        decoder_output=nn.functional.log_softmax(decoder_output, dim=-1)
        return decoder_outputs