In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet50
import torch.nn.functional as F

import gensim
torch.random.manual_seed(1)

In [None]:
#Dataset and preprocessing


nn.Embedding.from_pretrained(weight)

In [None]:
#Model
class ImageCNN(nn.Module):
    
    def __init__(self, stageI=True):
        super(ImageCNN, self).__init__()
        re = resnet50(pretrained=True)
        #remove the last classification layer
        self.resnet = nn.Sequential(*list(re.children())[:-1])
        
        #During StageI training -> resnet weights are fixed with pretrained weights
        if stageI: 
            for weights in self.resnet.parameters():
                weights.requires_grad_(False)
        
        self.fc = nn.Linear(2048, 2048)
        self.bn = nn.BatchNorm1d(2048)
        self.relu = nn.ReLU()
        self.dropOut = nn.Dropout(0.8)
        
    def forward(self, x):
        x = self.resnet(x) # (N,2048,1,1)
        x = torch.flatten(x,1) # (N,2048)
        x = self.relu(self.bn(self.fc(x)))
        x = self.relu(self.bn(self.fc(x))) # As per their MATLAB implementation
        return self.dropOut(x)

##Test Model -> Input (N,3,224,224) and Output (N,2048)
# net1 = ImageCNN()
# x1= torch.rand((2,3,224,224))
# y = net1(x)
# print(y.shape)

In [None]:
#textCNN -> The input should be the output of word2vec (N,300,32,1)
class BasicBlockText(nn.Module):
    
    def __init__(self, input_channel, intermediate_channel):
        super(BasicBlockText, self).__init__()
        self.bbConv1 = nn.Conv2d(input_channel, intermediate_channel, 
                                 kernel_size=(1,1), stride=(1,1), 
                                 padding=(0,0), bias=False)
        self.bbBatchNorm1 = nn.BatchNorm2d(intermediate_channel)
        self.relu = nn.ReLU()
        
        self.bbConv2 = nn.Conv2d(intermediate_channel, intermediate_channel, 
                                 kernel_size=(1,2), stride=(1,1), 
                                 padding=(0,1),bias=False, dilation = (1,2))
        self.bbBatchNorm2 = nn.BatchNorm2d(intermediate_channel)
        
        self.bbConv3 = nn.Conv2d(intermediate_channel, input_channel, kernel_size=(1,1), stride=(1,1), padding=(0,0), bias=False)
        self.bbBatchNorm3 = nn.BatchNorm2d(input_channel)
        
    def forward(self, x):
        identity  = x
        out = self.relu(self.bbBatchNorm1(self.bbConv1(x)))
        out = self.relu(self.bbBatchNorm2(self.bbConv2(out)))
        out = self.bbBatchNorm3(self.bbConv3(out))
        
        out += identity
        return self.relu(out)


        
class textCNN(nn.Module):
    
    def __init__(self, input_channel, path): #300
        super(textCNN, self).__init__()
        self.relu = nn.ReLU()
        self.word2Vec = nn.Embedding.from_pretrained(self.__load_word2vec(path))
        #--------First CNN block----------------
        self.b1Conv1_0 = nn.Conv2d(input_channel, 128, 
                                 kernel_size=(1,1), stride=(1,1), 
                                 padding=(0,0), bias=False)
        self.b1bn1_0 = nn.BatchNorm2d(128)
        
        self.b1Conv2 = nn.Conv2d(128, 128, 
                                 kernel_size=(1,2), stride=(1,1), 
                                 padding=(0,1), bias=False, dilation=(1,2))
        self.b1bn2 = nn.BatchNorm2d(128)
        
        self.b1Conv3 = nn.Conv2d(128, 256, 
                                 kernel_size=(1,1), stride=(1,1), 
                                 padding=(0,0), bias=False)
        self.b1bn3 = nn.BatchNorm2d(256)
        
        #input here too
        self.b1Conv1_1 = nn.Conv2d(input_channel, 256,
                               kernel_size=(1,1), stride=(1,1), 
                               padding=(0,0), bias=False)
        self.b1bn1_1 = nn.BatchNorm2d(256)
        
        #Adding first basicblock (in matlab code i=2:3)
        self.layer1  = self.__make_layer(input_channel=256, intermediate_channel=64, 
                                     num_blocks=2)
        
        #--------Second CNN block----------------
        self.b2Conv1_0 = nn.Conv2d(256, 512, 
                                 kernel_size=(1,1), stride=(1,1), 
                                 padding=(0,0), bias=False)
        self.b2bn1_0 = nn.BatchNorm2d(512)
        
        self.b2Conv2 = nn.Conv2d(512, 512, 
                                 kernel_size=(1,2), stride=(2,2), 
                                 padding=(0,1), bias=False, dilation= (1,2))
        self.b2bn2 = nn.BatchNorm2d(512)
        
        self.b2Conv3 = nn.Conv2d(512, 512, 
                                 kernel_size=(1,1), stride=(1,1), 
                                 padding=(0,0), bias=False)
        self.b2bn3 = nn.BatchNorm2d(512)
        
        self.b2Conv1_1 = nn.Conv2d(256, 512, 
                                 kernel_size=(1,1), stride=(2,2), 
                                 padding=(0,0), bias=False)
        self.b2bn1_1 = nn.BatchNorm2d(512)
        
        #Add second basicblock (in matlb i = 2:4)
        self.layer2 =  self.__make_layer(input_channel=512, intermediate_channel=128, 
                                     num_blocks=3)
        
        #--------Third CNN block----------------
        self.b3Conv1_0 = nn.Conv2d(512, 1024, 
                                 kernel_size=(1,1), stride=(1,1), 
                                 padding=(0,0), bias=False)
        self.b3bn1_0 = nn.BatchNorm2d(1024)
        
        self.b3Conv2 = nn.Conv2d(1024, 1024, 
                                 kernel_size=(1,2), stride=(2,2), 
                                 padding=(0,1), bias=False, dilation= (1,2))
        self.b3bn2 = nn.BatchNorm2d(1024)
        
        self.b3Conv3 = nn.Conv2d(1024, 1024, 
                                 kernel_size=(1,1), stride=(1,1), 
                                 padding=(0,0), bias=False)
        self.b3bn3 = nn.BatchNorm2d(1024)
        
        self.b3Conv1_1 = nn.Conv2d(512, 1024, 
                                 kernel_size=(1,1), stride=(2,2), 
                                 padding=(0,0), bias=False)
        self.b3bn1_1 = nn.BatchNorm2d(1024)
        
        #Add third basicblock (in matlb i = 2:6)
        self.layer3 =  self.__make_layer(input_channel=1024, intermediate_channel=256, 
                                     num_blocks=5)
        
        #------------------
        self.b4Conv1_0 = nn.Conv2d(1024, 2048, 
                                 kernel_size=(1,1), stride=(1,1), 
                                 padding=(0,0), bias=False)
        self.b4bn1_0 = nn.BatchNorm2d(2048)
        
        self.b4Conv2 = nn.Conv2d(2048, 2048, 
                                 kernel_size=(1,2), stride=(1,1), 
                                 padding=(0,1), bias=False, dilation= (1,2))
        self.b4bn2 = nn.BatchNorm2d(2048)
        
        self.b4Conv3 = nn.Conv2d(2048, 2048, 
                                 kernel_size=(1,1), stride=(1,1), 
                                 padding=(0,0), bias=False)
        self.b4bn3 = nn.BatchNorm2d(2048)
        
        self.b4Conv1_1 = nn.Conv2d(1024, 2048, 
                                 kernel_size=(1,1), stride=(1,1), 
                                 padding=(0,0), bias=False)
        self.b4bn1_1 = nn.BatchNorm2d(2048)
        
        #------
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc1 = nn.Linear(2048,2048)
        self.fc1_bn = nn.BatchNorm1d(2048)
        self.dropout = nn.Dropout(0.8)
        
    def __load_word2vec(self, path):
        model = gensim.models.KeyedVectors.load_word2vec_format(path, binary=True)
        weights = torch.FloatTensor(model.vectors)
        return weights
        
    def __make_layer(self, input_channel, intermediate_channel, num_blocks):
        layers = []
        for i in range(num_blocks):
            layers.append(BasicBlockText(input_channel, intermediate_channel))
        return nn.Sequential(*layers)  

    def forward(self, x): #input x: (N, 300, 1, 32)
        x = self.word2Vec(x)
        #--------------------
        identity = x
        out = self.relu(self.b1bn1_0(self.b1Conv1_0(x)))
        out = self.relu(self.b1bn2(self.b1Conv2(out)))
        out = self.b1bn3(self.b1Conv3(out))
        #print(out.shape)
        out2 = self.b1bn1_1(self.b1Conv1_1(identity))
        #print(out2.shape)
        out2 +=out
        out2 = self.relu(out2) # (N, 256, 1,32)
        #Add 2 basicblock
        out2 = self.layer1(out2) # (N,256,1,32)
        #------------------
        identity = out2
        out3 = self.relu(self.b2bn1_0(self.b2Conv1_0(out2)))
        out3 = self.relu(self.b2bn2(self.b2Conv2(out3)))
        out3 = self.b2bn3(self.b2Conv3(out3))
        
        out4 = self.b2bn1_1(self.b2Conv1_1(identity))
        out4 += out3
        out4 = self.relu(out4)
        #Add 3 basicblocks
        out4 = self.layer2(out4)## (N,512,1,16)
        #-------------------------------
        identity = out4
        out5 = self.relu(self.b3bn1_0(self.b3Conv1_0(out4)))
        out5 = self.relu(self.b3bn2(self.b3Conv2(out5)))
        out5 = self.b3bn3(self.b3Conv3(out5))
        
        out6 = self.b3bn1_1(self.b3Conv1_1(identity))
        out6 += out5
        out6 = self.relu(out6)
        #Add 5 basicblocks
        out6 = self.layer3(out6)## (N,1024,1,8)
        #---------------------------------------
        identity = out6
        out7 = self.relu(self.b4bn1_0(self.b4Conv1_0(out6)))
        out7 = self.relu(self.b4bn2(self.b4Conv2(out7)))
        out7 = self.b4bn3(self.b4Conv3(out7))
        
        out8 = self.b4bn1_1(self.b4Conv1_1(identity))
        out8 += out7
        out8 = self.relu(out8)
        #-------------
        out8 = self.avgpool(out8)
        out8 = torch.flatten(out8,1)
        out8 = self.dropout(self.relu(self.fc1_bn(self.fc1(out8))))
        return out8
        
##Test TextCNN
# net2 = textCNN(300)
# x2 = torch.rand(2,300,1,32)
# y2 = net2(x2)
# print(y2.shape) # (N,2048)

In [None]:
#implement weight sharing class
class Model(nn.Module):
    def __init__(self, stageI=True):
        super(Model, self).__init__()
        self.weights = torch.rand((113287, 2048)) #out_features, in_features
        self.imageCNN = ImageCNN(stageI)
        self.textCNN =  textCNN(300)
        
    def forward(self, img, txt):
        img_out = self.imageCNN(img) # (N,2048)
        txt_out = self.textCNN(txt) # (N,2048)
        return F.linear(img_out, self.weights), F.linear(txt_out, self.weights)

##Test whole model
# net = Model()
# img = torch.rand((2,3,224,224))
# txt = torch.rand((2,300,1,32))
# fc_img, fc_txt = net(img, txt)
# print(fc_img.shape) #(N,113287)
# print(fc_txt.shape) #(N,113287)

In [None]:
print(net1)

In [None]:
re = resnet50(pretrained=True)

In [1]:
import gensim

path = '../../../Downloads/GoogleNews-vectors-negative300.bin'




NameError: name 'torch' is not defined

In [3]:
import torch
weights = torch.FloatTensor(model.vectors)

In [4]:
weights.shape

torch.Size([3000000, 300])

In [5]:
embed  = torch.nn.Embedding.from_pretrained(weights)

In [None]:
x = torch.rand((2,))