# Attentive Pooling Networks
## Combined: BiLSTMs and CNNs

In [1]:
import time
import string
import random
import copy

import numpy as np
import warnings
from sklearn.metrics import average_precision_score
import datasets
import random
from tqdm import tqdm
from collections import defaultdict

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

GLOVE_FILE = '/hdd/data/NLP_data/word_vectors/glove.6B/glove.6B.300d.txt' 

### Hyperparameters

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TRAIN_EPOCHS = 20
BATCH_SIZE = 20
TEST_BATCH_SIZE = 128
LOSS_MARGIN = 0.5
TRAIN_NEG_COUNT = 50  # Amount of random negative answers for every question in training
Q_LENGTH = 20
A_LENGTH = 100
PAD_WORD = '<UNK>'  
KERNEL_COUNT = 400
KERNEL_SIZE = 3
RNN_HIDDEN = 150

device

device(type='cuda')

### Select questions-answer pairs with atleast one correct answer

In [3]:
def get_valid_questions(wikiqa):
    question_status = dict()

    for split in wikiqa:
        split_dataset = wikiqa[split]
        n_samples = len(split_dataset)

        for i in range(n_samples):
            qid = split_dataset[i]['question_id']
            label = split_dataset[i]['label']
            if qid not in question_status:
                question_status[qid] = label
            else:
                question_status[qid] = max(question_status[qid], label)

    valid_questions = set([qid for qid in question_status if question_status[qid] > 0])
    
    return valid_questions 

In [4]:
wikiqa = load_dataset('wiki_qa')
valid_questions = get_valid_questions(wikiqa)
wikiqa_f = wikiqa.filter(lambda sample: sample['question_id'] in valid_questions)

wikiqa_f

Using custom data configuration default
Reusing dataset wiki_qa (/home/at/.cache/huggingface/datasets/wiki_qa/default/0.1.0/d2d236b5cbdc6fbdab45d168b4d678a002e06ddea3525733a24558150585951c)


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/21 [00:00<?, ?ba/s]

DatasetDict({
    test: Dataset({
        features: ['question_id', 'question', 'document_title', 'answer', 'label'],
        num_rows: 2351
    })
    validation: Dataset({
        features: ['question_id', 'question', 'document_title', 'answer', 'label'],
        num_rows: 1130
    })
    train: Dataset({
        features: ['question_id', 'question', 'document_title', 'answer', 'label'],
        num_rows: 8672
    })
})

In [5]:
#positive_questions 
len([s for s in wikiqa_f['train'] if s['label']==1]  )

1040

### Embedding

In [6]:
def load_glove(filename):
    with open(filename, encoding='utf-8') as f:
        word_emb = list()
        word_dict = dict()
        word_emb.append([0])
        word_dict['<UNK>'] = 0
        for line in f.readlines():
            tokens = line.split(' ')
            word_emb.append([float(i) for i in tokens[1:]])
            word_dict[tokens[0]] = len(word_dict)
        word_emb[0] = [0] * len(word_emb[1])
    return word_emb, word_dict

In [7]:
word_emb, word_dict = load_glove(GLOVE_FILE)
len(word_emb), len(word_dict)

(400001, 400001)

### DataSet Class

In [8]:
class WikiQADataset(Dataset):
    def __init__(self, word_dict, wikiqa, mode='train'):
        self.mode = mode
        pad_num = word_dict[PAD_WORD]
        
        def sent_process(sent, p_len):  # vocab to id -> padding
            return [word_dict.get(w.lower(), pad_num) for w in sent[:p_len]] + [pad_num] * (p_len - len(sent)) 
        
        def get_tokens(big_str):
            return big_str.translate(str.maketrans('', '', string.punctuation)).lower().split()

        positive_questions = [s for s in wikiqa if s['label']==1]
        quests, answer_pos, answer_neg, answers, labels = [], [], [], [], []
        
        for sample in tqdm(positive_questions):
            quest = sample['question']
            quest = sent_process(get_tokens(quest), Q_LENGTH)      # List of vocab indices
            pos_ans = sample['answer'] 
            pos_ans = sent_process(get_tokens(pos_ans), A_LENGTH)  # List of vocab indices
            labels.append(1)

            qid = sample['question_id']               
            # First preference for above question's negative answers provided in dataset
            neg_answers = [s['answer'] for s in wikiqa if s['label']==0 and s['question_id']==qid ] # List of strings
            # Next filled randomly
            more_neg_ans_required = TRAIN_NEG_COUNT - len(neg_answers)
            eligible_samples = [s['answer'] for s in wikiqa if s['question_id']!=qid ] # List of strings
            more_neg_ans = random.sample(eligible_samples, more_neg_ans_required)
            neg_answers = neg_answers + more_neg_ans # List of strings
            neg_answers = [sent_process(get_tokens(a), A_LENGTH) for a in neg_answers]  # List of List of vocab indices

            quests.append(quest)           # List of List of vocab indices
            answer_pos.append(pos_ans)     # List of List of vocab indices
            answer_neg.append(neg_answers) # List of List of List of vocab indices
            
            labels += [0]*len(neg_answers)    # List of binary 0/1
            
            all_ans = [pos_ans] + neg_answers   # List of List of vocab indices      
            answers += all_ans                  # List of List of vocab indices

        self.q = torch.LongTensor(quests)        
        
        if mode == 'train':
            self.a_pos = torch.LongTensor(answer_pos)
            self.a_neg = torch.LongTensor(answer_neg)                
        else: 
            self.a = torch.LongTensor(answers)
            self.y = torch.LongTensor(labels)            

    def __getitem__(self, idx):
        if self.mode == 'train':
            return self.q[idx], self.a_pos[idx], self.a_neg[idx]
        return self.q[idx], self.a[idx], self.y[idx]

    def __len__(self): return self.q.shape[0]

    def __str__(self): return f'Dataset {self.mode}: {len(self.q)} samples.'                

In [9]:
train_data = WikiQADataset(word_dict, wikiqa_f['train'], mode='train')
print(train_data)
valid_data = WikiQADataset(word_dict, wikiqa_f['validation'], mode='valid')
print(valid_data)

100%|██████████████████████████████████████| 1040/1040 [22:48<00:00,  1.32s/it]


Dataset train: 1040 samples.


100%|████████████████████████████████████████| 140/140 [00:23<00:00,  5.96it/s]

Dataset valid: 140 samples.





In [10]:
train_dlr = DataLoader(train_data, batch_size=BATCH_SIZE, num_workers=4, shuffle=True)
valid_dlr = DataLoader(valid_data, batch_size=BATCH_SIZE, num_workers=4)

### Utils

In [11]:
def evaluate(model, dataloader):
    predict = defaultdict(list)

    for q, a, y in dataloader:
        cos = model(q, a)
        i=0
        for pred, label in zip(cos.detach().cpu().numpy(), y.numpy()):
            predict[i].append((pred, label))
            i += 1

    accuracy = 0
    MRR = 0
    average_precisions = [] 
    
    warnings.filterwarnings("ignore")    
    for p in predict.values():
        
        y_true = np.array([s[1] for s in p ])
        y_pred = np.array([s[0] for s in p ])
        ap = average_precision_score(y_true, y_pred)
        average_precisions.append(ap)
        
        p.sort(key=lambda x: -x[0])
        if p[0][1] == 1:
            accuracy += 1
            
        for i, t in enumerate(p):
            if t[1] == 1:
                MRR += 1 / (i + 1)
                break
    
    accuracy = accuracy / len(predict)
    MRR = MRR / len(predict)
    mAP = sum(average_precisions)/len(average_precisions)
    #warnings.filterwarnings("default")
    return accuracy, MRR, mAP

In [12]:
def process_bar(percent, start_str='', end_str='', auto_rm=True):
    bar = '=' * int(percent * 50)
    bar = '\r{}|{}| {:.1%} | {}'.format(start_str, bar.ljust(50), percent, end_str)
    print(bar, end='', flush=True)
    if percent == 1:
        print(end=('\r' + ' ' * len(bar) + '\r') if auto_rm else '\n', flush=True)

In [13]:
class LearnRate:
    def __init__(self, optimizer):
        self.opt = optimizer
        self.init_lr = optimizer.state_dict()['param_groups'][0]['lr']
        self.epoch = 1

    def step(self):
        self.epoch += 1
        for p in self.opt.param_groups:
            p['lr'] = self.init_lr / self.epoch

    def get_last_lr(self): return [self.init_lr / self.epoch]

### Models

In [14]:
# Without Attention

class CNN(nn.Module):
    def __init__(self, word_dim, kernel_count, kernel_size):
        super().__init__()
        self.encode = nn.Conv1d(
            in_channels=word_dim,
            out_channels=kernel_count,
            kernel_size=kernel_size,
            padding=(kernel_size - 1) // 2)

    def forward(self, vec):
        latent = self.encode(vec.permute(0, 2, 1))
        return latent


class BiLSTM(nn.Module):
    def __init__(self, word_dim, hidden_size):
        super().__init__()
        self.encode = nn.LSTM(input_size=word_dim, hidden_size=hidden_size, bidirectional=True, batch_first=True)

    def forward(self, vec):
        self.encode.flatten_parameters()
        latent, _ = self.encode(vec)
        return latent.transpose(-1, -2)

In [15]:
class CoAttention(nn.Module):
    def __init__(self, hidden_size, init_U='randn'):
        super().__init__()
        if init_U == 'zeros':
            self.U = nn.Parameter(torch.zeros(hidden_size, hidden_size))
        else:
            self.U = nn.Parameter(torch.randn(hidden_size, hidden_size))

    def forward(self, Q, A):
        G = Q.transpose(-1, -2) @ self.U.expand(Q.shape[0], -1, -1) @ A
        G = torch.tanh(G)
        Q_pooling = G.max(dim=-1)[0]
        A_pooling = G.max(dim=-2)[0]
        Q_pooling = Q_pooling.softmax(dim=-1)
        A_pooling = A_pooling.softmax(dim=-1)
        rq = Q @ Q_pooling.unsqueeze(-1)
        ra = A @ A_pooling.unsqueeze(-1)
        rq = rq.squeeze(-1)
        ra = ra.squeeze(-1)
        return rq, ra

In [16]:
class QAModel(nn.Module):
    def __init__(self, word_emb, model_name):
        super().__init__()
        self.model_name = model_name
        self.embedding = nn.Embedding.from_pretrained(torch.Tensor(word_emb))
        self.embedding.weight.requires_grad_()

        if 'CNN' in model_name:
            self.encode = CNN(self.embedding.embedding_dim, KERNEL_COUNT, KERNEL_SIZE)
            if 'AP' in model_name:
                self.coAttention = CoAttention(KERNEL_COUNT, init_U='zeros')
        elif 'biLSTM' in model_name:
            self.encode = BiLSTM(self.embedding.embedding_dim, RNN_HIDDEN)
            if 'AP' in model_name:
                self.coAttention = CoAttention(RNN_HIDDEN * 2)

    def forward(self, questions, answers):
        device = next(self.parameters()).device
        questions = questions.to(device)
        answers = answers.to(device)

        q_emb = self.embedding(questions)
        a_emb = self.embedding(answers)
        Q = self.encode(q_emb)
        A = self.encode(a_emb)
        if 'AP' in self.model_name:
            rq, ra = self.coAttention(Q, A)
        else:
            rq = Q.max(dim=-1)[0]
            ra = A.max(dim=-1)[0]
            rq = torch.tanh(rq)
            ra = torch.tanh(ra)
        cos = torch.sum(rq * ra, dim=-1) / (rq.norm(dim=-1) * ra.norm(dim=-1))
        return cos

In [17]:
def run_model(model_name, learning_rate):
    model = QAModel(word_emb, model_name).to(device)
    model_path = model_name + '.pt'

    opt = torch.optim.SGD(model.parameters(), learning_rate, weight_decay=1e-6)
    lr_sch = LearnRate(opt)

    max_mAP = 0
    for epoch in (range(TRAIN_EPOCHS)):
        model.train()
        total_loss, total_samples = 0, 0
        for q, a_pos, a_neg in train_dlr:
            cos_pos = model(q, a_pos)
            # Only the negative answer with max score value is used to update model weights
            input_q = q.unsqueeze(-2).expand(-1, a_neg.shape[-2], -1).reshape(-1, q.shape[-1])
            input_a = a_neg.view(-1, a_neg.shape[-1])
            cos_neg = model(input_q, input_a)
            cos_neg = cos_neg.view(len(q), -1).max(dim=-1)[0]

            loss = torch.max(torch.zeros(1).to(cos_pos.device), LOSS_MARGIN - cos_pos + cos_neg).mean()
            opt.zero_grad()
            loss.backward()
            opt.step()

            total_loss += loss.item() * len(q)
            total_samples += len(q)
            process_bar(total_samples / len(train_dlr.dataset), start_str=f'Epoch {epoch+1}')
        curr_lr = lr_sch.get_last_lr()[0]
        lr_sch.step()
        model.eval()

        train_loss = total_loss / total_samples
        mAP, MRR, _ = evaluate(model, valid_dlr)
        
        # Early Stopping 
        if max_mAP < mAP:
            max_mAP = mAP
            best_MRR = MRR
            if isinstance(model, torch.nn.DataParallel):
                torch.save(model.module, model_path)
            else:
                torch.save(model, model_path)
        print(f'Epoch {epoch+1:2d}; learning rate {curr_lr:.4f}; train loss {train_loss:.6f}; '
                    f'validation mAP {mAP * 100:.2f}%, MRR {MRR:.4f}')
    print(f'End of training') 
    return max_mAP, best_MRR

### Experiments

In [18]:
QA_CNN_mAP, QA_CNN_MRR = run_model('QA-CNN', learning_rate=0.05) 

Epoch  1; learning rate 0.0500; train loss 0.509099; validation mAP 10.42%, MRR 0.1250
Epoch  2; learning rate 0.0250; train loss 0.496364; validation mAP 19.64%, MRR 0.2278
Epoch  3; learning rate 0.0167; train loss 0.485139; validation mAP 28.32%, MRR 0.3145
Epoch  4; learning rate 0.0125; train loss 0.484369; validation mAP 37.48%, MRR 0.3952
Epoch  5; learning rate 0.0100; train loss 0.473812; validation mAP 45.89%, MRR 0.4555
Epoch  6; learning rate 0.0083; train loss 0.443380; validation mAP 48.74%, MRR 0.4923
Epoch  7; learning rate 0.0071; train loss 0.413030; validation mAP 51.54%, MRR 0.5301
Epoch  8; learning rate 0.0063; train loss 0.372736; validation mAP 54.98%, MRR 0.5666
Epoch  9; learning rate 0.0056; train loss 0.332482; validation mAP 56.49%, MRR 0.5824
Epoch 10; learning rate 0.0050; train loss 0.312259; validation mAP 59.62%, MRR 0.6009
Epoch 11; learning rate 0.0045; train loss 0.292061; validation mAP 61.86%, MRR 0.6201
Epoch 12; learning rate 0.0042; train loss 

In [19]:
AP_CNN_mAP, AP_CNN_MRR = run_model('AP-CNN', learning_rate=1.1)  

Epoch  1; learning rate 1.1000; train loss 0.503856; validation mAP 08.51%, MRR 0.1023
Epoch  2; learning rate 0.5500; train loss 0.499277; validation mAP 17.86%, MRR 0.1987
Epoch  3; learning rate 0.3667; train loss 0.496948; validation mAP 26.22%, MRR 0.2815
Epoch  4; learning rate 0.2750; train loss 0.491942; validation mAP 35.45%, MRR 0.3663
Epoch  5; learning rate 0.2200; train loss 0.472646; validation mAP 42.72%, MRR 0.4221
Epoch  6; learning rate 0.1833; train loss 0.442426; validation mAP 47.86%, MRR 0.4617
Epoch  7; learning rate 0.1571; train loss 0.404502; validation mAP 50.42%, MRR 0.5971
Epoch  8; learning rate 0.1375; train loss 0.370023; validation mAP 53.91%, MRR 0.6138
Epoch  9; learning rate 0.1222; train loss 0.337608; validation mAP 56.22%, MRR 0.6241
Epoch 10; learning rate 0.1100; train loss 0.310621; validation mAP 59.31%, MRR 0.6318
Epoch 11; learning rate 0.1000; train loss 0.279053; validation mAP 61.42%, MRR 0.6402
Epoch 12; learning rate 0.0917; train loss 

In [20]:
QA_biLSTM_mAP, QA_biLSTM_MRR = run_model('QA-biLSTM', learning_rate=1.1)

Epoch  1; learning rate 1.1000; train loss 0.508980; validation mAP 09.33%, MRR 0.1099
Epoch  2; learning rate 0.5500; train loss 0.499248; validation mAP 18.25%, MRR 0.1814
Epoch  3; learning rate 0.3667; train loss 0.495358; validation mAP 25.28%, MRR 0.2683
Epoch  4; learning rate 0.2750; train loss 0.492272; validation mAP 33.61%, MRR 0.3471
Epoch  5; learning rate 0.2200; train loss 0.473643; validation mAP 38.78%, MRR 0.4121
Epoch  6; learning rate 0.1833; train loss 0.441291; validation mAP 42.86%, MRR 0.4518
Epoch  7; learning rate 0.1571; train loss 0.403296; validation mAP 47.33%, MRR 0.4935
Epoch  8; learning rate 0.1375; train loss 0.373356; validation mAP 51.45%, MRR 0.5222
Epoch  9; learning rate 0.1222; train loss 0.335239; validation mAP 54.62%, MRR 0.5554
Epoch 10; learning rate 0.1100; train loss 0.316254; validation mAP 57.39%, MRR 0.5894
Epoch 11; learning rate 0.1000; train loss 0.278164; validation mAP 59.22%, MRR 0.6083
Epoch 12; learning rate 0.0917; train loss 

In [21]:
AP_biLSTM_mAP, AP_biLSTM_MRR = run_model('AP-biLSTM', learning_rate=1.1)

Epoch  1; learning rate 1.1000; train loss 0.516541; validation mAP 10.07%, MRR 0.1121
Epoch  2; learning rate 0.5500; train loss 0.498452; validation mAP 18.13%, MRR 0.1955
Epoch  3; learning rate 0.3667; train loss 0.496472; validation mAP 26.29%, MRR 0.2598
Epoch  4; learning rate 0.2750; train loss 0.489521; validation mAP 34.81%, MRR 0.3342
Epoch  5; learning rate 0.2200; train loss 0.474712; validation mAP 41.94%, MRR 0.4045
Epoch  6; learning rate 0.1833; train loss 0.446541; validation mAP 47.45%, MRR 0.4664
Epoch  7; learning rate 0.1571; train loss 0.419647; validation mAP 51.34%, MRR 0.5182
Epoch  8; learning rate 0.1375; train loss 0.375412; validation mAP 54.62%, MRR 0.5452
Epoch  9; learning rate 0.1222; train loss 0.332471; validation mAP 56.18%, MRR 0.5641
Epoch 10; learning rate 0.1100; train loss 0.314581; validation mAP 58.25%, MRR 0.5922
Epoch 11; learning rate 0.1000; train loss 0.273548; validation mAP 60.21%, MRR 0.6144
Epoch 12; learning rate 0.0917; train loss 

In [None]:
average_precision_score(np.array([0, 0, 0, 0, 0]), np.array([0.1, 0.1, 0.1, 0.1, 0.1]))

In [None]:
list(model.named_parameters())