# Testing pooling strategy 

In [1]:
import os
import sys

def adding_module_path():
    module_path = os.path.abspath(os.path.sep.join([".."]*3))

    if module_path not in sys.path:
        sys.path.append(module_path)

adding_module_path()

In [2]:
import tensorflow as tf
from src.types.transformer_pooling import TransformerPooling
from src.types.transformer_pooling_strategy import TransformerPoolingStrategy

In [3]:
def verify_bert_pooling_input(
    pooling_type,
    transformer_pooling_strategy=None,
    transformer_start_index=None,
    transformer_end_index=None
):
    if pooling_type in [TransformerPooling.LastHiddenState, TransformerPooling.Pooler] and (transformer_pooling_strategy is not None or transformer_start_index != -1 or transformer_end_index != -1):
        assert Exception(f"Cannot use pooling strategy when is not used {TransformerPooling.HiddenStates.value}")
        return None

In [4]:
from src.data_loading.get_dataset_object_from import get_dataset_all
from transformers import AutoConfig
from transformers import TFAutoModel
from src.types.transformer_name import TransformerName
from src.tokenizers.prepare_dataset_from_tokenizer import prepare_dataset_from_tokenizer
from src.utils.create_path_to_gutenberg import get_paths_to_gutenberg, get_path_to_gutenberg_sets
from src.tokenizers.transformer_tokenizer import TransformerTokenizer
from src.encoder.create_encoder_from_path import create_encoder_from_path
from src.data_loading.get_dataset_object_from import get_dataset_object_from_path, get_datasets


In [5]:
model_name = TransformerName.BertBaseUncased.value
path_data, path_authors = get_path_to_gutenberg_sets(5, 3)

In [6]:
train, valid, test = get_datasets(path_data, ';', None)

Loading dataset from=C:\Users\Vojta\Desktop\diploma\data\gutenberg\5Authors\Sentence3\train.csv
Loading dataset from=C:\Users\Vojta\Desktop\diploma\data\gutenberg\5Authors\Sentence3\valid.csv
Loading dataset from=C:\Users\Vojta\Desktop\diploma\data\gutenberg\5Authors\Sentence3\test.csv


In [7]:
model_name

'bert-base-uncased'

In [8]:
tokenizer = TransformerTokenizer(
    model_name, 
    create_encoder_from_path(
        path_authors
    )
)

In [9]:
config = AutoConfig.from_pretrained(model_name, output_hidden_states=True)
transformer = TFAutoModel.from_config(config)

In [10]:
class BertPoolingLayer(tf.keras.layers.Layer):
    
    def call(
        self, 
        inputs, 
        pooling_type, 
        transformer_pooling_strategy,
        transformer_start_index,
        transformer_end_index
    ):
        verify_bert_pooling_input(
            pooling_type, 
            transformer_pooling_strategy, 
            transformer_start_index, 
            transformer_end_index
        )
        
        if pooling_type == TransformerPooling.LastHiddenState:
            last_hidden_state = inputs[TransformerPooling.LastHiddenState.value]
            return tf.reduce_mean(last_hidden_state, axis=1)
            
        if pooling_type == TransformerPooling.Pooler:
            pooler = inputs[TransformerPooling.Pooler.value]
            return pooler
                
        if pooling_type == TransformerPooling.HiddenStates:
            selector = inputs[TransformerPooling.HiddenStates.value]

            number_of_layers = len(selector) - 1
            index_start_from_behinde = number_of_layers - transformer_start_index
            index_end_from_behinde = number_of_layers - transformer_end_index + 1

            selector = selector[index_start_from_behinde:index_end_from_behinde]

            if transformer_pooling_strategy in [TransformerPoolingStrategy.ConcatAverage, TransformerPoolingStrategy.ConcatCLS] :
                concatened = tf.concat(selector, axis=2)

                if transformer_pooling_strategy == TransformerPoolingStrategy.ConcatCLS:
                    cls = concatened[:, 0, :]
                    return cls
                else:
                    averaged_sentence = tf.reduce_mean(concatened, axis=1) 
                    return averaged_sentence
            else:
                tf_tensor = tf.convert_to_tensor(selector)
                averaged = tf.reduce_mean(tf_tensor, axis=0)
                if transformer_pooling_strategy == TransformerPoolingStrategy.CLS:
                    cls = averaged[:, 0, :]
                    return cls
                else:
                    averaged_sentence = tf.reduce_mean(averaged, axis=1) 
                    return averaged_sentence

In [11]:
for x in prepare_dataset_from_tokenizer(train, tokenizer).batch(10):
    #print(x)
    text, label = x
    output = transformer(text, output_hidden_states=True)
    output = BertPoolingLayer()(output, TransformerPooling.HiddenStates, TransformerPoolingStrategy.ConcatCLS, 1, 0) #odecteni od posledni
    print(tf.shape(output))
    break

tf.Tensor([  2  10 512 768], shape=(4,), dtype=int32)
tf.Tensor([  10 1536], shape=(2,), dtype=int32)


In [12]:
from src.models.transformer.pooling_strategy import pooling_strategy_dictionary, TransformerPoolingStrategySelection

In [13]:
for x in prepare_dataset_from_tokenizer(train, tokenizer).batch(10):
    #print(x)
    text, label = x
    output = transformer(text, output_hidden_states=True)
    output = BertPoolingLayer()(output, *pooling_strategy_dictionary[TransformerPoolingStrategySelection.ConcatLast4LayersCLS]) #odecteni od posledni
    print(tf.shape(output))
    break

tf.Tensor([  4  10 512 768], shape=(4,), dtype=int32)
tf.Tensor([  10 3072], shape=(2,), dtype=int32)
