In [14]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import os

In [3]:
class CNNENcoder(nn.Module):
    def __init__(self):
        super(CNNENcoder,self).__init__()
        resnet = torchvision.models.resnet152(pretrained=True)
        
        for param in resnet.parameters():
            param.required_grad = False
        
        modules = list(resnet.children())[:-2]
        #print(modules)
        self.model = nn.Sequential(*modules)
    
    def forward(self,IMAGE):
        feature = self.model(IMAGE)
        #for LSTM lets permute
        B,C,H,W = feature.size()
        feature = feature.permute(0,2,3,1)
        feature = feature.view(B,H*W,C)
        return feature

In [13]:
class Attention(nn.Module):
    def __init__(self,features,dims,output):
        super(Attention,self).__init__()
        
        self.features = features 
        self.dims = dims
        self.output = output
        
        self.Wa = nn.Linear(features,dims)
        
        self.Ua = nn.Linear(dims,dims)
        
        self.va = nn.Linear(dims,output)
        
    
    def forward(self,feature_map,hidden):
        
        FM = feature_map
        
        feature_map = self.Wa(feature_map)
        hidden = self.Ua(hidden)
        
        _resultant_vec = torch.tanh(feature_map + hidden)
        _resultant_vec = self.va(_resultant_vec)
        print(_resultant_vec.shape)
        _resultant_vec = torch.softmax(_resultant_vec,dim=1)
        
        context = torch.sum(_resultant_vec * FM,dim=1)
        return context,_resultant_vec
        

In [5]:
av = torch.randn(1,49,2048)
ff = torch.randn(1,49,1)
print((av * ff).shape)

torch.Size([1, 49, 2048])


In [6]:
ff = Attention(2048,512,1)
hdn = torch.randn(1,1,512)
ff.forward(mdl.forward(demo),hdn)[0].shape

NameError: name 'mdl' is not defined

In [12]:
class Decoder(nn.Module):
    def __init__(self,vocab,hidden,embed_dim,features):
        
        self.embedding = nn.Embedding(vocab,embed_dim)
        self.hidden = hidden
        self.features = features
        self.vocab = vocab
        self.LSSTFC = nn.Linear(hidden,vocab)
        
        self.init_fc_h = nn.Linear(features,hidden)
        self.init_fc_c = nn.Linear(features,hidden)
        self.BrunAtten = Attention(2048,256,1)
        self.LSTM = nn.LSTMCell(features + embed_dim,hidden)
        
        
    
    def init_hidden(self,feature):
        
        to_fed = torch.mean(feature,dim=1)
        h0 = self.init_fc_h(to_fed)
        c0 = self.init_fc_c(to_fed)
        return h0,c0
        
        
        
    
    def forward(self,caption,features):
        
        caption_vector = self.embedding(caption)
        #B,SEQ,EMBED
        
        SEQ_LEN = caption.size(1)
        
        vocabs  = torch.zeros(features.size(0),SEQ_LEN,self.vocab)
        attn_weights = torch.zeros(features.size(0),SEQ_LEN,features.size(1))
        h,c = init_hidden(features)
        
        
        for t in range(SEQ_LEN):
            sample_prob = 0.0 if t == 0 else 0.5
            use_sampling = np.random.random() < sample_prob
            if use_sampling == False:
                word_embed = caption_vector[:,t,:]
            contxt,attn_weight = self.BrunAtten.forward(features,h)
            _vec = torch.cat([contxt,word_embed])
            h,c = self.LSTM(_vec,(h,c))
            output = self.LSSTFC(h)
            if use_sampling == True:
                idex = F.log_softmax(output,dim=1)
                # normalized probs
                idex = idex.topk(1)[1]
                word_embed = self.embedding(idex).squeeze(1)
            
            vocabs[:,t,:] = output
            attn_weights[:,t,:] = attn_weight
            #h = 32,SEQ,hidden
        return vocabs,attn_weights
        

In [11]:
image  = torch.randn(1,3,224,224)
mdl = CNNENcoder()
mdl.forward(image).shape

torch.Size([1, 49, 2048])

In [9]:
#HYPER PARAM
BATCH = 64
EMBED_SIZE = 256
HIDDEN_SIZE = 512
FEATURES = 2048
EPoCHS = 20

In [None]:
def train(epochs,
         embed_size,
         hidden_size,
         features,epochs,vocab,
         train_loader,
         device,
         loss_criterion):
    
    encoder = CNNENcoder()
    decoder = Decoder(vocab,hidden,embed_size,features)
    
    encoder.eval()
    decoder.train()
    
    optim = torch.nn.optim.Adam(decoder.parameters(),lr=0.001)
    
    for ep in range(epochs):
        
        for i,(image,caption) in enumerate(train_loader):
            image = image.to(device)
            
            encoder.zero_grad()
            decoder.zer_grad()
            
            caption_train = caption[:,:-1].to(device)
            caption_target = caption[:,1:].to(device)
            
            # B,49,2048
            _vec_features = encoder.forward(image)
            results,attn_weghts = decoder.forward(caption_train,_vec_features)
            
            _loss = loss_criterion(results.view(-1,vocab_size),caption_target.reshape(-1))
            
            _loss.backward()
            
            optim.zero_grad()
            
            optim.step()
            
            print(f'Epochs {ep} Step {i} Loss {_loss.item()}')
        
        torch.save(decoder.state_dict(),os.path.join('/model',f'decoder{ep}.pkl'))
        torch.save(encoder.state_dict(),os.path.join('/model',f'encoder{ep}.pkl'))
            
            
            
            
            
            
    