<a href="https://colab.research.google.com/github/respect5716/deep-learning-paper-implementation/blob/main/03_NLP/MirrorBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MirrorBERT

## 0. Info

### Paper
* title: Fast, Effective, and Self-Supervised: Transforming Masked Language Models into Universal Lexical and Sentence Encoder
* author: Fangyu Liu et al.
* url: https://arxiv.org/abs/2104.08027

### Features
* dataset: wikitext

### Reference
* https://github.com/cambridgeltl/mirror-bert

## 1. Setup

In [1]:
!pip install -q transformers datasets pytorch_metric_learning

[K     |████████████████████████████████| 3.4 MB 12.4 MB/s 
[K     |████████████████████████████████| 306 kB 67.0 MB/s 
[K     |████████████████████████████████| 106 kB 65.6 MB/s 
[K     |████████████████████████████████| 895 kB 66.3 MB/s 
[K     |████████████████████████████████| 3.3 MB 56.9 MB/s 
[K     |████████████████████████████████| 596 kB 66.8 MB/s 
[K     |████████████████████████████████| 61 kB 481 kB/s 
[K     |████████████████████████████████| 243 kB 69.6 MB/s 
[K     |████████████████████████████████| 132 kB 73.1 MB/s 
[K     |████████████████████████████████| 1.1 MB 53.0 MB/s 
[K     |████████████████████████████████| 271 kB 74.3 MB/s 
[K     |████████████████████████████████| 160 kB 67.8 MB/s 
[K     |████████████████████████████████| 192 kB 54.1 MB/s 
[?25h

In [2]:
import easydict

import numpy as np
from scipy import spatial
from scipy.stats.stats import pearsonr,spearmanr

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_metric_learning import losses

from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset

In [3]:
config = easydict.EasyDict(
    batch_size = 32,
    max_seq_length = 64,
    mask_len = 5,
    num_train_steps = 1000,
    agg_mode = 'mean',
)

## 2. Data

In [4]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, config):
        self.config = config
        self.data = load_dataset('wikitext', 'wikitext-2-raw-v1')['train']
        self.data = self.data.filter(lambda x: len(x['text']) > 100)
        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sent = self.data[idx]['text']
        inputs = self.tokenizer(sent, max_length=self.config.max_seq_length, truncation=True, padding='max_length')
        inputs['masked_input_ids'] = erase_and_mask(inputs['input_ids'], self.config.mask_len)
        return inputs


def erase_and_mask(input_ids, mask_len, mask_token_id=103):
    masked_input_ids = input_ids.copy()
    ind = np.random.randint(len(input_ids) - mask_len)
    masked_input_ids[ind:ind+mask_len] = [mask_token_id] * mask_len
    return masked_input_ids


def collate_fn(batch):
    collated = {}
    keys = batch[0].keys()
    for k in keys:
        collated[k] = torch.tensor([b[k] for b in batch])
    return collated

In [5]:
dataset = Dataset(config)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

Downloading:   0%|          | 0.00/2.03k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.25k [00:00<?, ?B/s]

Downloading and preparing dataset wikitext/wikitext-2-raw-v1 (download: 4.50 MiB, generated: 12.90 MiB, post-processed: Unknown size, total: 17.40 MiB) to /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126...


Downloading:   0%|          | 0.00/4.72M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset wikitext downloaded and prepared to /root/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?ba/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

In [6]:
batch = next(iter(dataloader))
for k, v in batch.items():
    print(k, v.size())

input_ids torch.Size([8, 64])
token_type_ids torch.Size([8, 64])
attention_mask torch.Size([8, 64])
masked_input_ids torch.Size([8, 64])


## 3. Model

In [7]:
encoder = AutoModel.from_pretrained('bert-base-uncased')
encoder = encoder.to('cuda')

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
optim = torch.optim.AdamW(encoder.parameters(), lr=2e-5, weight_decay=0.01)
criterion = losses.NTXentLoss(0.04)

## 4. Train

In [9]:
def agg_fn(out, mode):
    if mode == 'cls':
        return out.last_hidden_state[:, 0]
    
    elif mode == 'mean':
        return out.last_hidden_state.mean(dim=1)

In [10]:
loss_tracker = 0.
for st, batch in enumerate(dataloader):
    batch = {k:v.to(encoder.device) for k, v in batch.items()}

    outputs1 = encoder(batch['input_ids'], batch['attention_mask'], batch['token_type_ids'])
    outputs2 = encoder(batch['masked_input_ids'], batch['attention_mask'], batch['token_type_ids'])

    embed1 = agg_fn(outputs1, config.agg_mode)
    embed2 = agg_fn(outputs2, config.agg_mode)
    embed = torch.cat([embed1, embed2], dim=0) # (bs * 2, dim)

    labels = torch.arange(embed1.size(0))
    labels = torch.cat([labels, labels], dim=0) # (bs * 2, dim)
    labels = labels.to(embed.device)

    loss = criterion(embed, labels)

    optim.zero_grad()
    loss.backward()
    optim.step()

    loss_tracker = 0.9 * loss_tracker + 0.1 * loss.item()

    if st > 0 and st % 100 == 0:
        print(f'st {st:04d} | loss: {loss_tracker:.4f}')

    if st + 1 == config.num_train_steps:
        encoder.save_pretrained('transformers')
        dataset.tokenizer.save_pretrained('transformers')
        break

st 0100 | loss: 0.0001
st 0200 | loss: 0.0003
st 0300 | loss: 0.0001
st 0400 | loss: 0.0000
st 0500 | loss: 0.0000
st 0600 | loss: 0.0000
st 0700 | loss: 0.0000
st 0800 | loss: 0.0000
st 0900 | loss: 0.0000


## 5. Eval

In [11]:
def tokenize(example, tokenizer):
    batch = {}
    sent1 = tokenizer(example['sentence1'], max_length=64, padding='max_length', truncation=True)
    for k, v in sent1.items():
        batch[f'sent1_{k}'] = v
    
    sent2 = tokenizer(example['sentence2'], max_length=64, padding='max_length', truncation=True)
    for k, v in sent2.items():
        batch[f'sent2_{k}'] = v
    
    batch['label'] = example['label']
    return batch

In [12]:
model = AutoModel.from_pretrained('transformers')
tokenizer = AutoTokenizer.from_pretrained('transformers')

model = model.cuda()

In [13]:
stsb = load_dataset('glue', 'stsb')['validation']
stsb = stsb.map(lambda x: tokenize(x, tokenizer))
stsb.set_format('torch', columns=['sent1_input_ids', 'sent1_attention_mask', 'sent1_token_type_ids', 'sent2_input_ids', 'sent2_attention_mask', 'sent2_token_type_ids', 'label'])

Downloading:   0%|          | 0.00/7.78k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/4.47k [00:00<?, ?B/s]

Downloading and preparing dataset glue/stsb (download: 784.05 KiB, generated: 1.09 MiB, post-processed: Unknown size, total: 1.86 MiB) to /root/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


Downloading:   0%|          | 0.00/803k [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset glue downloaded and prepared to /root/.cache/huggingface/datasets/glue/stsb/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?ex/s]

In [14]:
dataloader = torch.utils.data.DataLoader(stsb, batch_size=8, shuffle=False)

In [15]:
embeds1, embeds2, labels = [], [], []

for batch in dataloader:
    batch = {k:v.to(model.device) for k, v in batch.items()}
    with torch.no_grad():
        outputs1 = model(batch['sent1_input_ids'], batch['sent1_attention_mask'], batch['sent1_token_type_ids'])
        outputs2 = model(batch['sent2_input_ids'], batch['sent2_attention_mask'], batch['sent2_token_type_ids'])

    embed1 = agg_fn(outputs1, config.agg_mode)
    embed2 = agg_fn(outputs2, config.agg_mode)

    embeds1.append(embed1.cpu())
    embeds2.append(embed2.cpu())
    labels.append(batch['label'].cpu())

embeds1 = torch.cat(embeds1, dim=0).numpy()
embeds2 = torch.cat(embeds2, dim=0).numpy()
labels = torch.cat(labels, dim=0).numpy()

In [16]:
sims = []
for i in range(len(stsb)):
    sim = 1 - spatial.distance.cosine(embeds1[i], embeds2[i])
    sims.append(sim)

In [17]:
score = spearmanr(labels, sims)[0]
print(score)

0.663083132567378
