In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm_notebook as tqdm

torch.manual_seed(1)

<torch._C.Generator at 0x1230d73d0>

In [60]:
import sys
sys.path.append('../src/')

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [63]:
import CRF

In [64]:
class Model(nn.Module):
    def __init__(self
                 , tag_to_ix
                 , word_vocab_size
                 , word_emb_dim
                 , word_lstm_units
                 , char_vocab_size
                 , char_emb_dim
                 , char_lstm_units
                ):
        super(Model, self).__init__()
        self.word_vocab_size = word_vocab_size
        self.word_emb_dim = word_emb_dim
        self.word_lstm_units = word_lstm_units
        self.char_vocab_size = char_vocab_size
        self.char_emb_dim = char_emb_dim
        self.char_lstm_units = char_lstm_units
        self.n_class = len(tag_to_ix) + 2  # +2 means <START> and <STOP> tags.
        
        # Character input
        self.char_embeddings = nn.Embedding(num_embeddings=char_vocab_size
                                            , embedding_dim=char_emb_dim
                                            , padding_idx=0
                                           )
        self.char_lstm = nn.LSTM(input_size=char_emb_dim
                                 , hidden_size=char_lstm_units
                                 , bidirectional=True
                                 , batch_first=True
                                )
        
        # Word input
        self.word_embeddings = nn.Embedding(num_embeddings=word_vocab_size
                                            , embedding_dim=word_emb_dim
                                            , padding_idx=0
                                           )
        
        self.word_lstm = nn.LSTM(input_size=word_emb_dim + (char_emb_dim * 2)
                                 , hidden_size=word_lstm_units
                                 , bidirectional=True
                                 , batch_first=True
                                )
        
        self.hidden_to_tag = nn.Linear(in_features=word_lstm_units, out_features=self.n_class)

        self.crf = CRF.CRF(tag_to_ix)

    def forward(self, sentence, chars):
        lstm_feats = self._get_word_lstm_features(sentence)
        feats = self.hidden_to_tag(lstm_feats)
        _, tag_seq = self.crf(feats)
        
        return feats, tag_seq
    
    def _get_word_lstm_features(self, sentence, chars):
        '''
        sentence: (sentence_length) 
        chars: (sentence_length, word_length)
        '''
        word_embs = self.word_embeddings(sentence)
        chars_embs = self._get_char_lstm_features(chars)
        
        return torch.cat((word_embs, chars_embs), 1)
    
    def _get_char_lstm_features(self, chars):
        '''
        chars: (sentence_length, word_length) 
        '''
        #chars = chars.view(self.batch_size * chars.size()[1], -1)
        
        char_lengths = (chars > 0).sum(1)
        char_lengths, sorted_index = char_lengths.sort(0, descending=True)
        
        char_embs = self.char_embeddings(chars[sorted_index])
        packed = torch.nn.utils.rnn.pack_padded_sequence(char_embs, word_lengths, batch_first=True)
        
        l_packed_out, _ = self.char_lstm(packed)
        l_output, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(l_packed_out, batch_first=True)
        
        l_char_embs = torch.Tensor(torch.zeros((l_output.size(0), l_output.size(2))))
        l_char_embs[sorted_index] = torch.stack([self._cat_lstm_last(word_feat, length) for word_feat, length in zip(l_output, char_lengths)])
        
        return l_char_embs
        
    def _cat_lstm_last(self, output, length):
        return torch.cat((output[length - 1, :self.char_lstm_dim], output[0, self.char_lstm_dim:]))


In [65]:
Model({1: 2} ,1, 2, 3, 4, 5, 6)

Model(
  (char_embeddings): Embedding(4, 5, padding_idx=0)
  (char_lstm): LSTM(5, 6, batch_first=True, bidirectional=True)
  (word_embeddings): Embedding(1, 2, padding_idx=0)
  (word_lstm): LSTM(12, 3, batch_first=True, bidirectional=True)
  (hidden_to_tag): Linear(in_features=3, out_features=3, bias=True)
  (crf): CRF()
)