In [1]:
%matplotlib inline

In [2]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models
from torch.nn.utils.rnn import pack_padded_sequence


In [3]:
class ImageCNN(nn.Module):
    def __init__(self, image_vector_size):
        """Load the pretrained ResNet-152 and replace top fc layer."""
        super(ImageCNN, self).__init__()
        resnet = models.resnet152(pretrained=True)
        modules = list(resnet.children())[:-1]      # delete the last fc layer.
        self.resnet = nn.Sequential(*modules)
        self.linear = nn.Linear(resnet.fc.in_features, image_vector_size)
        self.bn = nn.BatchNorm1d(image_vector_size, momentum=0.01)
        self.init_weights()
        
    def init_weights(self):
        """Initialize the weights."""
        self.linear.weight.data.normal_(0.0, 0.02)
        self.linear.bias.data.fill_(0)
        
    def forward(self, images):
        """Extract the image feature vectors."""
        # images: batch_size * 3 * height * width
        #  height, width is larger than 224
        features = self.resnet(images)
        features = Variable(features.data)
        features = features.view(features.size(0), -1)
        features = self.bn(self.linear(features))
        return features

In [29]:
class MatchCNN(nn.Module):
    def __init__(self, embed_size, image_vector_size, vocab_size):
        super(MatchCNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        
        self.muti_conv1_word = nn.Linear(3 * embed_size + image_vector_size, 200)
        self.conv2_word = nn.Linear(200 * 3, 300)
        self.conv3_word = nn.Linear(300 * 3, 300)
        self.linear1_word = nn.Linear(600, 400)
        self.linear2_word = nn.Linear(400, 1)
        
        self.conv1_phs = nn.Linear(embed_size * 3, 200)
        self.muti_conv2_phs = nn.Linear(3 * 200 + image_vector_size, 300)
        self.conv3_phs = nn.Linear(300 * 3, 300)
        self.linear1_phs = nn.Linear(600, 400)
        self.linear2_phs = nn.Linear(400, 1)
        
        self.conv1_phl = nn.Linear(embed_size * 3, 200)
        self.conv2_phl = nn.Linear(200 * 3, 300)
        self.muti_conv3_phl = nn.Linear(3 * 300 + image_vector_size, 300)
        self.linear1_phl = nn.Linear(600, 400)
        self.linear2_phl = nn.Linear(400, 1)

        self.conv1_sen = nn.Linear(embed_size * 3, 200)
        self.conv2_sen = nn.Linear(200 * 3, 300)
        self.conv3_sen = nn.Linear(300 * 3, 300)
        self.muti_linear1_sen = nn.Linear(600 + image_vector_size, 400)
        self.linear2_sen = nn.Linear(400, 1)

        
    """
        image_vectors: batch_size * sentence_vector_size
        sentences : batch_size * sentence_size(now fixed as 30)
        note: Every image_vector and sentences pair should be matched
    """
    def forward(self, image_vectors, sentences):
        #For test only
#         self.sentence_vectors = Variable(torch.randn((10, 30, 50)), requires_grad = True)
#         image_vectors = Variable(torch.randn(10, 256))
            
        sentence_vectors = self.embed(sentences)
    
        features_word = self.conv(sentence_vectors, self.muti_conv1_word, image_vectors)
        features_word = self.conv(features_word, self.conv2_word)
        features_word = self.conv(features_word, self.conv3_word)
        features_word = self.mlp(features_word, self.linear1_word, self.linear2_word)
        
        features_phs = self.conv(sentence_vectors, self.conv1_phs)
        features_phs = self.conv(features_phs, self.muti_conv2_phs, image_vectors)
        features_phs = self.conv(features_phs, self.conv3_phs)
        features_phs = self.mlp(features_phs, self.linear1_phs, self.linear2_phs)
        
        features_phl = self.conv(sentence_vectors, self.conv1_phl)
        features_phl = self.conv(features_phl, self.conv2_phl)
        features_phl = self.conv(features_phl, self.muti_conv3_phl, image_vectors)
        features_phl = self.mlp(features_phl, self.linear1_phl, self.linear2_phl)
        
        features_sen = self.conv(sentence_vectors, self.conv1_sen)
        features_sen = self.conv(features_sen, self.conv2_sen)
        features_sen = self.conv(features_sen, self.conv3_sen)
        features_sen = self.mlp(features_sen, self.muti_linear1_sen, self.linear2_sen, image_vectors)
        
        return features_word + features_phs + features_phl + features_sen



    """
    features:  batch_size * sentence_size * channel_size
    return scores: batch_size * 1
    """
    def mlp(self, features, linear_function1, linear_function2, image_vectors=None):
        features = features.contiguous()
        features_num = self.num_flat_features(features)
        print("flat size:", features_num)
        features = features.view(-1, features_num)
        
        if(image_vectors is not None):
            features = torch.cat([features,image_vectors], dim=1)

        features = F.relu(linear_function1(features))
        features = F.relu(linear_function2(features))
        print("final shape:",features.data.numpy().shape)
        return features
    
    
#     def muti_mlp(self, features, image_vectors, linear_function1, linear_function2):
#         features = features.contiguous()
#         features_num = self.num_flat_features(features)
#         print("flat size:", features_num)

#         features = features.view(-1, features_num)
#         features = torch.cat([features,image_vectors], dim=1)
#         features = F.relu(linear_function1(features))
#         features = F.relu(linear_function2(features))
#         print("final shape:",features.data.numpy().shape)
#         return features
  

    """
    includ convlution, zero_gate and pooling
    """
#     def muti_conv(self, features, image_vectors, muti_conv_function):
#         features1 = self.scan_conv(features, image_vectors)
#         features = F.relu(muti_conv_function(features1))
#         features = self.zero_gate(features1, features)
#         print("muti_convlution1 features shape:", features.size())
#         features = self.sentence_pooling(features)
#         return features;
    
    
    def conv(self, features, conv_function, image_vectors = None):
        features1 = self.scan_conv(features, image_vectors)
        features = F.relu(conv_function(features1))
        features = self.zero_gate(features1, features)
        print("muti_convlution1 features shape:", features.size())
        features = self.sentence_pooling(features)
        return features;
        
        
    """
    x: batch_size * feature1_size *... * featuren_size
    return: feature1_size * feature2_size * .... featuren_size
    """
    def num_flat_features(self, x):
        size = x.size()[1:]  
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

    
    """
    features: batch_size * sentence_size * channel_size
    return: batch_size * sentence_size/2 * channel_size
    """
    def sentence_pooling(self, features):
        return  (F.max_pool1d(features.permute(0, 2, 1), 2)).permute(0, 2, 1)

    
    """
    features: batch_size * sentence_size * channel_size
    image_vectors: batch_size * image_size
    sentence_image_vectors: batch_size * (sentence_size - 3 +1) * (3*channel_size + image_size)
    """
#     def scan_muticonv(self, features, image_vectors):
#         batch_size = features.size(0)
#         sentence_size = features.size(1)
#         channel_size = features.size(2)
#         image_size = image_vectors.size(1)
#         print("muti_convlution input features shape:", features.size())
# #         features_transpose = features.permute(0, 2, 1)
        
#         sentence_image_vectors = Variable(torch.FloatTensor(batch_size, sentence_size - 3 + 1, 3*channel_size + image_size))
#         print("sentence_image_vectors shape:", sentence_image_vectors.size())
#         for i in range(3):
#             sentence_image_vectors[:,:,i * channel_size:(i+1)*channel_size] = features[:,i:sentence_size - 3 + 1 + i,:]    
#         sentence_image_vectors[:,:,3*channel_size:] = image_vectors.unsqueeze(1).repeat(1, sentence_size - 3 + 1,1)

# #       features = self.muti_conv1(sentence_image_vectors)
#         return sentence_image_vectors


    """
    features: batch_size * sentence_size * channel_size
    sentence_image_vectors: batch_size * (sentence_size - 3 +1) * (3*channel_size)
    """
#     def scan_conv(self, features):
#         batch_size = features.size(0)
#         sentence_size = features.size(1)
#         channel_size = features.size(2)
#         image_size = image_vectors.size(1)
#         print("muti_convlution input features shape:", features.size())
#         #         features_transpose = features.permute(0, 2, 1)

#         sentence_vectors = Variable(torch.FloatTensor(batch_size, sentence_size - 3 + 1, 3*channel_size))

#         print("sentence_vectors shape:", sentence_vectors.size())
#         for i in range(3):
#             sentence_vectors[:,:,i * channel_size:(i+1)*channel_size] = features[:,i:sentence_size - 3 + 1 + i,:]
#         return sentence_vectors
    def scan_conv(self, features, image_vectors=None):
        batch_size = features.size(0)
        sentence_size = features.size(1)
        channel_size = features.size(2)
        print("muti_convlution input features shape:", features.size())
        #         features_transpose = features.permute(0, 2, 1)
        if(image_vectors is None):
            sentence_vectors = Variable(torch.FloatTensor(batch_size, sentence_size - 3 + 1, 3*channel_size))
        else:
            image_size = image_vectors.size(1)
            sentence_vectors = Variable(torch.FloatTensor(batch_size, sentence_size - 3 + 1, 3*channel_size + image_size))
        print("sentence_vectors shape:", sentence_vectors.size())
        for i in range(3):
            sentence_vectors[:,:,i * channel_size:(i+1)*channel_size] = features[:,i:sentence_size - 3 + 1 + i,:]
        if(image_vectors is not None):
            sentence_vectors[:,:,3*channel_size:] = image_vectors.unsqueeze(1).repeat(1, sentence_size - 3 + 1,1)
        return sentence_vectors

    
    """
    if vector in feature1 is zero vectors, vector in feature should also be zero
    """
    def zero_gate(self,feature1, feature2):
        zero_vectors = feature1.sum(dim = 2, keepdim = True)
        zero_vectors[zero_vectors > 0] = 1
        return torch.mul(feature2, zero_vectors)

In [30]:
matchCNN = MatchCNN(embed_size = embed_size, image_vector_size = image_vector_size, vocab_size = 1000)


In [31]:
"""ensemble test"""
image_vector_size = 256
embed_size = 50
margin = 0.5
batch_size = 10
epoch = 1

imageCNN = ImageCNN(image_vector_size=image_vector_size)
matchCNN = MatchCNN(embed_size = embed_size, image_vector_size = image_vector_size, vocab_size = 1000)


"""set optimizer"""
# params = list(imageCNN.parameters()) + list(matchCNN.parameters())
params = list(imageCNN.linear.parameters()) + list(imageCNN.bn.parameters()) + list(matchCNN.parameters())

optimizer = optim.SGD(params, momentum=0.9, lr=0.001)


In [32]:
# for i in range(epoch):
"""input data"""
image = Variable(torch.randn(10,3,224,224))
image_wrong = image[torch.randperm(batch_size)]
# sentence = Variable(torch.randn(10, 30))
sentences = Variable(torch.LongTensor(np.random.randint(low=0, high=999, size=(10,30))))


"""extract imgae feature and embed sentence"""
image_vectors = imageCNN(image)
image_vectors_wrong = imageCNN(image_wrong)
# sentence_vectors = Variable(torch.randn(10, 30, 50))


In [33]:
"""get correct score"""
scores = matchCNN(image_vectors, sentences)
print("-"*20)
scores_wrong = matchCNN(image_vectors_wrong, sentences)

loss = torch.clamp(margin + scores_wrong - scores, min = 0)
loss = torch.sum(loss)

imageCNN.zero_grad()
matchCNN.zero_grad()

loss.backward()

optimizer.step()

muti_convlution input features shape: torch.Size([10, 30, 50])
sentence_vectors shape: torch.Size([10, 28, 406])
muti_convlution1 features shape: torch.Size([10, 28, 200])
muti_convlution input features shape: torch.Size([10, 14, 200])
sentence_vectors shape: torch.Size([10, 12, 600])
muti_convlution1 features shape: torch.Size([10, 12, 300])
muti_convlution input features shape: torch.Size([10, 6, 300])
sentence_vectors shape: torch.Size([10, 4, 900])
muti_convlution1 features shape: torch.Size([10, 4, 300])
flat size: 600
final shape: (10, 1)
muti_convlution input features shape: torch.Size([10, 30, 50])
sentence_vectors shape: torch.Size([10, 28, 150])
muti_convlution1 features shape: torch.Size([10, 28, 200])
muti_convlution input features shape: torch.Size([10, 14, 200])
sentence_vectors shape: torch.Size([10, 12, 856])
muti_convlution1 features shape: torch.Size([10, 12, 300])
muti_convlution input features shape: torch.Size([10, 6, 300])
sentence_vectors shape: torch.Size([10, 4

In [None]:
print()

In [None]:
#通过index也可以求梯度
x = Variable(torch.ones(2,2,2), requires_grad = True)
print(x.grad)

z = Variable(torch.randn(2,2))
z[:,:] = x[0][:][:]

y = z * 3

loss = torch.sum(y)

loss.backward()

print(x.grad)

In [None]:
a = nn.Linear(5,10)
def test(feature, linear_func):
    b = linear_func(feature)
    return b

In [None]:
f = Variable(torch.ones(1,5))
print(test(f, a))

In [None]:
def test(input, extra=None):
    if(extra == None):
        print("none")


In [None]:
test(None)

In [None]:
if(True):
    d = 1
print(d)