# 0. Imports, setup, utilities

In [18]:
import functools
import pandas as pd
import torch
import seaborn as sns
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.nn import Module, Linear, Dropout, ELU, Sequential, Sigmoid
from torch.nn.functional import binary_cross_entropy
from torch.optim import Adam
from scipy.stats import spearmanr

from sklearn.model_selection import train_test_split

%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../utilities/')
from utilities import Timer, ProgressBar

sns.set()





The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
# detect gpu availability
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [3]:
# Utility function: map returning a list
def lmap(func, iterable):
    return list(map(func, iterable))

In [4]:
debug = True

# 1. Prepare the Data
### 1a. Preprocess/Tokenize

In [5]:
%%capture
def preprocess_BERT(encoder, q_row):
    '''Preprocess dataframe row for BERT
    
    Parameters
    ----------
    encoder: callable, takes in text, returns encoded tokens, 
             should be provided by pre-trained model
    
    q_row  : dataframe row containing columns for question_title, 
             question_body, and answer
    
    Returns
    -------
    pandas Series of entries, each entry a list of length 512.
    Entries: tokens, a mask, and positional embeddings.
    '''
    
    # Set max length allowed by BERT model
    MAX_LENGTH = 512
    
    # Get question title, body, and answer from dataframe row
    question = q_row.question_title + q_row.question_body
    answer   = q_row.answer
    
    # Encode question and answer without [CLS] and [SEP]
    question_tok = encoder(question, add_special_tokens = False)
    answer_tok   = encoder(answer, add_special_tokens = False)

    # Truncate tokens to length MAX_LENGTH - 3 to account for special tokens
    while len(question_tok + answer_tok) > (MAX_LENGTH - 3):
        
        # Pick the longest list, then pop last item in list
        # Default to shortening answer if there is a tie
        array_to_pop = max([answer_tok, question_tok], key = len)
        array_to_pop.pop()
    
    # Get encodings for [CLS] and [SEP]
    cls_token_encoded = encoder(['[CLS]'], add_special_tokens = False)
    sep_token_encoded = encoder(['[SEP]'], add_special_tokens = False)
    
    # Combine question, answer, and special tokens
    content_tok = cls_token_encoded + question_tok + \
                  sep_token_encoded + answer_tok   + \
                  sep_token_encoded
    
    # Create padding
    padding_len = MAX_LENGTH - len(content_tok)
    padding     = [0] * padding_len
    
    # Add padding
    final_tok   = content_tok + padding
    
    # Compute segment_ids
    segment_ids = [0] * (len(question_tok) + 2) + \
                  [1] * (len(answer_tok)   + 1) + \
                  padding
    
    # Compute the mask
    mask        = [1] * len(content_tok) + padding
    
    return pd.Series({
        'tokens'      : final_tok,
        'segment_ids' : segment_ids,
        'mask'        : mask
    })

# Load in tokenizer for BERT base uncased
BERT_base_uncased_tokenizer  = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-uncased') 

# Curry preprocess function and partially apply it
preprocess_BERT_base_uncased = functools.partial(preprocess_BERT, 
                                                 BERT_base_uncased_tokenizer.encode)

### 1b. Create a PyTorch Dataset from the Processed Data

In [6]:
target_cols = ['question_asker_intent_understanding', 'question_body_critical', 
               'question_conversational', 'question_expect_short_answer', 
               'question_fact_seeking', 'question_has_commonly_accepted_answer', 
               'question_interestingness_others', 'question_interestingness_self', 
               'question_multi_intent', 'question_not_really_a_question', 
               'question_opinion_seeking', 'question_type_choice', 
               'question_type_compare', 'question_type_consequence', 
               'question_type_definition', 'question_type_entity', 
               'question_type_instructions', 'question_type_procedure', 
               'question_type_reason_explanation', 'question_type_spelling', 
               'question_well_written', 'answer_helpful', 
               'answer_level_of_information', 'answer_plausible', 
               'answer_relevance', 'answer_satisfaction', 
               'answer_type_instructions', 'answer_type_procedure', 
               'answer_type_reason_explanation', 'answer_well_written']

In [7]:
def create_dataset(df, preprocessor, target_cols = None):
    '''Create a dataset from a pandas dataframe
    
    Parameters
    ----------
    
    df: Pandas dataframe with text columns available for the preprocessor 
        and containing the target columns
        
    preprocessor: callable taking a row of a dataframe and returning 
                  a Series containing the inputs as lists in each entry
                  
    target_cols: list of column names to use as the target.
    If None, no labels are included
    
    Returns
    -------
    
    PyTorch Dataset (batched)
    
    '''
    # Process the input data into a dataframe with 3 columns
    processed_data = df.apply(preprocessor, axis = 'columns')

    # Convert each of those three columns into a tensor
    def convert_col_to_tensor(col):
        # Convert each list entry to a tensor. Then stack them into one large tensor
        col = lmap(lambda list_ : torch.tensor(list_, dtype = torch.long), col.tolist())
        return torch.stack(col)

    tokens      = convert_col_to_tensor(processed_data.tokens).to(device)
    segment_ids = convert_col_to_tensor(processed_data.segment_ids).to(device)
    mask        = convert_col_to_tensor(processed_data['mask']).to(device)
    
    data        = [tokens, segment_ids, mask]
    
    # Collect the target columns
    if target_cols is not None:
        targets     = torch.tensor(df[target_cols].values, dtype = torch.float32).to(device)
        data.append(targets)

    # Construct a Torch Dataset, then a DataLoader that random samples and batches
    dataset     = TensorDataset(*data)
    if debug:
        dataset     = DataLoader(dataset, 1, shuffle = True)
    else:
        dataset     = DataLoader(dataset, 32, shuffle = True)

    return dataset

### 1c. Train Test Split

In [8]:
%%capture
# Suppress warnings when tokenizing sentences longer than the allowed length of 512

# Load the original data
train_df_all = pd.read_csv('../input/google-quest-challenge/train.csv')
test_df      = pd.read_csv('../input/google-quest-challenge/test.csv')

if debug:
    train_df_all = train_df_all.iloc[:10]

# Create Train and Validation Splits
train_df, valid_df = train_test_split(train_df_all, random_state = 42, train_size = 0.8)

Timer.start()
# Create PyTorch Datasets
train = create_dataset(train_df, preprocess_BERT_base_uncased, target_cols)
valid = create_dataset(valid_df, preprocess_BERT_base_uncased, target_cols)
test  = create_dataset( test_df, preprocess_BERT_base_uncased)

In [9]:
Timer.end()

8.57 seconds elapsed


# 2. Construct the Model

In [10]:
# Build the BERT Model with a head
class BERT(Module):
    
    def __init__(self, dropout):
        super(BERT, self).__init__()
        self.bert_embed = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-uncased')
        self.classifier = Sequential(
            Linear(768 * 2, 2048),
            Dropout(dropout)     ,
            ELU()                ,
            Linear(2048   , 2048),
            Dropout(dropout)     ,
            ELU()                ,
            Linear(2048   ,   30),
        )
        
    def forward(self, tokens, segment_ids, mask):
        # Apply the main BERT
        sequence_output, pooled_output = self.bert_embed(tokens, segment_ids, mask)
        
        # Average the Pooled Outputs, taking into account the mask
        mask_expanded = torch.unsqueeze(mask, dim = -1)
        seq_reduced   = torch.sum(sequence_output * mask_expanded, dim = 1)
        mask_size     = torch.sum(mask, dim = 1)
        seq_reduced   = seq_reduced / mask_size
        
        # Concatenate the pooled and seq(uential)_reduced tensors
        signal        = torch.cat([pooled_output, seq_reduced], dim = 1)
        
        # Run the forward classifier. Output: logits (for cross-entropy loss)
        signal        = self.classifier(signal)
        
        return signal
      

# 3. Train the Model

In [51]:
def train_loop(model, model_name, 
          optimizer, loss_fn, 
          train_data, val_data, 
          epochs = 30, 
          early_stopping = 2,
          restore_best_model = False
         ):
    
    best_val_spearman = None
    bast_val_epoch    = None
    
    for epoch in range(epochs):
        ############################################
        #                Training                  #
        ############################################

        
        print(f'Epoch {epoch+1}')
        
        # Set up a progress bar, training loss accumulator, mini batch counter
        bar = ProgressBar(len(train_data))
        bar.start()
        train_loss_total = 0.0
        mini_batch       = 0
        
        model.train()
        for batch in train_data:
            
            # Do a gradient descent step
            tokens, segment_ids, mask, target = batch
            optimizer.zero_grad()
            output = model(tokens, segment_ids, mask)
            loss   = loss_fn(output, target)
            loss.backward()
            optimizer.step()
            
            # Compute metrics and display progress bar
            train_loss_batch   = loss.data.item()
            train_loss_total  += train_loss_batch
            
            mini_batch        += 1
            train_loss_running = train_loss_total / mini_batch
            bar.update(mini_batch, {'train_loss' : train_loss_running})
                
        ############################################
        #                Validation                #
        ############################################

        
        # Compute validation loss and metrics
        val_loss = 0.0
        all_predictions = []
        all_targets     = []
        model.eval()
        for batch in val_data:
            
            # Compute the batch loss
            tokens, segment_ids, mask, target = batch
            output   = model(tokens, segment_ids, mask)
            loss     = loss_fn(output, target)
            val_loss += loss / len(val_data)
            
            # Accumulate the predictions and targets
            processed_outputs = torch.sigmoid(output).cpu()
            all_predictions.append(processed_outputs)
            all_targets.append(target.cpu())
            
        all_predictions = torch.stack(all_predictions)
        all_targets     = torch.stack(all_targets)
        
        # Calculate Spearman correlation coefficient
        spearman_coef = 0.0
        NUM_TARGETS   = all_targets.shape[-1]
        
        all_targets     =     all_targets.detach().cpu().numpy().squeeze()
        all_predictions = all_predictions.detach().cpu().numpy().squeeze()
        
        for i in range(NUM_TARGETS):
            spearman_coef += spearmanr(all_targets[:, i], all_predictions[:, i]).correlation / NUM_TARGETS
        
        bar.update(mini_batch, {'train_loss'     : train_loss_running,
                                'val_loss'       : val_loss, 
                                'val_spearman'   : spearman_coef
                               })
        
        ############################################
        #                Callbacks                 #
        ############################################

        
        # Checkpoint saving
        if best_val_spearman is None or best_val_spearman < spearman_coef:
            # Save the new best model, overwriting the old one
            torch.save(model, f'./checkpoints/{model_name}.pt')
            best_val_epoch = epoch
            
        # Early Stopping
        elif epoch >= best_val_epoch + early_stopping:
            if restore_best_model:
                model = torch.load(f'./checkpoints/{model_name}.pt')
            break
            
    return model

In [21]:
def cross_entropy_loss(output_batch, target_batch):
    return binary_cross_entropy(torch.sigmoid(output_batch), target_batch, reduction = 'mean')

In [13]:
model     = BERT(0.1)
optimizer = Adam(model.classifier.parameters())
loss_fn   = cross_entropy_loss

Using cache found in /Users/rcharan/.cache/torch/hub/huggingface_pytorch-transformers_master


In [None]:
targets, predictions = train_loop(model, 'bert_base_uncased_1', optimizer, loss_fn, train, valid, restore_best_model=True)

In [None]:
# Timer.start()
# writer = SummaryWriter()
# writer.add_graph(model, next(iter(train))[:3])
# Timer.end()