In [1]:
import torch
import torch.nn as nn

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# make the character embedding and convolutional layer with max pooling
class CharCNN(nn.Module):
    def __init__(self, character_embedding_size, num_filters, kernel_size, max_word_length, char_vocab_size):
        super(CharCNN, self).__init__()
        self.char_embedding = nn.Embedding(char_vocab_size, character_embedding_size)
        self.conv_layers = nn.ModuleList([nn.Conv1d(character_embedding_size
                                                    , num_filters, kernel_size) for _ in range(max_word_length - kernel_size + 1)])
    
    def forward(self, x):
        # x is a batch of words. Each word is a list of characters (batch_size, max_word_length)
        # first, we convert the characters to embeddings
        x = self.char_embedding(x) # (batch_size, max_word_length, character_embedding_size)
        x = x.permute(0, 2, 1) # (batch_size, character_embedding_size, max_word_length)
        # now we run the convolutional layers
        x = [conv(x) for conv in self.conv_layers]
        # now we max pool
        x = [torch.max(torch.relu(conv), dim = 2)[0] for conv in x]
        # now we concatenate the results
        x = torch.cat(x, dim = 1) # (batch_size, num_filters * (max_word_length - kernel_size + 1))
        return x 

# ELMo part
class ELMo(nn.Module):
    def __init__(self, cnn_config, elmo_config, char_vocab_size):
        # input to this is a batch of sentences. Each sentence is a list of words. Each word is a list of characters.
        super(ELMo, self).__init__()
        # first, we convert the token to a representation using character embeddings
        self.char_cnn = CharCNN(cnn_config['character_embedding_size'], 
                                cnn_config['num_filters'], 
                                cnn_config['kernel_size'], 
                                cnn_config['max_word_length'], 
                                cnn_config['char_vocab_size'])
        self.lstm1 = nn.LSTM(cnn_config['num_filters'] * (cnn_config['max_word_length'] - cnn_config['kernel_size'] + 1), 
                             hidden_size=elmo_config['lstm_hidden_size'], 
                             num_layers=elmo_config['lstm_num_layers'], 
                             batch_first = True, 
                             bidirectional = True)
        self.lstm2 = nn.LSTM(cnn_config['num_filters'] * (cnn_config['max_word_length'] - cnn_config['kernel_size'] + 1),
                                hidden_size=elmo_config['lstm_hidden_size'], 
                                num_layers=elmo_config['lstm_num_layers'], 
                                batch_first = True, 
                                bidirectional = True)
        
        self.interpolation_linear = nn.Linear(3, 1)
        self.lambdas = nn.Parameter(torch.tensor([0.33, 0.33, 0.33], 
                                                 device = device, dtype=torch.float32))

    def forward(self, x):
        # character cnn
        x = [self.char_cnn(sentence) for sentence in x]
        # lstm1
        x = [torch.unsqueeze(sentence, 0) for sentence in x]
        x = torch.cat(x, dim = 0)
        lstm1, _ = self.lstm1(x)
        # lstm2
        lstm2, _ = self.lstm2(lstm1)

        # interpolation
        alpha = torch.nn.functional.softmax(self.lambdas, dim = 0)
        x = alpha[0] * lstm1 + alpha[1] * lstm2 + alpha[2] * x
        return x


    

