<a href="https://colab.research.google.com/github/zhh210/flea_market/blob/master/Nested.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torchtext, spacy, torch

In [0]:
# set up fields
TEXT = torchtext.data.Field(lower=True, include_lengths=True, batch_first=True)
LABEL = torchtext.data.LabelField(dtype=torch.float)
# make splits for data
train_d, test_d = torchtext.datasets.IMDB.splits(TEXT, LABEL)
# from random import shuffle
# train_d = [i for i in train_d]
# test_d = [i for i in test_d]
# shuffle(train_d)
# shuffle(test_d)
# train_d = train_d[:500]
# test_d = test_d[:200]

In [0]:
# ' '.join(train[0].text)
import numpy as np
for i in test_d: print(i.label)

In [0]:
nlp = spacy.load('en')

def tokenizer(input):
    # Tokenize a sentence
    return [x.text for x in nlp.tokenizer(input) if x.text != " "]

NESTED_TEXT = torchtext.data.Field(sequential=True, tokenize=tokenizer)
TEXT = torchtext.data.NestedField(NESTED_TEXT)
LABEL = torchtext.data.LabelField(dtype=torch.float)
SQ_LEN = torchtext.data.RawField()
SQ_LEN.is_target = False
fields = [('text',TEXT),('label',LABEL),('seqlen',SQ_LEN)]
import numpy as np
examples = [torchtext.data.Example.fromlist([[j.text for j in nlp(' '.join(i.text)).sents], i.label],fields) for i in test_d if np.random.rand()<2]
test_data = torchtext.data.Dataset(examples, fields)
examples = [torchtext.data.Example.fromlist([[j.text for j in nlp(' '.join(i.text)).sents], i.label],fields) for i in train_d]
train_data = torchtext.data.Dataset(examples, fields)

for i in range(len(train_data)): train_data[i].seqlen = [len(j) for j in train_data[i].text]
for i in range(len(test_data)): test_data[i].seqlen = [len(j) for j in test_data[i].text]
  
train_data[0].text, train_data[0].seqlen

In [0]:
NESTED_TEXT.build_vocab(train_data,min_freq=1,vectors="glove.6B.200d")
LABEL.build_vocab(train_data)
# SQ_LEN.build_vocab(train_data)

In [0]:
import torch
BATCH_SIZE = 10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator,test_iterator = torchtext.data.BucketIterator.splits(
    (train_data, test_data), 
    batch_size=BATCH_SIZE, 
    device=device,sort_key=lambda x: len(x.text))
next(iter(test_iterator))

In [0]:
tmp = next(iter(test_iterator))
for i in range(tmp.text.shape[1]):print(tmp.text[:,i,:].shape)
tmp.seqlen

In [0]:
print(tmp.seqlen,type(tmp.seqlen))
# tmp.seqlen.sort(dim=(0,1),descending=True)
print(tmp.seqlen)

In [0]:
import torch.nn as nn

class RNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, 
                 output_dim, n_layers, bidirectional, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, 
                           bidirectional=bidirectional, dropout=dropout)
        self.fc = nn.Linear(hidden_dim*2, output_dim)
        self.dropout = nn.Dropout(dropout)
        
        self.word_weight = nn.Parameter(torch.Tensor(2*hidden_dim,2*hidden_dim))
        self.word_bias = nn.Parameter(torch.Tensor(1,2*hidden_dim))
        self.context_weight = nn.Parameter(torch.Tensor(2*hidden_dim,1))
        
        self.sentence_weight = nn.Parameter(torch.Tensor(2*hidden_dim,2*hidden_dim))
        self.sentence_bias = nn.Parameter(torch.Tensor(1,2*hidden_dim))
        self.sentence_context_weight = nn.Parameter(torch.Tensor(2*hidden_dim,1))
        
        self._create_weights(mean=0.0,std=0.05)

        
    def _create_weights(self,mean=0.0, std=0.05):
        self.word_weight.data.normal_(mean,std)
        self.context_weight.data.normal_(mean,std)
        
        self.sentence_weight.data.normal_(mean,std)
        self.sentence_context_weight.data.normal_(mean,std)
       
    def forward_email(self,email,input_lengths):
        # One pass for an email, output state vector seq
        input_lengths, perm_idx = torch.IntTensor(input_lengths).sort(0,descending=True)
        email = email[perm_idx]
        input_lengths = input_lengths.cpu().numpy()
        
        embedded = self.dropout(self.embedding(email))
#         import pdb; pdb.set_trace()
        packed_input = nn.utils.rnn.pack_padded_sequence(embedded,input_lengths,batch_first=True)
        
        output, (hidden, cell) = self.rnn(packed_input)
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(output,batch_first=True)
        
        if output.shape[0] == 1: 
            t_output = matrix_mul(output,self.word_weight,self.word_bias)
            t_output = matrix_mul(t_output.unsqueeze(0),self.context_weight)
            t_output = t_output.unsqueeze(1)
        else:
            t_output = matrix_mul(output,self.word_weight,self.word_bias)
            t_output = matrix_mul(t_output,self.context_weight).permute(1,0)
            
        res = []
        for i,l in zip(t_output.permute(1,0),input_lengths):
            # Calculate weight for each sentence
            if max(input_lengths) > l:
              this_weight = torch.cat([F.softmax(i[:l]),torch.zeros(max(input_lengths)-l)]).unsqueeze(1)
              res.append(this_weight)
              
            else:
              res.append(F.softmax(i).unsqueeze(1))
              
        attention_weight = torch.cat(res,dim=1)
#         hidden = element_wise_mul(output.permute(1,0,2),attention_weight)
        hidden = torch.mul(output.permute(2,1,0),attention_weight).sum(dim=1).permute(1,0).unsqueeze(0)
#         hidden = self.dropout(hidden)

        s_output = matrix_mul(hidden,self.sentence_weight,self.sentence_bias)
        while len(s_output.shape) < 3: s_output = s_output.unsqueeze(0)
        s_output = matrix_mul(s_output,self.sentence_context_weight)
        sentence_attention_weight = F.softmax(s_output)
#         s_hidden = element_wise_mul(hidden,sentence_attention_weight)

        s_hidden = hidden.squeeze(0).permute(1,0).mul(sentence_attention_weight).sum(dim=1)
#         return self.fc(hidden.squeeze(0))
#         hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
        return s_hidden
        
    def forward(self, text, seqlen):
        
        #text = [sent len, batch size]
        hiddens = []
        
        for email, input_lengths in zip(text,seqlen):
            hidden = self.forward_email(email,input_lengths)
            hiddens.append(hidden)
            
#         f_hidden = torch.cat(hiddens)
        f_hidden = torch.cat([i.unsqueeze(1) for i in hiddens],dim=1)
#         return self.fc(f_hidden,dim=0).squeeze(0)
        return self.fc(torch.cat([i.unsqueeze(1) for i in hiddens],dim=1).permute(1,0))
        
#         for i in range(text.shape[1]):
#         for i,j in enumerate(seqlen):
#           texti = text[:,i,:].transpose(dim0=1,dim1=0)
#           embedded = self.dropout(self.embedding(texti))
        
        #embedded = [sent len, batch size, emb dim]
        
#             output, (hidden, cell) = self.rnn(embedded)
        
        #output = [sent len, batch size, hid dim * num directions]
        #hidden = [num layers * num directions, batch size, hid dim]
        #cell = [num layers * num directions, batch size, hid dim]
        
        #concat the final forward (hidden[-2,:,:]) and backward (hidden[-1,:,:]) hidden layers
        #and apply dropout
        
#             hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
#             hiddens.append(hidden)
                
        #hidden = [batch size, hid dim * num directions]
#         print(embedded.shape, hidden.shape, torch.mean(torch.stack(hiddens),dim=0).shape)
#         import pdb; pdb.set_trace()
#         return self.fc(torch.mean(torch.stack(hiddens),dim=0).squeeze(0))

In [0]:
INPUT_DIM = len(NESTED_TEXT.vocab)
EMBEDDING_DIM = 200
HIDDEN_DIM = 100
OUTPUT_DIM = 1
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.5
import torch.nn.functional as F
model = RNN(INPUT_DIM, EMBEDDING_DIM, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT)
print(model.word_weight)
import torch.optim as optim

optimizer = optim.Adam(model.parameters())

criterion = nn.BCEWithLogitsLoss()

# model = model.to(device)
criterion = criterion.to(device)

def matrix_mul(input, weight, bias=False):
    feature_list = []
    for feature in input:
        feature = torch.mm(feature, weight)
        if isinstance(bias, torch.nn.parameter.Parameter):
            feature = feature + bias.expand(feature.size()[0], bias.size()[1])
        feature = torch.tanh(feature).unsqueeze(0)
        feature_list.append(feature)

    return torch.cat(feature_list, 0).squeeze()

def element_wise_mul(input1, input2):

    feature_list = []
    for feature_1, feature_2 in zip(input1, input2):
        feature_2 = feature_2.unsqueeze(1).expand_as(feature_1)
        feature = feature_1 * feature_2
        feature_list.append(feature.unsqueeze(0))
    output = torch.cat(feature_list, 0)

    return torch.sum(output, 0).unsqueeze(0)
  
def binary_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """

    #round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float() #convert into float for division 
    acc = correct.sum()/len(correct)
    return acc
  
def train(model, iterator, optimizer, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
#         _, perm_idx = batch.seqlen.sort(0,descending=True)
#         y = batch.sentiment
#         y = y.squeeze(0)[perm_idx]
#         optimizer.zero_grad()
        
#         predictions = model(batch.text).squeeze(1)

        predictions = model(batch.text,batch.seqlen).squeeze(1)
        
#         loss = criterion(predictions, batch.label.float().squeeze(0))
        loss = criterion(predictions,batch.label)
  
#         acc = binary_accuracy(predictions, batch.label.float().squeeze(0))
        acc = binary_accuracy(predictions,batch.label)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)
  
def evaluate(model, iterator, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

#             predictions = model(batch.text).squeeze(1)
            predictions = model(batch.text,batch.seqlen).squeeze(1)
            
            loss = criterion(predictions, batch.label.float().squeeze(0))
            
            acc = binary_accuracy(predictions, batch.label.float().squeeze(0))

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)
  
N_EPOCHS = 2

for epoch in range(N_EPOCHS):

    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, test_iterator, criterion)
    
    print(f'| Epoch: {epoch+1:02} | Train Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}% | Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}% |')

In [0]:
tmp.label,model(tmp.text,tmp.seqlen)

In [0]:
NESTED_TEXT.vocab.itos[0], NESTED_TEXT.vocab.itos[1]

In [0]:
for i,j in zip(tmp.text, tmp.seqlen):
  print(i[0,:],j)

In [0]:
10%3