In [81]:
# Author: Robert Guthrie
from copy import copy
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(1)

<torch._C.Generator at 0x7ffb940f71d0>

In [7]:
def argmax(vec):
    # return the argmax as a python int
    _, idx = torch.max(vec, 1)
    return idx.item()


def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)


# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + \
        torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

In [8]:
class BiLSTM_CRF(nn.Module):

    def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):
        super(BiLSTM_CRF, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)

        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True)

        # Maps the output of the LSTM into tag space.
        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)

        # Matrix of transition parameters.  Entry i,j is the score of
        # transitioning *to* i *from* j.
        self.transitions = nn.Parameter(
            torch.randn(self.tagset_size, self.tagset_size))

        # These two statements enforce the constraint that we never transfer
        # to the start tag and we never transfer from the stop tag
        self.transitions.data[tag_to_ix[START_TAG], :] = -10000
        self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000

        self.hidden = self.init_hidden()

    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim // 2),
                torch.randn(2, 1, self.hidden_dim // 2))

    def _forward_alg(self, feats):
        # Do the forward algorithm to compute the partition function
        init_alphas = torch.full((1, self.tagset_size), -10000.)
        # START_TAG has all of the score.
        init_alphas[0][self.tag_to_ix[START_TAG]] = 0.

        # Wrap in a variable so that we will get automatic backprop
        forward_var = init_alphas

        # Iterate through the sentence
        for feat in feats:
            alphas_t = []  # The forward tensors at this timestep
            for next_tag in range(self.tagset_size):
                # broadcast the emission score: it is the same regardless of
                # the previous tag
                emit_score = feat[next_tag].view(
                    1, -1).expand(1, self.tagset_size)
                # the ith entry of trans_score is the score of transitioning to
                # next_tag from i
                trans_score = self.transitions[next_tag].view(1, -1)
                # The ith entry of next_tag_var is the value for the
                # edge (i -> next_tag) before we do log-sum-exp
                next_tag_var = forward_var + trans_score + emit_score
                # The forward variable for this tag is log-sum-exp of all the
                # scores.
                alphas_t.append(log_sum_exp(next_tag_var).view(1))
            forward_var = torch.cat(alphas_t).view(1, -1)
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        alpha = log_sum_exp(terminal_var)
        return alpha

    def _get_lstm_features(self, sentence):
        self.hidden = self.init_hidden()
        embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats

    def _score_sentence(self, feats, tags):
        # Gives the score of a provided tag sequence
        score = torch.zeros(1)
        tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags])
        for i, feat in enumerate(feats):
            score = score + \
                self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
        score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]
        return score

    def _viterbi_decode(self, feats):
        backpointers = []

        # Initialize the viterbi variables in log space
        init_vvars = torch.full((1, self.tagset_size), -10000.)
        init_vvars[0][self.tag_to_ix[START_TAG]] = 0

        # forward_var at step i holds the viterbi variables for step i-1
        forward_var = init_vvars
        for feat in feats:
            bptrs_t = []  # holds the backpointers for this step
            viterbivars_t = []  # holds the viterbi variables for this step

            for next_tag in range(self.tagset_size):
                # next_tag_var[i] holds the viterbi variable for tag i at the
                # previous step, plus the score of transitioning
                # from tag i to next_tag.
                # We don't include the emission scores here because the max
                # does not depend on them (we add them in below)
                next_tag_var = forward_var + self.transitions[next_tag]
                best_tag_id = argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            # Now add in the emission scores, and assign forward_var to the set
            # of viterbi variables we just computed
            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            backpointers.append(bptrs_t)

        # Transition to STOP_TAG
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        best_tag_id = argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]

        # Follow the back pointers to decode the best path.
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        # Pop off the start tag (we dont want to return that to the caller)
        start = best_path.pop()
        assert start == self.tag_to_ix[START_TAG]  # Sanity check
        best_path.reverse()
        return path_score, best_path

    def neg_log_likelihood(self, sentence, tags):
        feats = self._get_lstm_features(sentence)
        forward_score = self._forward_alg(feats)
        gold_score = self._score_sentence(feats, tags)
        return forward_score - gold_score

    def forward(self, sentence):  # dont confuse this with _forward_alg above.
        # Get the emission scores from the BiLSTM
        lstm_feats = self._get_lstm_features(sentence)

        # Find the best path, given the features.
        score, tag_seq = self._viterbi_decode(lstm_feats)
        return score, tag_seq

In [9]:
START_TAG = "<START>"
STOP_TAG = "<STOP>"
EMBEDDING_DIM = 5
HIDDEN_DIM = 4

# Make up some training data
training_data = [(
    "the wall street journal reported today that apple corporation made money".split(),
    "B I I I O O O B I O O".split()
), (
    "georgia tech is a university in georgia".split(),
    "B I O O O O B".split()
)]

word_to_ix = {}
for sentence, tags in training_data:
    for word in sentence:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)

tag_to_ix = {"B": 0, "I": 1, "O": 2, START_TAG: 3, STOP_TAG: 4}

model = BiLSTM_CRF(len(word_to_ix), tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM)
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)

# Check predictions before training
with torch.no_grad():
    precheck_sent = prepare_sequence(training_data[0][0], word_to_ix)
    precheck_tags = torch.tensor([tag_to_ix[t] for t in training_data[0][1]], dtype=torch.long)
    print(model(precheck_sent))

# Make sure prepare_sequence from earlier in the LSTM section is loaded
for epoch in range(
        300):  # again, normally you would NOT do 300 epochs, it is toy data
    for sentence, tags in training_data:
        # Step 1. Remember that Pytorch accumulates gradients.
        # We need to clear them out before each instance
        model.zero_grad()

        # Step 2. Get our inputs ready for the network, that is,
        # turn them into Tensors of word indices.
        sentence_in = prepare_sequence(sentence, word_to_ix)
        targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long)

        # Step 3. Run our forward pass.
        loss = model.neg_log_likelihood(sentence_in, targets)

        # Step 4. Compute the loss, gradients, and update the parameters by
        # calling optimizer.step()
        loss.backward()
        optimizer.step()

# Check predictions after training
with torch.no_grad():
    precheck_sent = prepare_sequence(training_data[0][0], word_to_ix)
    print(model(precheck_sent))
# We got it!

(tensor(2.6907), [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1])
(tensor(20.4906), [0, 1, 1, 1, 2, 2, 2, 0, 1, 2, 2])


# model

In [67]:
import sys
sys.path.append('..')
from utils.dataset.ec import ECDataset
from utils.dataloader.ec import ECDataLoader
from models.han.word_model import WordAttention
from models.han.sentence_model import SentenceWithPosition
device = torch.device('cuda: 0')

In [61]:
batch_size = 16
vocab_size = 23071
num_classes = 2
sequence_length = 41
embedding_dim = 300
dropout = 0.5
word_rnn_size = 300
word_rnn_layer = 2
sentence_rnn_size = 300
sentence_rnn_layer = 2
pos_size = 103
pos_embedding_dim = 300
pos_embedding_file= '/data/wujipeng/ec/data/embedding/pos_embedding.pkl'

In [56]:
train_dataset = ECDataset(data_root='/data/wujipeng/ec/data/test/', vocab_root='/data/wujipeng/ec/data/raw_data/', train=True)
test_dataset = ECDataset(data_root='/data/wujipeng/ec/data/test/', vocab_root='/data/wujipeng/ec/data/raw_data/', train=False)

In [57]:
train_loader = ECDataLoader(dataset=train_dataset, clause_length=sequence_length, batch_size=16, shuffle=True, sort=True, collate_fn=train_dataset.collate_fn)

In [132]:
for batch in train_loader:
    clauses, keywords, poses = ECDataset.batch2input(batch)
    labels = ECDataset.batch2target(batch)
    clauses = torch.from_numpy(clauses).to(device)
    keywords = torch.from_numpy(keywords).to(device)
    poses = torch.from_numpy(poses).to(device)
    labels = torch.from_numpy(labels).to(device)
    targets = labels
    break

In [62]:
class HierachicalAttentionModelCRF:
    def __init__(self,
                 vocab_size,
                 num_classes,
                 embedding_dim,
                 hidden_size,
                 word_model,
                 sentence_model,
                 dropout=0.5,
                 fix_embed=True,
                 name='HAN'):
        super(HierachicalAttentionModelCRF, self).__init__()

        self.num_classes = num_classes
        self.fix_embed = fix_embed
        self.name = name

        self.Embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.word_rnn = WordAttention(
            vocab_size=vocab_size,
            embedding_dim=embedding_dim,
            batch_size=batch_size,
            sequence_length=sequence_length,
            rnn_size=word_rnn_size,
            rnn_layers=word_rnn_layer,
            dropout=dropout)
        self.sentence_rnn = SentenceAttention(
            batch_size=batch_size,
            word_rnn_size = word_rnn_size,
            rnn_size = sentence_rnn_size,
            rnn_layers=sentence_rnn_layer,
            pos_size=pos_size,
            pos_embedding_dim=pos_embedding_dim,
            pos_embedding_file=pos_embedding_file
        )
        self.fc = nn.Linear(
            2 * self.word_rnn_size + 2 * self.sentence_rnn_size, num_classes)
        self.dropout = nn.Dropout(dropout)
        # self.fc = nn.Sequential(
        #     nn.Linear(2 * self.sentence_rnn_size, linear_hidden_dim),
        #     nn.ReLU(inplace=True),
        #     nn.Dropout(dropout),
        #     nn.Linear(linear_hidden_dim, num_classes)
        # )

    def init_weights(self, embeddings):
        if embeddings is not None:
            self.Embedding = self.Embedding.from_pretrained(embeddings)

    def forward(self, clauses, keywords, poses):
        inputs = self.linear(self.Embedding(clauses))
        queries = self.linear(self.Embedding(keywords))
        documents, word_attn = self.word_rnn(inputs, queries)
        outputs, sentence_attn = self.sentence_rnn(documents, poses)
        # outputs = self.fc(outputs)
        s_c = torch.cat((documents, outputs), dim=-1)
        outputs = self.fc(self.dropout(s_c))
        return outputs, word_attn, sentence_attn

## init

In [133]:
def argmax(vec):
    # return the argmax as a python int
    _, idx = torch.max(vec, 1)
    return idx.item()

def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)


# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + \
        torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

In [134]:
START_TAG = "<START>"
STOP_TAG = "<STOP>"
tag_to_ix = {0: 0, 1: 1, START_TAG: 2, STOP_TAG: 3}
tagsize = len(tag_to_ix)

In [135]:
Embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0).to(device)
word_rnn = WordAttention(
    vocab_size=vocab_size,
    embedding_dim=embedding_dim,
    batch_size=batch_size,
    sequence_length=sequence_length,
    rnn_size=word_rnn_size,
    rnn_layers=word_rnn_layer,
    dropout=dropout).to(device)
sentence_rnn = SentenceWithPosition(
    batch_size=batch_size,
    word_rnn_size = word_rnn_size,
    rnn_size = sentence_rnn_size,
    rnn_layers=sentence_rnn_layer,
    pos_size=pos_size,
    pos_embedding_dim=pos_embedding_dim,
    pos_embedding_file=pos_embedding_file
).to(device)
fc = nn.Linear(2 * word_rnn_size + 2 * sentence_rnn_size, num_classes+2).to(device)
drop = nn.Dropout(dropout).to(device)

In [282]:
transitions = nn.Parameter(torch.randn(tagset_size, tagset_size)).to(device)
transitions.data[tag_to_ix[START_TAG], :] = -10000
transitions.data[:, tag_to_ix[STOP_TAG]] = -10000

## forward

In [137]:
inputs = Embedding(clauses)
queries = Embedding(keywords)
documents, word_attn = word_rnn(inputs, queries)
outputs, sentence_attn = sentence_rnn(documents, poses)

In [138]:
s_c = torch.cat((documents, outputs), dim=-1)
outputs = fc(drop(s_c))

In [139]:
outputs.size()

torch.Size([16, 12, 4])

In [141]:
lstm_feats = copy(outputs)
lstm_feats.size()

torch.Size([16, 12, 4])

In [187]:
ignore_index = -100
masks = targets != ignore_index
masks

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0',
       dtype=torch.uint8)

### _forward_alg

In [345]:
lstm_feats.size(), masks.size()

(torch.Size([16, 12, 4]), torch.Size([16, 12]))

In [402]:
# para: feats masks
feats = copy(lstm_feats)

alpha = torch.zeros(1).to(device)
forward_vars = torch.FloatTensor().to(device)
# Iterate through the sentence
for feat, mask in zip(feats, masks):
    init_alphas = torch.full((1, tagset_size), -10000.).to(device)
    # START_TAG has all of the score.
    init_alphas[0][tag_to_ix[START_TAG]] = 0.

    # Wrap in a variable so that we will get automatic backprop
    forward_var = init_alphas
    feat = torch.masked_select(feat, mask.unsqueeze(-1).expand_as(feat)).view(feat.size(0), -1)
    for f in feat:
        alphas_t = []  # The forward tensors at this timestep
        for next_tag in range(tagset_size):
            # broadcast the emission score: it is the same regardless of
            # the previous tag
            emit_score = f[next_tag].view(1, -1).expand(1, tagset_size)
            # the ith entry of trans_score is the score of transitioning to
            # next_tag from i
            trans_score = transitions[next_tag].view(1, -1)
            # The ith entry of next_tag_var is the value for the
            # edge (i -> next_tag) before we do log-sum-exp
            next_tag_var = forward_var + trans_score + emit_score
            # The forward variable for this tag is log-sum-exp of all the
            # scores.
            alphas_t.append(log_sum_exp(next_tag_var).view(1))
        forward_var = torch.cat(alphas_t).view(1, -1)
        forward_vars = torch.cat((forward_vars, forward_var))
    terminal_var = forward_var + transitions[tag_to_ix[STOP_TAG]]
    alpha += log_sum_exp(terminal_var)

In [356]:
feat.size()

torch.Size([12, 4])

In [353]:
feat.size()

torch.Size([12, 4])

In [369]:
forward_score = alpha

In [368]:
alpha

tensor([192.0690], device='cuda:0', grad_fn=<AddBackward0>)

### _score_sentence

In [410]:
# params: feats, tags, masks
feats = copy(lstm_feats)
tags = copy(targets)

score = torch.zeros(1).to(device)
for feat, tag, mask in zip(feats, tags, masks):
    feat = torch.masked_select(feat, mask.unsqueeze(-1).expand_as(feat)).view(-1, feat.size(-1))
    tag = torch.masked_select(tag, mask)
    tag = torch.cat([torch.LongTensor([tag_to_ix[START_TAG]]).to(device), tag])
    for i, f in enumerate(feat):
        score = score + transitions[tag[i + 1], tag[i]] + f[tag[i + 1]]
    score = score + transitions[tag_to_ix[STOP_TAG], tag[-1]]

In [409]:
tag.size()

torch.Size([13])

In [363]:
gold_score = score

In [364]:
score = forward_score - gold_score
forward_score, gold_score, score

(tensor([192.0690], device='cuda:0', grad_fn=<AddBackward0>),
 tensor([99.7979], device='cuda:0', grad_fn=<AddBackward0>),
 tensor([92.2711], device='cuda:0', grad_fn=<SubBackward0>))

### _viterbi_decode

In [284]:
lstm_feats.size()

torch.Size([16, 12, 4])

In [428]:
# params: feats
feats = copy(lstm_feats)
path_score = torch.zeros(1)
best_paths = []
for feat in feats:
    backpointers = []

    # Initialize the viterbi variables in log space
    init_vvars = torch.full((1, tagset_size), -10000.).to(device)
    init_vvars[0][tag_to_ix[START_TAG]] = 0

    # forward_var at step i holds the viterbi variables for step i-1
    forward_var = init_vvars
    for f in feat:
        bptrs_t = []  # holds the backpointers for this step
        viterbivars_t = []  # holds the viterbi variables for this step

        for next_tag in range(tagset_size):
            # next_tag_var[i] holds the viterbi variable for tag i at the
            # previous step, plus the score of transitioning
            # from tag i to next_tag.
            # We don't include the emission scores here because the max
            # does not depend on them (we add them in below)
            next_tag_var = forward_var + transitions[next_tag]
            best_tag_id = argmax(next_tag_var)
            bptrs_t.append(best_tag_id)
            viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
        # Now add in the emission scores, and assign forward_var to the set
        # of viterbi variables we just computed
        forward_var = (torch.cat(viterbivars_t) + f).view(1, -1)
        backpointers.append(bptrs_t)

    # Transition to STOP_TAG
    terminal_var = forward_var + transitions[tag_to_ix[STOP_TAG]]
    best_tag_id = argmax(terminal_var)
    path_score += terminal_var[0][best_tag_id]

    # Follow the back pointers to decode the best path.
    best_path = [best_tag_id]
    for bptrs_t in reversed(backpointers):
        best_tag_id = bptrs_t[best_tag_id]
        best_path.append(best_tag_id)
    # Pop off the start tag (we dont want to return that to the caller)
    start = best_path.pop()
    assert start == tag_to_ix[START_TAG]  # Sanity check
    best_path.reverse()
    best_paths += best_path
best_paths = torch.LongTensor(best_paths)

In [427]:
# params: feats, masks
feats = copy(lstm_feats)
path_score = torch.zeros(1)
best_paths = []
forward_vars = torch.FloatTensor().to(device)

for feat in feats:
    backpointers = []

    # Initialize the viterbi variables in log space
    init_vvars = torch.full((1, tagset_size), -10000.).to(device)
    init_vvars[0][tag_to_ix[START_TAG]] = 0

    # forward_var at step i holds the viterbi variables for step i-1
    forward_var = init_vvars
    for f in feat:
        bptrs_t = []  # holds the backpointers for this step
        viterbivars_t = []  # holds the viterbi variables for this step
        for next_tag in range(tagset_size):
            # next_tag_var[i] holds the viterbi variable for tag i at the
            # previous step, plus the score of transitioning
            # from tag i to next_tag.
            # We don't include the emission scores here because the max
            # does not depend on them (we add them in below)
            next_tag_var = forward_var + transitions[next_tag]
            best_tag_id = argmax(next_tag_var)
            bptrs_t.append(best_tag_id)
            viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
        # Now add in the emission scores, and assign forward_var to the set
        # of viterbi variables we just computed
        forward_var = (torch.cat(viterbivars_t) + f).view(1, -1)
        forward_vars = torch.cat((forward_vars, forward_var))
        backpointers.append(bptrs_t)

    # Transition to STOP_TAG
    terminal_var = forward_var + transitions[tag_to_ix[STOP_TAG]]
    best_tag_id = argmax(terminal_var)
    path_score += terminal_var[0][best_tag_id]

    # Follow the back pointers to decode the best path.
    best_path = [best_tag_id]
    for bptrs_t in reversed(backpointers):
        best_tag_id = bptrs_t[best_tag_id]
        best_path.append(best_tag_id)
    # Pop off the start tag (we dont want to return that to the caller)
    start = best_path.pop()
    assert start == tag_to_ix[START_TAG]  # Sanity check
    best_path.reverse()
    best_paths += best_path
best_paths = torch.LongTensor(best_paths)

In [434]:
forward_vars.view_as(feats).size()

torch.Size([16, 12, 4])

In [433]:
feats.size()

torch.Size([16, 12, 4])

In [332]:
path_score

tensor([161.3597], grad_fn=<AddBackward0>)

In [394]:
best_paths.size()

torch.Size([192])

In [397]:
forward_vars[:, :2].size()

torch.Size([192, 2])

# class

In [438]:
# Author: Robert Guthrie
# https://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html
def argmax(vec):
    # return the argmax as a python int
    _, idx = torch.max(vec, 1)
    return idx.item()


# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + \
        torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))


class HierarchicalAttentionNetworkCRF(nn.Module):
    def __init__(self,
                 vocab_size,
                 num_classes,
                 tagset_size,
                 embedding_dim,
                 word_model,
                 sentence_model,
                 dropout=0.5,
                 fix_embed=True,
                 name='HAN'):
        super(HierarchicalAttentionNetworkCRF, self).__init__()
        self.word_rnn_size = word_model['args']['rnn_size']
        self.sentence_rnn_size = sentence_model['args']['rnn_size']
        self.num_classes = num_classes
        self.tagset_size = tagset_size
        self.fix_embed = fix_embed
        self.name = name

        self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))
        self.Embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.word_rnn = eval(word_model['class'])(**word_model['args'])
        self.sentence_rnn = eval(sentence_model['class'])(**sentence_model['args'])
        self.fc = nn.Linear(2 * self.word_rnn_size + 2 * self.sentence_rnn_size, tagset_size)
        self.dropout = nn.Dropout(dropout)

    def init_weights(self, embeddings):
        if embeddings is not None:
            self.Embedding = self.Embedding.from_pretrained(embeddings)

    def set_device(self, device):
        self.device = device

    def _forward_alg(self, feats, masks):
        alpha = torch.zeros(1).to(self.device)
        forward_vars = torch.FloatTensor().to(self.device)
        # Iterate through the sentence
        for feat, mask in zip(feats, masks):
            # Do the forward algorithm to compute the partition function
            init_alphas = torch.full((1, self.tagset_size), -10000.).to(self.device)
            # START_TAG has all of the score.
            init_alphas[0][2] = 0.

            # Wrap in a variable so that we will get automatic backprop
            forward_var = init_alphas

            feat = torch.masked_select(feat, mask.unsqueeze(-1).expand_as(feat)).view(-1, feat.size(-1))
            for f in feat:
                alphas_t = []  # The forward tensors at this timestep
                for next_tag in range(self.tagset_size):
                    # broadcast the emission score: it is the same regardless of
                    # the previous tag
                    emit_score = f[next_tag].view(
                        1, -1).expand(1, self.tagset_size)
                    # the ith entry of trans_score is the score of transitioning to
                    # next_tag from i
                    trans_score = self.transitions[next_tag].view(1, -1)
                    # The ith entry of next_tag_var is the value for the
                    # edge (i -> next_tag) before we do log-sum-exp
                    next_tag_var = forward_var + trans_score + emit_score
                    # The forward variable for this tag is log-sum-exp of all the
                    # scores.
                    alphas_t.append(log_sum_exp(next_tag_var).view(1))
                forward_var = torch.cat(alphas_t).view(1, -1)
                forward_vars = torch.cat((forward_vars, forward_var))
            terminal_var = forward_var + self.transitions[-1]
            alpha += log_sum_exp(terminal_var)
        return alpha, forward_vars[:, :2]

    def _get_lstm_features(self, sentences, keywords, poses):

        inputs = self.Embedding(sentences)
        queries = self.Embedding(keywords)
        documents, word_attn = self.word_rnn(inputs, queries)
        outputs, sentence_attn = self.sentence_rnn(documents, poses)
        # outputs = self.fc(outputs)
        s_c = torch.cat((documents, outputs), dim=-1)
        lstm_feats = self.fc(self.dropout(s_c))
        return lstm_feats, word_attn, sentence_attn

    def _score_sentence(self, feats, tags, masks):
        # Gives the score of a provided tag sequence
        score = torch.zeros(1).to(self.device)
        for feat, tag, mask in zip(feats, tags, masks):
            feat = torch.masked_select(feat, mask.unsqueeze(-1).expand_as(feat)).view(-1, feat.size(-1))
            tag = torch.masked_select(tag, mask)
            tag = torch.cat((torch.LongTensor([2]).to(self.device), tag))
            for i, f in enumerate(feat):
                score += self.transitions[tag[i + 1], tag[i]] + f[tag[i + 1]]
            score += self.transitions[-1, tag[-1]]
        return score

    def _viterbi_decode(self, feats):
        path_score = torch.zeros(1).to(self.device)
        best_paths = []
        forward_vars = torch.FloatTensor().to(self.device)

        for feat in feats:
            backpointers = []

            # Initialize the viterbi variables in log space
            init_vvars = torch.full((1, self.tagset_size), -10000.).to(self.device)
            init_vvars[0][2] = 0

            # forward_var at step i holds the viterbi variables for step i-1
            forward_var = init_vvars
            for f in feat:
                bptrs_t = []  # holds the backpointers for this step
                viterbivars_t = []  # holds the viterbi variables for this step

                for next_tag in range(self.tagset_size):
                    # next_tag_var[i] holds the viterbi variable for tag i at the
                    # previous step, plus the score of transitioning
                    # from tag i to next_tag.
                    # We don't include the emission scores here because the max
                    # does not depend on them (we add them in below)
                    next_tag_var = forward_var + self.transitions[next_tag]
                    best_tag_id = argmax(next_tag_var)
                    bptrs_t.append(best_tag_id)
                    viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
                # Now add in the emission scores, and assign forward_var to the set
                # of viterbi variables we just computed
                forward_var = (torch.cat(viterbivars_t) + f).view(1, -1)
                forward_vars = torch.cat((forward_vars, forward_var))
                backpointers.append(bptrs_t)

            # Transition to STOP_TAG
            terminal_var = forward_var + self.transitions[-1]
            best_tag_id = argmax(terminal_var)
            path_score += terminal_var[0][best_tag_id]

            # Follow the back pointers to decode the best path.
            best_path = [best_tag_id]
            for bptrs_t in reversed(backpointers):
                best_tag_id = bptrs_t[best_tag_id]
                best_path.append(best_tag_id)
            # Pop off the start tag (we dont want to return that to the caller)
            start = best_path.pop()
            assert start == 2  # Sanity check
            best_path.reverse()
            best_paths.append(best_path)
        forward_vars = forward_vars.view_as(feats)[:, :, :2]
        return path_score, torch.LongTensor(best_paths).to(self.device), forward_vars

    def neg_log_likelihood(self, sentences, keywords, poses, tags, masks):
        feats, word_attn, sentence_attn = self._get_lstm_features(sentences, keywords, poses)
        forward_score, forward_probs = self._forward_alg(feats, masks)
        gold_score = self._score_sentence(feats, tags, masks)
        return forward_score - gold_score, forward_probs, word_attn, sentence_attn

    def forward(self, sentences, keywords, poses):  # dont confuse this with _forward_alg above.
        # Get the emission scores from the BiLSTM
        lstm_feats, word_attn, sentence_attn = self._get_lstm_features(sentences, keywords, poses)

        # Find the best path, given the features.
        score, tag_seq, tag_probs = self._viterbi_decode(lstm_feats)
        return score, tag_seq, tag_probs, word_attn, sentence_attn

In [439]:
model = HierarchicalAttentionNetworkCRF()

In [458]:
class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.linear = nn.Linear(200, 2)
        
    def forward(self, inputs):
        self.w = torch.zeros(100)
        c = torch.cat((inputs, self.w))
        return self.linear(c)

In [459]:
inputs = torch.randn(100).to(device)
model = Test().to(device)

In [460]:
model(inputs)

RuntimeError: Expected object of backend CUDA but got backend CPU for sequence element 1 in sequence argument at position #1 'tensors'