# Hierarchical Attention Networks for Document Classification
PyTorch implementation of **Hierarchical Attention Networks for Document Classification (NAACL 2016)**

# Table of Contents
* [Preamble](#Preamble)
* [Word2Vec Module](#Word2Vec-Module)
* [Load Vocabulary and Embeddings](#Load-Vocabulary-and-Embeddings)
* [PyTorch Dataset class](#PyTorch-Dataset-class)
* [HAN Model](#HAN-Model)
* [Training](#Training)

# Preamble

In [None]:
# Preamble
import time, random
import re, string
import os, sys
import math
import json

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

from tqdm import tqdm

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64

# Word2Vec Module
Following the original paper, we trained the word embeddings using *Negative sampling* technique introduced in **Distributed Representations of Words and Phrases and their Compositionality (NIPS 2013)**

In [None]:
# Negative sampling embedding module
class NegSamplingEmbedding(nn.Module):
    '''
    Vocab_size: V
    Embedding_size: E
    Text_length: L
    Batch_size: B
    
    Consult: https://github.com/mindspore-courses/DeepNLP-models-MindSpore/
            blob/main/notebooks/02.Skip-gram-Negative-Sampling.ipynb
    '''
    def __init__(self, vocab_size, embedding_size):
        super(NegSamplingEmbedding, self).__init__()
        self.U = nn.Embedding(vocab_size, embedding_size) # Center embedding
        self.V = nn.Embedding(vocab_size, embedding_size) # Outside embedding
        self.LogSig = nn.LogSigmoid()
        
    def forward(self, wc, wo, wk, mask_c, mask_o, check_shape= False):
        vc = self.V(wc) # Center embedding. Shape: (B, L, E)
        uo = self.U(wo) # Outside embedding. Shape: (B, L, C, E)
        uk = self.U(wk) # Random embedding. Shape: (B, L, C, E, K)
        
        if check_shape:
            B = uk.shape[0]
            L = uk.shape[1]
            C = uk.shape[2]
            K = uk.shape[3]
            E = uk.shape[4]
            print(f"Basic shapes: B = {B}; L = {L}; C = {C}; K = {K}; E = {E}")
            print('*********************************')
            print('Shape of vc:', vc.shape)
            print('Shape of uo:', uo.shape)
            print('Shape of uk:', uk.shape)
            print('*********************************')
        cmp1 = torch.einsum('blce,ble->blc', uo, vc) # Shape: (B, L, C)
        cmp2 = torch.einsum('blcke,ble->blck', uk, vc) # Shape: (B, L, C, K)
        
        cmp1 = self.LogSig(cmp1) * mask_o # Shape: (B, L, C)
        cmp2 = self.LogSig(-cmp2) # Shape: (B, L, C, K)
        cmp2 = torch.einsum('blck->blc', cmp2) * mask_o # Shape: (B, L, C)
    
        cmp1 = torch.einsum('blc->bl', cmp1) # Shape: (B, L)
        cmp2 = torch.einsum('blc->bl', cmp2) # Shape: (B, L)
        
        loss = torch.mean(cmp1 + cmp2)
        
        if check_shape:
            print('Shape of cmp1:', cmp1.shape)
            print('Shape of cmp2:', cmp2.shape)
            print('Shape of LOSS:', loss.shape)
        return -loss

# Load Vocabulary and Embeddings

In [None]:
# Prepare vocab, counter, tokenizer
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import vocab, build_vocab_from_iterator

from collections import Counter

tokenizer = get_tokenizer("basic_english")

vocab = torch.load('/kaggle/input/hlt-word2vec/vocab.pth')
counter = torch.load('/kaggle/input/hlt-word2vec/vocab.pth')

In [None]:
# Task configs
VOCAB_SIZE = len(vocab)
EMBEDDING_DIM = 200
GRU_DIM = 50
NUM_CLASSES = 5

In [None]:
# Prepare word embeddings
WORD2VEC_PATH = '/kaggle/input/hlt-word2vec/word2vec.pth'
word2vec = NegSamplingEmbedding(VOCAB_SIZE, EMBEDDING_DIM)
word2vec.load_state_dict(torch.load(WORD2VEC_PATH, map_location= DEVICE))
word2vec.eval()

W_e = word2vec.V.weight.detach().to(DEVICE)

# PyTorch Dataset class

In [None]:
# YELP Dataset
def collate_batch(batch):
    '''
    Collate batch with zero-padding
    Consult: https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html
    '''
    texts = []
    labels = []
    mask = []
    for _text, _label in batch:
        texts.append(_text)
        labels.append(_label)
        
    L = max([len(text) for text in texts])
    
    for i in range(len(texts)):
        l = texts[i].shape[0]
        cur_mask = torch.ones(L)
        if l < L:
            cur_mask[l:L] = 0
            # Zero-padding text, only on one side.
            texts[i] = F.pad(texts[i], (0, L-l), 'constant', 0)
            
        mask.append(cur_mask)
    
    texts = torch.stack(texts)
    labels = torch.stack(labels)
    mask = torch.stack(mask)
    return texts, labels, mask

def collate_batch_HAN(batch):
    return batch

class YELPDataset(Dataset):
    def __init__(self, df, vocab, tokenizer, df_sort= True, punct_splt= False):
        self.df = df
        self.tokenizer = tokenizer
        self.vocab = vocab
        self.len_vocab = len(vocab)
        self.punct_splt = punct_splt
        if df_sort:
            self.sort_df_by_txt_len()
    
    def sort_df_by_txt_len(self):
        len_list = [-len(self.df.iloc[i]['text']) for i in range(len(self.df))]
        self.df = self.df.iloc[np.argsort(len_list)]
    
    def __len__(self):
        return len(self.df)

    def text_pipeline(self, x):
        return self.vocab(self.tokenizer(x))

    def label_pipeline(self, x):
        return int(x) - 1
    
    def __getitem__(self, idx):
        if not self.punct_splt:
            txt = self.text_pipeline(self.df.iloc[idx]['text'])
            txt = torch.tensor(txt, dtype= torch.int64)

            label = self.label_pipeline(self.df.iloc[idx]['stars'])
            label = torch.tensor(label, dtype= torch.int64)

            return (txt, label)
    
        else:
            txt = self.df.iloc[idx]['text']
            sentences = re.split("[" + string.punctuation + "]+", txt)
            L = 0
            X = []
            mask = []
            
            for s in sentences:
                l = len(s)
                if l == 0:
                    continue
                L = max(L, l)
                X.append(torch.tensor(self.text_pipeline(s), dtype= torch.int64))
                
            if len(X) == 0:
                return (None, None, None)
            
            for i in range(len(X)):
                l = X[i].shape[0]
                cur_mask = torch.ones(L)
                
                # Zero-padding sentence
                if(l < L):
                    # Zero-padding sentence, only on one side.
                    X[i] = F.pad(X[i], (0, L - l), 'constant', 0)
                    cur_mask[l:L] = 0
                
                mask.append(cur_mask)
                    
            X = torch.stack(X)
            mask = torch.stack(mask)
            
            label = self.label_pipeline(self.df.iloc[idx]['stars'])
            label = torch.tensor(label, dtype= torch.int64)
            
            return (X, label, mask)

In [None]:
# CSV Preparation
data_file = open("/kaggle/input/yelp-dataset/yelp_academic_dataset_review.json")
data = []

cnt = 1569264 # Size of YELP 2015 dataset
# cnt = 10000

for line in data_file:
    data.append(json.loads(line))
    cnt -= 1
    if cnt == 0:
        break
    
data_file.close()
df = pd.DataFrame(data)

print("Number of datapoints:", len(df))
df.head()

In [None]:
# Train-val-test splits
df_size = len(df)
idx = [x for x in range(df_size)]
random.Random(555).shuffle(idx)

train_num = int(df_size * 0.8)
val_num = int(df_size * 0.01)
test_num = int(df_size * 0.1)

# print(train_num, val_num, test_num)

train_idx = idx[:train_num]
val_idx = idx[train_num : (train_num + val_num)]
test_idx = idx[(train_num + val_num) : ]

train_df = df.iloc[train_idx]
val_df = df.iloc[val_idx]
test_df = df.iloc[test_idx]

print('Size of trainset:', len(train_df))
print('Size of valset:', len(val_df))
print('Size of testset:', len(test_df))

In [None]:
# Dataset, Dataloader
trainset = YELPDataset(train_df, vocab, tokenizer, punct_splt= True)
valset = YELPDataset(val_df, vocab, tokenizer, punct_splt= True)
testset = YELPDataset(test_df, vocab, tokenizer, punct_splt= True)

trainloader = DataLoader(trainset, batch_size= BATCH_SIZE, 
                         shuffle= False, pin_memory= True, collate_fn= collate_batch_HAN)
valloader = DataLoader(valset, batch_size= BATCH_SIZE, 
                         shuffle= False, pin_memory= True, collate_fn= collate_batch_HAN)
testloader = DataLoader(testset, batch_size= BATCH_SIZE, 
                         shuffle= False, pin_memory= True, collate_fn= collate_batch_HAN)

for batch in trainloader:
    X, y, mask = batch[0]
    print("Shape of Texts:", X.shape)
    print("Shape of Labels:", y.shape)
    print("Shape of Mask:", mask.shape)
    break

# HAN Model

In [None]:
class HANModel(nn.Module):
    def __init__(self, W_e, embedding_dim, gru_dim, 
                 num_classes):
        super(HANModel, self).__init__()
        
        DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        # Initialize
        self.E = embedding_dim
        self.G = gru_dim
        self.num_classes = num_classes
        
        # Pretrained word embeddings
        self.W_e = W_e
        
        # Word-level attention
        self.WordEncoder = nn.GRU(
            input_size= embedding_dim,
            hidden_size= gru_dim,
            batch_first= True,
            bidirectional= True
        )
        self.WordMLP = nn.Sequential(
            nn.Linear(2*gru_dim, 2*gru_dim),
            nn.Tanh()
        )
        self.u_w = nn.Parameter(torch.randn(2*gru_dim))
        
        # Sequence-level attention
        self.SeqEncoder = nn.GRU(
            input_size= 2*gru_dim,
            hidden_size= gru_dim,
            batch_first= True,
            bidirectional= True
        )
        self.SeqMLP = nn.Sequential(
            nn.Linear(2*gru_dim, 2*gru_dim),
            nn.Tanh()
        )
        self.u_s = nn.Parameter(torch.randn(2*gru_dim))
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(2*gru_dim, num_classes),
            nn.Softmax()
        )
        
    def forward(self, text, mask, check_shape= False):
        '''
        <text> is expected to have shape (S, L)
        
        Number of sentence: S
        Max sentence length: L
        Embedding size: E
        GRU dimension: G -> Bidirectional: 2G
        '''
        S = text.shape[0]
        L = text.shape[1]
        
        x = self.W_e[text] # Text embedding, Shape: (S, L, E)
        hW,_ = self.WordEncoder(x) # Word annotation. Shape: (S, L, 2G)
        uW = self.WordMLP(hW) # Word hidden representation. Shape: (S, L, 2G)

        # Attention weight - Needs verifications
        alphaW = nn.Softmax()(torch.einsum('slg,g->sl', uW, self.u_w) * mask) # Shape: (S, L)
        
        # Sentence vector
        s = torch.einsum('slg,sl->sg', hW, alphaW) # Shape: (S, 2G)
        hS,_ = self.SeqEncoder(s) # Sentence annotation. Shape: (S, 2G)
        uS = self.SeqMLP(hS) # Sentence hidden representation. Shape: (S, 2G)
        
        # Attention weight - Needs verifications
        alphaS = nn.Softmax()(torch.matmul(uS, self.u_s)) # Shape: (S)
        
        # Document vector
        v = torch.matmul(alphaS, hS)
        
        logits = self.classifier(v)
        
        if check_shape:
            print(f'Basic shapes: S = {S}, E = {self.E}, L = {L}, G = {self.G}')
            print('*********************************')
            print('Shape of x:', x.shape)
            print('Shape of hW:', hW.shape)
            print('Shape of uW:', uW.shape)
            print('Shape of alphaW:', alphaW.shape)
            print('*********************************')
            print('Shape of s:', s.shape)
            print('Shape of hS:', hS.shape)
            print('Shape of uS:', uS.shape)
            print('Shape of alphaS:', alphaS.shape)
            print('*********************************')
            print('Shape of v:', v.shape)
            print('Shape of logits:', logits.shape)
        
        return logits
        

In [None]:
model = HANModel(W_e, EMBEDDING_DIM, GRU_DIM, NUM_CLASSES).to(DEVICE)

for batch in trainloader:
    X, y, mask = batch[0]
    X, mask = X.to(DEVICE), mask.to(DEVICE)
    logits = model(X, mask, check_shape= True)
    break
    
print('*********************************')
print(model)

# Training

In [None]:
# Training configs
LR = 1e-3
MOMENTUM = 0.9
EPOCHS = 3
ITER = EPOCHS * len(trainloader)
OPTIMIZER = torch.optim.AdamW(model.parameters(), lr= LR)
# OPTIMIZER = torch.optim.SGD(model.parameters(), lr= LR, momentum= MOMENTUM, nesterov= True)
SCHEDULER = lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(OPTIMIZER, T_max = ITER)
LOSS_FN = nn.CrossEntropyLoss()

In [None]:
# Train procedures
def test(testloader, model, loss_fn):
    model.eval()
    test_loss = 0
    correct = 0
    bcnt = 0
    cnt = 0
    
    for i, batch in enumerate(testloader):
        tmp_loss = 0
        cnt += 1
        for (X, y, mask) in batch:
            if X is None:
                continue
            bcnt += 1
            X, y, mask = X.to(DEVICE), y.to(DEVICE), mask.to(DEVICE)
            pred = model(X, mask)
            tmp_loss += loss_fn(logits, y).item()
            correct += (pred.argmax(0) == y).type(torch.float).sum().item()
        tmp_loss /= len(batch)
        test_loss += tmp_loss
        
    test_loss /= len(testloader)
    accuracy = correct / len(testloader.dataset)
    
    return test_loss, accuracy

def train(trainloader, valloader, model, optimizer, scheduler, loss_fn, val_freq):
    model.train()
    tloss = []
    cur_acc = 0
    for i, batch in enumerate(trainloader):
        loss = 0
        for (X, y, mask) in batch:
            if X is None:
                continue
            X, y, mask = X.to(DEVICE), y.to(DEVICE), mask.to(DEVICE)
            logits = model(X, mask)
            loss += loss_fn(logits, y)
        loss /= len(batch)
        
        if val_freq > 0 and i % val_freq == 0:
            tloss.append(loss.cpu().detach().numpy())
            model.eval()
            val_loss, val_acc = test(valloader, model, loss_fn)
            model.train()
            print(f'Iter {i}, loss = {tloss[-1]}, val_acc = {val_acc}')
            if cur_acc < val_acc:
                cur_acc = val_acc
                print('Saving model...')
                torch.save(model.state_dict(), f'HAN_{val_acc*100}.pth')
        
        tloss.append(loss.cpu().detach().numpy())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
    model.eval()
    val_loss, val_acc = test(valloader, model, loss_fn)
    model.train()
    print(f'Iter {i}, loss = {tloss[-1]}, val_acc = {val_acc}')
    if cur_acc < val_acc:
        cur_acc = val_acc
        print('Saving model...')
        torch.save(model.state_dict(), f'HAN_{val_acc*100}.pth')
        
    return tloss

In [None]:
# TRAINING
iter_loss = []
epoch_loss = []
best_acc = 0

for t in range(EPOCHS):
    print(f'Epoch {t} starts.')
    tloss = train(trainloader, valloader, model, OPTIMIZER, SCHEDULER, LOSS_FN, 1000)
    val_loss, val_acc = test(valloader, model, LOSS_FN)
    
    iter_loss = iter_loss + tloss
    epoch_loss.append(sum(tloss) / len(tloss))
    
    print(f'Epoch {t}: LOSS = {epoch_loss[-1]}, VAL-ACC = {val_acc}')
    
fig, axes = plt.subplots()
axes.plot(iter_loss, label = 'train-loss')
axes.legend()
axes.set_xlabel('Iteration')
axes.set_ylabel('Loss')
plt.show()

In [None]:
_, val_acc = test(testloader, model, LOSS_FN)
print(f'Test accuracy: {val_acc}')