**Pré-treino em voz auxilia em tarefas de texto?**

Esse notebook destina-se à verificar se a utilização dos pesos do Wa2Vec2 pré-treinado em voz auxiliar o BERT em tarefas de texto *downstream*.



*   Os testes são feitos sempre utilizando a arquitetura do BERT, no qual faz-se o estudo apenas da troca dos pesos do Transformer;
*   



---
### Bibliotecas

In [None]:
# ! pip -q install --upgrade requests  # Dependendo da runtime do COLAB, um upgrade dessas libs pode ser necessário
# ! pip -q install --upgrade urllib3
! pip -q install lightning-bolts
! pip -q install pytorch_lightning
! pip -q install transformers
! pip -q install neptune-client
! pip -q install adabelief-pytorch

IMPORTS

In [None]:
import os
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import NeptuneLogger

from transformers import BertTokenizer, Wav2Vec2CTCTokenizer
from transformers import BertModel, BertForSequenceClassification, Wav2Vec2Model
from transformers import BertConfig
from transformers import get_cosine_schedule_with_warmup

from adabelief_pytorch import AdaBelief

from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR

import gc

---
# IMDB - DATASET

A priori, o dataset do IMDB é escolhido por ser pequeno e podermos avaliar rapidamente a performance dos modelos em testes

In [None]:
!wget -q -nc http://files.fast.ai/data/aclImdb.tgz 
!tar -xzf aclImdb.tgz

In [None]:
# Seed para reprodutibilidade
random.seed(123)
np.random.seed(123)
torch.manual_seed(123)

max_valid = 5000

def load_texts(folder):
    texts = []
    for path in os.listdir(folder):
        with open(os.path.join(folder, path)) as f:
            texts.append(f.read())
    return texts

x_train_pos = load_texts('aclImdb/train/pos')
x_train_neg = load_texts('aclImdb/train/neg')
x_test_pos = load_texts('aclImdb/test/pos')
x_test_neg = load_texts('aclImdb/test/neg')

x_train = x_train_pos + x_train_neg
x_test = x_test_pos + x_test_neg
y_train = [True] * len(x_train_pos) + [False] * len(x_train_neg)
y_test = [True] * len(x_test_pos) + [False] * len(x_test_neg)

# Embaralhamos o treino para depois fazermos a divisão treino/valid.
c = list(zip(x_train, y_train))
random.shuffle(c)
x_train, y_train = zip(*c)

x_valid = x_train[-max_valid:]
y_valid = y_train[-max_valid:]
x_train = x_train[:-max_valid]
y_train = y_train[:-max_valid]

print(len(x_train), 'amostras de treino.')
print(len(x_valid), 'amostras de desenvolvimento.')
print(len(x_test), 'amostras de teste.')

In [None]:
class IMDBDataset():
  def __init__(self, x, y):
    self.x = x
    self.y = y
  
  def __len__(self):
    return len(self.x)
  
  def __getitem__(self, idx):
    return self.x[idx], int(self.y[idx])

In [None]:
def create_dataloader(x, y, tokenizer, batch_size, shuffle=False, max_length=250):
  def data_collator(batch):
    x, y = zip(*batch)
    tokenized_x = tokenizer(x, padding='longest', truncation=True, max_length=max_length, return_tensors='pt')
    return tokenized_x['input_ids'], tokenized_x['attention_mask'], torch.LongTensor(y)
  dataset = IMDBDataset(x, y)
  return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=data_collator)

# OS MODELOS

### Definição do tokenizador
- Nesse caso, para o BERT base

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

### Mapeamento de pesos

Primeiramente, para utilizar os pesos do Wa2Vec2 no BERT precisamos realizar um mapeamento dos nomes dos parâmetros, e uma função para alteração dos mesmo no *state_dict*

In [None]:
MAP_WAV2VEC_TO_BERT_NAMES = {
    'attention.k_proj.weight': 'attention.self.key.weight',
    'attention.k_proj.bias': 'attention.self.key.bias',
    'attention.v_proj.weight': 'attention.self.value.weight',
    'attention.v_proj.bias': 'attention.self.value.bias',
    'attention.q_proj.weight': 'attention.self.query.weight',
    'attention.q_proj.bias': 'attention.self.query.bias',
    'attention.out_proj.weight': 'attention.output.dense.weight',
    'attention.out_proj.bias': 'attention.output.dense.bias',
    'layer_norm.weight': 'attention.output.LayerNorm.weight',
    'layer_norm.bias': 'attention.output.LayerNorm.bias',
    'feed_forward.intermediate_dense.weight': 'intermediate.dense.weight',
    'feed_forward.intermediate_dense.bias': 'intermediate.dense.bias',
    'feed_forward.output_dense.weight': 'output.dense.weight',
    'feed_forward.output_dense.bias': 'output.dense.bias',
    'final_layer_norm.weight': 'output.LayerNorm.weight',
    'final_layer_norm.bias': 'output.LayerNorm.bias',
}

In [None]:
def map_model_state_from_w2v2_to_bert(w2v2_states, bert_states):
    print("Changing weights ...")
    BERT_PREFIX = 'encoder.layer'
    new_weights = {}
    for name, weight in w2v2_states.items():
        if 'encoder.layers.' in name:
            pieces = name.split('.')
            head_number = pieces[2]
            in_name = '.'.join(pieces[3:])
            eq_name = f'{BERT_PREFIX}.{head_number}.{MAP_WAV2VEC_TO_BERT_NAMES[in_name]}'
            print("Updating: ", eq_name, eq_name in bert_states.keys())
            new_weights[eq_name] = weight

    bert_states.update(new_weights)
    return bert_states

## Lista de modelos para teste

### .1. BERT - modelo usado como base

In [None]:
class BERTBaseUncased(nn.Module):
    def __init__(self, num_class, bert_model):
        super().__init__()
        
        self.bert = BertModel.from_pretrained(bert_model)
        self.dropout = nn.Dropout(p=0.1, inplace=False)
        self.classification = nn.Linear(self.bert.config.hidden_size, num_class)

    def freeze_encoder(self, exclude_layernorm=False):
        for name, param in self.bert.encoder.named_parameters():
            if not exclude_layernorm or 'LayerNorm' not in name:
                param.requires_grad = False

    def unfreeze_encoder(self, exclude_layernorm=False):
        for name, param in self.bert.encoder.named_parameters():
            if not exclude_layernorm or 'LayerNorm' not in name:
                param.requires_grad = True

    def forward(self, tokens, masks):
        out_bert = self.bert(tokens, attention_mask = masks).pooler_output
        out_bert = self.dropout(out_bert)
        logits = self.classification(out_bert)
        
        return logits

### .2. BERT - base + pesos do Transformer do Wav2Vec2

In [None]:
class BERTWithWav2Vec2Weights(BERTBaseUncased):
    def __init__(self, num_class, bert_model, wav2vec2_states):
        super().__init__(num_class, bert_model)
        self.load_transformer(wav2vec2_states)

    def load_transformer(self, wav2vec2_states):
        bert_states = self.bert.state_dict()
        bert_states = map_model_state_from_w2v2_to_bert(wav2vec2_states, bert_states)
        self.bert.load_state_dict(bert_states)

### .3. BERT - base + pesos do Transformer inicializados com Xavier

In [None]:
class BERTBaseUncasedXavierInit(BERTBaseUncased):
    def __init__(self, num_class, bert_model):
        super().__init__(num_class, bert_model)
        self.xavier_init_transformer()

    def xavier_init_transformer(self):
        for i, child in enumerate(self.bert.encoder.layer.children()):
            print("Xavier Init on HEAD", i)
            torch.nn.init.xavier_uniform_(child.attention.self.query.weight)
            child.attention.self.query.bias.data.fill_(0.01)

            torch.nn.init.xavier_uniform_(child.attention.self.key.weight)
            child.attention.self.key.bias.data.fill_(0.01)

            torch.nn.init.xavier_uniform_(child.attention.self.value.weight)
            child.attention.self.value.bias.data.fill_(0.01)

            torch.nn.init.xavier_uniform_(child.attention.output.dense.weight)
            child.attention.output.dense.bias.data.fill_(0.01)

            child.attention.output.LayerNorm.reset_parameters()

            torch.nn.init.xavier_uniform_(child.intermediate.dense.weight)
            child.intermediate.dense.bias.data.fill_(0.01)
            
            torch.nn.init.xavier_uniform_(child.output.dense.weight)
            child.output.dense.bias.data.fill_(0.01)

            child.output.LayerNorm.reset_parameters()

### .4. BERT - puro + pesos do Transformer do Wav2Vec2

In [None]:
class BERTRawWithWav2Vec2Weights(nn.Module):
    def __init__(self, num_class, wav2vec2_states):
        super().__init__()

        self.bert = BertModel(BertConfig())
        self.dropout = nn.Dropout(p=0.1, inplace=False)
        self.classification = nn.Linear(self.bert.config.hidden_size, num_class)

        self.load_transformer(wav2vec2_states)
    
    def load_transformer(self, wav2vec2_states):
        bert_states = self.bert.state_dict()
        bert_states = map_model_state_from_w2v2_to_bert(wav2vec2_states, bert_states)
        self.bert.load_state_dict(bert_states)

    def freeze_encoder(self, exclude_layernorm=False):
        for name, param in self.bert.encoder.named_parameters():
            if not exclude_layernorm or 'LayerNorm' not in name:
                param.requires_grad = False

    def unfreeze_encoder(self, exclude_layernorm=False):
        for name, param in self.bert.encoder.named_parameters():
            if not exclude_layernorm or 'LayerNorm' not in name:
                param.requires_grad = True

    def forward(self, tokens, masks):
        out_bert = self.bert(tokens, attention_mask = masks).pooler_output
        out_bert = self.dropout(out_bert)
        logits = self.classification(out_bert)
        
        return logits

### .5. BERT - puro + pesos do Transformer inicializados com Xavier

In [None]:
class BERTRawXavierInit(nn.Module):
    def __init__(self, num_class):
        super().__init__()
        self.bert = BertModel(BertConfig())
        self.dropout = nn.Dropout(p=0.1, inplace=False)
        self.classification = nn.Linear(self.bert.config.hidden_size, num_class)

        self.xavier_init_transformer()

    def xavier_init_transformer(self):
        for i, child in enumerate(self.bert.encoder.layer.children()):
            print("Xavier Init on HEAD", i)
            torch.nn.init.xavier_uniform_(child.attention.self.query.weight)
            child.attention.self.query.bias.data.fill_(0.01)

            torch.nn.init.xavier_uniform_(child.attention.self.key.weight)
            child.attention.self.key.bias.data.fill_(0.01)

            torch.nn.init.xavier_uniform_(child.attention.self.value.weight)
            child.attention.self.value.bias.data.fill_(0.01)

            torch.nn.init.xavier_uniform_(child.attention.output.dense.weight)
            child.attention.output.dense.bias.data.fill_(0.01)

            child.attention.output.LayerNorm.reset_parameters()

            torch.nn.init.xavier_uniform_(child.intermediate.dense.weight)
            child.intermediate.dense.bias.data.fill_(0.01)
            
            torch.nn.init.xavier_uniform_(child.output.dense.weight)
            child.output.dense.bias.data.fill_(0.01)

            child.output.LayerNorm.reset_parameters()

    def freeze_encoder(self, exclude_layernorm=False):
        for name, param in self.bert.encoder.named_parameters():
            if not exclude_layernorm or 'LayerNorm' not in name:
                param.requires_grad = False

    def unfreeze_encoder(self, exclude_layernorm=False):
        for name, param in self.bert.encoder.named_parameters():
            if not exclude_layernorm or 'LayerNorm' not in name:
                param.requires_grad = True

    def forward(self, tokens, masks):
        out_bert = self.bert(tokens, attention_mask = masks).pooler_output
        out_bert = self.dropout(out_bert)
        logits = self.classification(out_bert)
        
        return logits

## Pytorch Lightning Module

In [None]:
class LiteNet(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.hparams.update(hparams)
        self.criterion = nn.CrossEntropyLoss()

        self.model = args[0]['model'](**args[0]['model_args'])
        
        self.freeze_finetune_updates = hparams["freeze_finetune_updates"]

        self.trainloader = args[0]['train_loader']

        self.frozen = False
        if self.freeze_finetune_updates > 0:
            print("Freezing model ...")
            self.model.freeze_encoder()
            self.frozen = True

    def train_dataloader(self):
        return self.trainloader
    
    def setup(self, stage):
        if stage == 'fit':
            train_batches = len(self.train_dataloader())
            self.train_steps = (self.hparams.max_epochs * train_batches) // self.hparams.accum_grads

    def forward(self, tokens, mask):
        return self.model(tokens, mask)

    def training_step(self, train_batch, batch_idx):
        if self.frozen and self.freeze_finetune_updates < self.global_step:
            print("UNFREEZING!!")
            self.frozen = False
            self.model.unfreeze_encoder()

        tokens, mask, y = train_batch

        logits = self.forward(tokens, mask)

        loss = self.criterion(logits, y)

        self.log('loss_step', loss, on_step=True, prog_bar=True)
        
        return loss

    def training_epoch_end(self, outputs):
        loss = torch.stack([x['loss'] for x in outputs]).mean()       
        self.log("train_loss", loss, prog_bar=True)
  
    def validation_step(self, val_batch, batch_idx):
        
        tokens, mask, y = val_batch

        logits = self.forward(tokens, mask)

        # LOSS
        loss = self.criterion(logits, y)

        # ACC
        preds = logits.argmax(dim=1)
        corrects = (preds == y)

        return {"corrects": corrects, "val_loss_step": loss}

    def validation_epoch_end(self, outputs):
        acc_mean = torch.cat([x["corrects"] for x in outputs], dim=0)
        acc_mean = acc_mean.sum() / len(acc_mean)
        avg_loss = torch.stack([x["val_loss_step"] for x in outputs]).mean()
        
        self.log("val_acc", acc_mean, prog_bar=True)
        self.log("val_loss", avg_loss, prog_bar=True)
  
    def test_step(self, test_batch, batch_idx):
        
        tokens, mask, truth = test_batch

        out = self.forward(tokens, mask)
        preds = torch.argmax(out, dim=1)

        corrects = (preds == truth)

        return {"corrects": corrects}

    def test_epoch_end(self, outputs):
        acc_mean = torch.cat([x["corrects"] for x in outputs], dim=0)
        acc_mean = acc_mean.sum() / len(acc_mean)

        self.log("test_acc", acc_mean, prog_bar=True)

    def configure_optimizers(self):
        optimizer = AdaBelief(self.parameters(),
                              lr=self.hparams['lr'],
                              eps=1e-16,
                              weight_decay=self.hparams["w_decay"])
        
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_warmup_steps=self.hparams['warm_up_steps'],
            num_training_steps=self.train_steps
        )

        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step',
                'frequency': 1,
            }
        }

## Configuração dos parâmetros - hparams

- **model_name**: nome do checkpoint que será salvo;
- **bert_model**: modelo pré-treinado do BERT que será usado;
- **w2v2_model**: modelo pré-treinado do Wav2Vec2 nos quais os pesos serão utilizados;
- **max_length**: tamanho máximo de tokens;
- **nb_classes**: no caso da classificação, temos  2 classes;
- **lr**: learning-rate após o período de warm-up;
- **w_decay**: weight decay utilizado pelo otimizador (AdaBelief, no caso);
- **bs**: batch-size
- **accum_grads**: quantos acúmulos de gradientes até fazer update dos pesos;
- **patience**: paciência do *early-stop*;
- **max_epochs**: quantidade de épocas do treinamento;
- **freeze_finetune_updates**: quantidade de épocas no qual o Transformer fica congelado;
- **warm_up_stes**: quantidade de steps de warm-up da learning-rate;
- **seed_value**: seed à ser utilizada;

In [None]:
model_name = "FINAL_bert_base" #@param {type: "string"}
bert_model = 'bert-base-uncased'#@param {type: "string"}
w2v2_model = "facebook/wav2vec2-base"#@param {type: "string"}
max_length =  256#@param {type: "integer"}
nb_classes = 2 #@param {type: "integer"}
lr = 5e-4 #@param {type: "number"}
w_decay =  1e-4#@param {type: "number"}
bs =  16#@param {type: "integer"}
accum_grads =  4#@param {type: "integer"}
patience =  3#@param {type: "integer"}
max_epochs =  6#@param {type: "integer"}

freeze_finetune_updates = 0#@param {type: "integer"}
warm_up_steps = 0#@param {type: "integer"}

# clip_value = 0 #@param {type: "number"}
seed_value = 123#@param {type: "integer"}
# Define hyperparameters
hparams = {"model_name": model_name,
          "bert_model": bert_model,
          "w2v2_model": w2v2_model,
           "max_length": max_length,
           "nb_classes": nb_classes,
           "lr": lr,
           "w_decay": w_decay,
          "bs": bs,
          "patience": patience,
          "accum_grads": accum_grads,
          "freeze_finetune_updates":freeze_finetune_updates,
           "clip_value": clip_value,
          "max_epochs": max_epochs,
           "seed_value": seed_value,
           "warm_up_steps": warm_up_steps}

# Checagem: Overfit

Primeiro, escolha o modelo, descomentando o modelo necessário abaixo:

In [None]:
#####========== 1. BERT base - BASELINE
hparams.update(
    {'model': BERTBaseUncased, 'model_args': {
        'num_class': hparams['nb_classes'],
        'bert_model': hparams['bert_model']}}
)

#####========== 2. BERT base + Wav2Vec2
# w2v2_states = Wav2Vec2Model.from_pretrained(hparams['w2v2_model']).state_dict()
# hparams.update(
#     {'model': BERTWithWav2Vec2Weights, 'model_args': {
#         'num_class': hparams['nb_classes'],
#         'bert_model': hparams['bert_model'],
#         'wav2vec2_states': w2v2_states}}
# )

#####========== 3. BERT base + Xavier init
# hparams.update(
#     {'model': BERTBaseUncasedXavierInit, 'model_args': {
#         'num_class': hparams['nb_classes'],
#         'bert_model': hparams['bert_model']}}
# )

#####========== 4. BERT puro + Wav2Vec2
# w2v2_states = Wav2Vec2Model.from_pretrained(hparams['w2v2_model']).state_dict()
# hparams.update(
#     {'model': BERTRawWithWav2Vec2Weights, 'model_args': {
#         'num_class': hparams['nb_classes'],
#         'wav2vec2_states': w2v2_states}}
# )

#####========== 5. BERT puro + Xavier init
# hparams.update(
#     {'model': BERTRawXavierInit, 'model_args': {
#         'num_class': hparams['nb_classes']}}
# )

In [None]:
train_loader = create_dataloader(x_train, y_train, tokenizer, hparams['bs'], shuffle=True, max_length=hparams['max_length'])
valid_loader = create_dataloader(x_valid, y_valid, tokenizer, hparams['bs'], shuffle=False, max_length=hparams['max_length'])

hparams.update(
    {'train_loader': train_loader}
)

model = LiteNet(hparams)

trainer = pl.Trainer(gpus=1,
                     precision=16,
                     max_epochs=50,
                     accumulate_grad_batches=hparams["accum_grads"],
                     check_val_every_n_epoch=1,
                     checkpoint_callback=False, # Disable checkpoint saving.
                     overfit_batches=2)

trainer.fit(model, train_loader, valid_loader)


del model, trainer # Para não ter estouro de mémoria da GPU
gc.collect()
torch.cuda.empty_cache()

# Treinamento + Neptune logging

Caso desejar salvar o checkpoint do modelo no Google Drive:

In [None]:
from google.colab import drive
drive.mount('/content/drive')

checkpoint_dir = 'drive/MyDrive/Cursos/UNICAMP - IA376E - 2S2021/Projeto Final/checkpoints/pl'
assert os.path.exists(checkpoint_dir), "Pasta ainda não existe, por favor criar!"

Credenciais do Neptune para log:

In [None]:
nep_api_key = "eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiI2OTc4Mjk2NS0zNTI1LTRlYTItOWVmMC0yZjc0MDE4ODY5NzYifQ=="
nep_proj = "otalviana/IA376E-Projeto"
tags = ["FINAL", "IMDB", "BERT"]

In [None]:
pl.seed_everything(hparams['seed_value'], workers=True)

Selecione o modelo desejado:

In [None]:
#####========== 1. BERT base - BASELINE
hparams.update(
    {'model': BERTBaseUncased, 'model_args': {
        'num_class': hparams['nb_classes'],
        'bert_model': hparams['bert_model']}}
)

#####========== 2. BERT base + Wav2Vec2
# w2v2_states = Wav2Vec2Model.from_pretrained(hparams['w2v2_model']).state_dict()
# hparams.update(
#     {'model': BERTWithWav2Vec2Weights, 'model_args': {
#         'num_class': hparams['nb_classes'],
#         'bert_model': hparams['bert_model'],
#         'wav2vec2_states': w2v2_states}}
# )

#####========== 3. BERT base + Xavier init
# hparams.update(
#     {'model': BERTBaseUncasedXavierInit, 'model_args': {
#         'num_class': hparams['nb_classes'],
#         'bert_model': hparams['bert_model']}}
# )

#####========== 4. BERT puro + Wav2Vec2
# w2v2_states = Wav2Vec2Model.from_pretrained(hparams['w2v2_model']).state_dict()
# hparams.update(
#     {'model': BERTRawWithWav2Vec2Weights, 'model_args': {
#         'num_class': hparams['nb_classes'],
#         'wav2vec2_states': w2v2_states}}
# )

#####========== 5. BERT puro + Xavier init
# hparams.update(
#     {'model': BERTRawXavierInit, 'model_args': {
#         'num_class': hparams['nb_classes']}}
# )

Treinamento:

In [None]:
train_loader = create_dataloader(x_train, y_train, tokenizer, hparams['bs'], shuffle=True, max_length=hparams['max_length'])
valid_loader = create_dataloader(x_valid, y_valid, tokenizer, hparams['bs'], shuffle=False, max_length=hparams['max_length'])
test_loader = create_dataloader(x_test, y_test, tokenizer, hparams['bs'], shuffle=False, max_length=hparams['max_length'])

hparams.update(
    {'train_loader': train_loader}
)

model = LiteNet(hparams)

neptune_logger = NeptuneLogger(
    api_key=nep_api_key,
    project=nep_proj,
    tags=tags,
    log_model_checkpoints=False
)

neptune_logger.log_hyperparams(params=hparams)


# PL Callbacks
checkpoint_callback = pl.callbacks.ModelCheckpoint(filename=hparams['model_name'] + "-{epoch:02d}-{val_acc:.2f}",
                                                  dirpath=checkpoint_dir,
                                                  save_top_k=1,
                                                  verbose = True, 
                                                  monitor="val_loss", mode="min")

lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')
early_stop_callback = pl.callbacks.EarlyStopping(monitor="val_loss", patience=hparams["patience"], mode='min')

try:
    trainer = pl.Trainer(gpus=1,
                        precision=16,
                        progress_bar_refresh_rate=1,
                        max_epochs=hparams['max_epochs'],
                        accumulate_grad_batches=hparams["accum_grads"],
                        check_val_every_n_epoch=1,
                        callbacks=[early_stop_callback, checkpoint_callback, lr_monitor],
                        checkpoint_callback=True, # Disable checkpoint saving.
                        logger=neptune_logger,
                        log_every_n_steps=20
        )

    trainer.fit(model, train_loader, valid_loader)
    trainer.test(dataloaders=test_loader)
except Exception as e:
    print("Error:", e)
finally:
    neptune_logger._run_instance.stop()