<a href="https://colab.research.google.com/github/yongsun-yoon/deep-learning-paper-implementation/blob/main/03-natural-language-process/Making%20Monolingual%20Sentence%20Embeddings%20Multilingual%20using%20Knowledge%20Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation

## 0. Info

## Paper
* title: Making Monolingual Sentence Embeddings Multilingual using
Knowledge Distillation
* author: Nils Reimers and Iryna Gurevych
* url: https://arxiv.org/abs/2004.09813


## Feats
* dataset: opus100 (en-ko)

## 1. Setup

In [None]:
!pip install -q transformers datasets

In [None]:
import easydict
import numpy as np
from tqdm.auto import tqdm
from scipy.stats import pearsonr, spearmanr

import torch
import torch.nn.functional as F

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from transformers import get_scheduler

In [None]:
cfg = easydict.EasyDict(
    teacher_name = 'sentence-transformers/all-mpnet-base-v2',
    student_name = 'xlm-roberta-base',

    device = 'cuda:0',
    max_length = 256,
    batch_size = 16,

    weight_decay = 1e-4,
    lr = 1e-4,
    num_warmup_steps = 500,
    num_training_steps = 10000,
)

## 2. Data

In [None]:
def get_batch(data, batch_size=32):
    ens, kos, idxs = [], [], []
    while len(idxs) < batch_size:
        idx = np.random.randint(0, len(data))
        if idx in idxs: continue

        item = data[idx]['translation']
        ens.append(item['en'])
        kos.append(item['ko'])
        idxs.append(idx)

    return ens, kos

In [None]:
data = load_dataset('opus100', 'en-ko')
train_data = data['train']

## 3. Model

In [None]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [None]:
teacher_tokenizer = AutoTokenizer.from_pretrained(cfg.teacher_name)
teacher_model = AutoModel.from_pretrained(cfg.teacher_name)
_ = teacher_model.eval().requires_grad_(False).to(cfg.device)

In [None]:
student_tokenizer = AutoTokenizer.from_pretrained(cfg.student_name)
student_model = AutoModel.from_pretrained(cfg.student_name)
_ = student_model.train().to(cfg.device)

In [None]:
optimizer = torch.optim.AdamW(student_model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = get_scheduler('cosine', optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=cfg.num_training_steps)

## 4. Train

In [None]:
pbar = tqdm(range(1, cfg.num_training_steps+1))
for st in pbar:
    ens, kos = get_batch(train_data, cfg.batch_size)

    teacher_ens_inputs = teacher_tokenizer(ens, max_length=cfg.max_length, padding=True, truncation=True, return_tensors='pt').to(cfg.device)
    student_ens_inputs = student_tokenizer(ens, max_length=cfg.max_length, padding=True, truncation=True, return_tensors='pt').to(cfg.device)
    student_kos_inputs = student_tokenizer(kos, max_length=cfg.max_length, padding=True, truncation=True, return_tensors='pt').to(cfg.device)

    teacher_ens_outputs = teacher_model(**teacher_ens_inputs)
    teacher_ens_embeds = mean_pooling(teacher_ens_outputs, teacher_ens_inputs.attention_mask)

    student_ens_outputs = student_model(**student_ens_inputs)
    student_ens_embeds = mean_pooling(student_ens_outputs, student_ens_inputs.attention_mask)
    student_kos_outputs = student_model(**student_kos_inputs)
    student_kos_embeds = mean_pooling(student_kos_outputs, student_kos_inputs.attention_mask)

    en_loss = F.mse_loss(teacher_ens_embeds, student_ens_embeds) 
    ko_loss = F.mse_loss(teacher_ens_embeds, student_kos_embeds)
    loss = (en_loss + ko_loss) * 10.
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()

    log = {'loss': loss.item(), 'en': en_loss.item(), 'ko': ko_loss.item()}
    pbar.set_postfix(log)

    if st % 1000 == 0:
        student_tokenizer.save_pretrained('ckpt')
        student_model.save_pretrained('ckpt')

## 5. Test

In [None]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

class Dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        return item

def collate_fn(batch, tokenizer, max_length):
    sent1 = tokenizer([i['sentence1'] for i in batch], max_length=max_length, padding=True, truncation=True, return_tensors='pt')
    sent2 = tokenizer([i['sentence2'] for i in batch], max_length=max_length, padding=True, truncation=True, return_tensors='pt')
    labels = torch.tensor([i['labels']['label'] for i in batch])
    return sent1, sent2, labels

def to_device(d, device):
    return {k:v.to(device) for k,v in d.items()}

In [None]:
data = load_dataset('klue', 'sts')['validation']
dataset = Dataset(data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, shuffle=False, collate_fn=lambda x: collate_fn(x, tokenizer, cfg.max_length))

In [None]:
tokenizer = AutoTokenizer.from_pretrained('ckpt')
model = AutoModel.from_pretrained('ckpt')
# tokenizer = AutoTokenizer.from_pretrained(cfg.student_name)
# model = AutoModel.from_pretrained(cfg.student_name)
_ = model.eval().requires_grad_(False).to(cfg.device)

In [None]:
preds, labels = [], []

for sent1, sent2, label in tqdm(dataloader):
    sent1, sent2 = sent1.to(cfg.device), sent2.to(cfg.device)
    
    sent1_out = model(**sent1)
    sent2_out = model(**sent2)
    sent1_emb = mean_pooling(sent1_out, sent1.attention_mask).cpu()
    sent2_emb = mean_pooling(sent2_out, sent2.attention_mask).cpu()

    pred = F.cosine_similarity(sent1_emb, sent2_emb, dim=-1)
    preds.append(pred)
    labels.append(label)

preds = torch.cat(preds, dim=0).numpy()
labels = torch.cat(labels, dim=0).numpy()

In [None]:
pr = pearsonr(preds, labels)[0]
spr = spearmanr(preds, labels)[0]

print(f'pearsonr {pr:.3f} | spearmanr {spr:.3f}')