This notebook introduces model with various poolers available.

It's based on https://www.kaggle.com/rhtsingh/utilizing-transformer-representations-efficiently notebook.

**If you liked this notebook, please upvote it!**

In [None]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer

In [None]:
class WeightedLayerPooling(nn.Module):
    def __init__(self, num_hidden_layers, layer_start: int = 4, layer_weights = None):
        super(WeightedLayerPooling, self).__init__()
        self.layer_start = layer_start
        self.num_hidden_layers = num_hidden_layers
        self.layer_weights = layer_weights if layer_weights is not None \
            else nn.Parameter(
                torch.tensor([1] * (num_hidden_layers+1 - layer_start), dtype=torch.float)
            )

    def forward(self, all_hidden_states):
        all_layer_embedding = all_hidden_states[self.layer_start:, :, :, :]
        weight_factor = self.layer_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(all_layer_embedding.size())
        weighted_average = (weight_factor*all_layer_embedding).sum(dim=0) / self.layer_weights.sum()
        return weighted_average
    
class LSTMPooling(nn.Module):
    def __init__(self, num_layers, hidden_size, hiddendim_lstm):
        super(LSTMPooling, self).__init__()
        self.num_hidden_layers = num_layers
        self.hidden_size = hidden_size
        self.hiddendim_lstm = hiddendim_lstm
        self.lstm = nn.LSTM(self.hidden_size, self.hiddendim_lstm, batch_first=True)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, all_hidden_states):
        ## forward
        hidden_states = torch.stack([all_hidden_states[layer_i][:, 0].squeeze()
                                     for layer_i in range(1, self.num_hidden_layers+1)], dim=-1)
        hidden_states = hidden_states.view(-1, self.num_hidden_layers, self.hidden_size)
        out, _ = self.lstm(hidden_states, None)
        out = self.dropout(out[:, -1, :])
        return out

class AttentionPooling(nn.Module):
    def __init__(self, num_layers, hidden_size, hiddendim_fc):
        super(AttentionPooling, self).__init__()
        self.num_hidden_layers = num_layers
        self.hidden_size = hidden_size
        self.hiddendim_fc = hiddendim_fc
        self.dropout = nn.Dropout(0.1)

        q_t = np.random.normal(loc=0.0, scale=0.1, size=(1, self.hidden_size))
        self.q = nn.Parameter(torch.from_numpy(q_t)).float().cuda()
        w_ht = np.random.normal(loc=0.0, scale=0.1, size=(self.hidden_size, self.hiddendim_fc))
        self.w_h = nn.Parameter(torch.from_numpy(w_ht)).float().cuda()

    def forward(self, all_hidden_states):
        hidden_states = torch.stack([all_hidden_states[layer_i][:, 0].squeeze()
                                     for layer_i in range(1, self.num_hidden_layers+1)], dim=-1)
        hidden_states = hidden_states.view(-1, self.num_hidden_layers, self.hidden_size)
        out = self.attention(hidden_states)
        out = self.dropout(out)
        return out

    def attention(self, h):
        v = torch.matmul(self.q, h.transpose(-2, -1)).squeeze(1)
        v = F.softmax(v, -1)
        v_temp = torch.matmul(v.unsqueeze(1), h).transpose(-2, -1)
        v = torch.matmul(self.w_h.transpose(1, 0), v_temp).squeeze(2)
        return v

In [None]:
class ModelWithPooling(nn.Module):
    def __init__(self, base_model, pooling='mean'):
        super().__init__()

        self.pooling = pooling
        
        config = AutoConfig.from_pretrained(base_model)
        config.update({"output_hidden_states":True, 
                       "hidden_dropout_prob": 0.0,
                       "layer_norm_eps": 1e-7})                       
        self.base_model = AutoModel.from_pretrained(base_model, config=config)
        
        if pooling == 'mean_max':
            self.logits = nn.Linear(config.hidden_size*2, 1)
            
        elif pooling == 'conv':
            self.cnn1 = nn.Conv1d(config.hidden_size, 256, kernel_size=2, padding=1)
            self.cnn2 = nn.Conv1d(256, 1, kernel_size=2, padding=1)
            
        elif pooling == 'concat':
            self.logits = nn.Linear(config.hidden_size*4, 1)
            
        elif pooling == 'weighted_layer':
            layer_start = 9
            self.pooler = WeightedLayerPooling(
                config.num_hidden_layers, 
                layer_start=layer_start, 
                layer_weights=None
            )
            self.logits = nn.Linear(config.hidden_size, 1)
            
        elif pooling == 'lstm':
            hiddendim_lstm=256
            self.pooler = LSTMPooling(
                config.num_hidden_layers, 
                config.hidden_size,
                hiddendim_lstm
            )
            self.logits = nn.Linear(hiddendim_lstm, 1)
            
        elif pooling == 'attention':
            hiddendim_fc=128
            self.pooler = AttentionPooling(
                config.num_hidden_layers, 
                config.hidden_size,
                hiddendim_fc
            )
            self.logits = nn.Linear(hiddendim_fc, 1)
            
        else:
            self.logits = nn.Linear(config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        base_model_output = self.base_model(input_ids=input_ids,
                                      attention_mask=attention_mask)        

        if self.pooling == 'mean':
            last_hidden_state = base_model_output.hidden_states[-1]
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() 
            sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
            sum_mask = input_mask_expanded.sum(1)
            sum_mask = torch.clamp(sum_mask, min=1e-9)
            mean_embeddings = sum_embeddings / sum_mask
            logits = self.logits(mean_embeddings)
            
        elif self.pooling == 'max':
            last_hidden_state = base_model_output.hidden_states[-1]
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
            last_hidden_state[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
            max_embeddings = torch.max(last_hidden_state, 1)[0]
            logits = self.logits(max_embeddings)
            
        elif self.pooling == 'mean_max':
            last_hidden_state = base_model_output.hidden_states[-1]
            mean_pooling_embeddings = torch.mean(last_hidden_state, 1)
            max_pooling_embeddings = torch.max(last_hidden_state, 1)[0]
            mean_max_embeddings = torch.cat((mean_pooling_embeddings, max_pooling_embeddings), 1)
            logits = self.logits(mean_max_embeddings)
            
        elif self.pooling == 'conv':
            last_hidden_state = base_model_output.hidden_states[-1]
            last_hidden_state = last_hidden_state.permute(0, 2, 1)   
            cnn_embeddings = F.relu(self.cnn1(last_hidden_state))
            cnn_embeddings = self.cnn2(cnn_embeddings)
            logits = torch.max(cnn_embeddings, 2)[0]
            
        elif self.pooling == 'concat':
            all_hidden_states = torch.stack(base_model_output[2])
            concatenate_pooling = torch.cat(
                (all_hidden_states[-1], all_hidden_states[-2], all_hidden_states[-3], all_hidden_states[-4]),-1
            )
            concatenate_pooling = concatenate_pooling[:, 0]
            logits = self.logits(concatenate_pooling)
        
        elif self.pooling == 'weighted_layer':
            all_hidden_states = torch.stack(base_model_output[2])
            weighted_pooling_embeddings = self.pooler(all_hidden_states)
            weighted_pooling_embeddings = weighted_pooling_embeddings[:, 0]
            logits = self.logits(weighted_pooling_embeddings)
            
        elif self.pooling == 'lstm':
            all_hidden_states = torch.stack(base_model_output[2])
            lstm_pooling_embeddings = self.pooler(all_hidden_states)
            logits = self.logits(lstm_pooling_embeddings)
        
        elif self.pooling == 'attention':
            all_hidden_states = torch.stack(base_model_output[2])
            attention_pooling_embeddings  = self.pooler(all_hidden_states)
            logits = self.logits(attention_pooling_embeddings)
            
        else:
            raise ValueError('Incorrect pooler specified.')
            
        return logits