In [11]:
from module.CoTNet import cotn_attn
import torch.nn as nn
import torch
from torchmultimodal.modules.layers.mlp import MLP

class ImageEncoder(nn.Module): 
    def __init__(self, num_classes):
        super(ImageEncoder, self).__init__()
        self.cnn = cotn_attn(attn_layer="psa")
        self.cnn.fc = nn.Identity()
        self.cnn.fc = nn.Linear(2048, num_classes)
        
    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.size(0), -1)
        return x
    
class TextEncoder(nn.Module):
    """
        CNN_LSTM
    """
    def __init__(self, num_classes, vocab_size, embedding_dim, hidden_dim): 
        super(TextEncoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        self.cnn = nn.Sequential(
            nn.Conv1d(embedding_dim, 512, kernel_size=3, padding=1),
            nn.MaxPool1d(3, stride=2),
            nn.Conv1d(512, 256, kernel_size=3, padding=1),
            nn.MaxPool1d(3, stride=2),
            nn.Conv1d(256, 128, kernel_size=3, padding=1),
            nn.MaxPool1d(3, stride=2),
            nn.Conv1d(128, 64, kernel_size=3, padding=1),
        )

        self.lstm = nn.LSTM(64, hidden_dim, num_layers=2,dropout=0.2,bidirectional=True,batch_first=True)
        self.lstm2 = nn.LSTM(256, hidden_dim, num_layers=1,dropout=0.2,bidirectional=True,batch_first=True)
        #self.lstm3 = nn.LSTM(256, hidden_dim, num_layers=2,dropout=0.2,bidirectional=True,batch_first=True)
        #self.lstm4 = nn.LSTM(256, hidden_dim, num_layers=2,dropout=0.2,bidirectional=True,batch_first=True)
        #self.lstm3 = nn.LSTM(256, hidden_dim, num_layers=1,bidirectional=True,batch_first=True)
        #self.lstm2 = nn.LSTM(64, hidden_dim, num_layers=1,bidirectional=True,batch_first=True)
        self.fc = nn.Linear(hidden_dim*2, num_classes)
        
    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(0, 2, 1) 
        x = self.cnn(x)
        x = x.permute(0, 2, 1)
        x, _ = self.lstm(x)
        x, _ = self.lstm2(x)
        #x, _ = self.lstm3(x)
        #x, _ = self.lstm4(x)
        x = self.fc(x[:, -1, :])
        return x


class ITF_WPI_Net(nn.Module):
    def __init__(self, num_classes, vocab_size, embedding_dim, hidden_dim):
        super(ITF_WPI_Net, self).__init__()
        self.image_encoder = ImageEncoder(num_classes)
        self.text_encoder = TextEncoder(num_classes, vocab_size, embedding_dim, hidden_dim)
       
        self.mlp = MLP(34,17,hidden_dims=[128],activation=nn.ReLU,normalization=nn.BatchNorm1d)

        self.fc = nn.Linear(1, num_classes)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, images, texts):
        image_features = self.image_encoder(images) #
        text_features = self.text_encoder(texts)  # 
        features = torch.cat((image_features, text_features), dim=1)
        outputs = self.mlp(features)
        return self.softmax(outputs)