<a href="https://colab.research.google.com/github/respect5716/deep-learning-paper-implementation/blob/main/03_NLP/GPL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GPL

## 0. Info
### Paper
* title: GPL: Generative Pseudo Labeling for Unsupervised Domain Adaptation of Dense Retrieval
* author: Kexin Wang et al
* url: https://arxiv.org/abs/2112.07577

### Features
* dataset: fiqa
* num negatives: 10


### Reference
* https://github.com/UKPLab/gpl

## 1. Setup

In [None]:
!pip install -q jsonlines datasets transformers sentence-transformers beir

In [None]:
import random
import easydict
import jsonlines
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from datasets import load_dataset
from transformers import get_scheduler
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer, CrossEncoder, InputExample

from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch 

In [None]:
cfg = easydict.EasyDict(
    dataset = 'fiqa',
    generator = 'BeIR/query-gen-msmarco-t5-base-v1',
    retriever = 'msmarco-MiniLM-L-6-v3',
    cross_encoder = 'cross-encoder/ms-marco-MiniLM-L-6-v2',
    base_model = 'distilbert-base-uncased',
    
    qs_per_passage = 3,
    num_negs = 10,
    batch_size = 16,
    scale = 1.,
    num_training_steps = 10000,
    
    warmup_ratio = 0.05,
    weight_decay = 1e-2,
    lr = 1e-4,
    max_grad_norm = 1.,
)

## 2. Data

In [None]:
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{cfg.dataset}.zip"
!wget "{url}"
!unzip "{cfg.dataset}.zip"

In [None]:
corpus = []
with jsonlines.open(f'{cfg.dataset}/corpus.jsonl') as f:
    for line in f:
        line.pop('metadata')
        corpus.append(line)

corpus = pd.DataFrame(corpus)
corpus['concat'] = corpus['title'] + ' ' + corpus['text']

corpus.to_csv('corpus.csv', index=False)
corpus.head()

## 3, Generate queries

In [None]:
corpus = pd.read_csv('corpus.csv')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(cfg.generator)
generator = AutoModelForSeq2SeqLM.from_pretrained(cfg.generator)
_ = generator.to('cuda').eval()

In [None]:
queries = []
for i in tqdm(range(0, len(corpus), cfg.batch_size)):
    batch = corpus.iloc[i:i+cfg.batch_size].to_dict('records')
    ids = [b['_id'] for b in batch]
    texts = [b['concat'] for b in batch]
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
    inputs = inputs.to(generator.device)
    
    with torch.no_grad():
        outputs = generator.generate(
            **inputs, 
            do_sample = True, 
            max_length = 64, 
            top_k = 25, 
            top_p = 0.95, 
            num_return_sequences = cfg.qs_per_passage
        )
        
    outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    queries += [{'query': o, 'pos': i+j//cfg.qs_per_passage} for j, o in enumerate(outputs)]

In [None]:
queries = pd.DataFrame(queries)
queries.to_csv('queries.csv', index=False)

## 4. Mine negatives

In [None]:
corpus = pd.read_csv('corpus.csv')
queries = pd.read_csv('queries.csv')

In [None]:
retriever = SentenceTransformer(cfg.retriever)

In [None]:
docs = corpus['concat'].tolist()
doc_embs = retriever.encode(docs, batch_size=128, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True)

In [None]:
negs = []
for i in tqdm(range(0, len(queries), cfg.batch_size)):
    batch = queries.iloc[i:i+cfg.batch_size].to_dict('records')
    texts = [b['query'] for b in batch]
    
    q_emb = retriever.encode(texts, show_progress_bar=False, convert_to_tensor=True, normalize_embeddings=True)
    score = torch.matmul(q_emb, doc_embs.T)
    _, indices = score.topk(k=cfg.num_negs, dim=-1)
    negs += indices.tolist()
queries['neg'] = negs

In [None]:
queries['neg'] = queries.apply(lambda x: [n for n in x['neg'] if n != x['pos']], axis=1)
queries = queries.explode('neg')
queries.to_csv('negatives.csv', index=False)

## 5. Pseudo labeling

In [None]:
corpus = pd.read_csv('corpus.csv')
queries = pd.read_csv('negatives.csv')
pos = queries.drop_duplicates(subset=['query', 'pos'])

In [None]:
retokenizer = AutoTokenizer.from_pretrained(cfg.cross_encoder)
cross_encoder = CrossEncoder(cfg.cross_encoder)

In [None]:
# positive scores
score = []
for i in tqdm(range(0, len(pos), cfg.batch_size)):
    batch = pos.iloc[i:i+cfg.batch_size].to_dict('records')
    qs = [b['query'] for b in batch]
    docs = [corpus.iloc[b['pos']]['concat'] for b in batch]
    score += cross_encoder.predict(list(zip(qs, docs))).tolist()
pos['pos_score'] = score
pos = pos.drop('neg', axis=1)

In [None]:
# negative scores
score = []
for i in tqdm(range(0, len(queries), cfg.batch_size)):
    batch = queries.iloc[i:i+cfg.batch_size].to_dict('records')
    qs = [b['query'] for b in batch]
    docs = [corpus.iloc[b['neg']]['concat'] for b in batch]
    score += cross_encoder.predict(list(zip(qs, docs))).tolist()
queries['neg_score'] = score

In [None]:
queries = pd.merge(queries, pos, on=['query', 'pos'], how='left')
queries['labels'] = queries['pos_score'] - queries['neg_score']
queries.to_csv('pseudo-labels.csv', index=False)
queries.head()

## 6. Train

In [None]:
def get_param_groups(model, weight_decay):
    no_decay = ["bias", "bn", "ln", "norm"]
    param_groups = [
        {
            # apply weight decay
            "params": [p for n, p in model.named_parameters() if not any(nd in n.lower() for nd in no_decay)],
            "weight_decay": weight_decay
        },
        {
            # not apply weight decay
            "params": [p for n, p in model.named_parameters() if any(nd in n.lower() for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    return param_groups

def margin_mse_fn(qemb, pemb, nemb, labels, scale, normalize=False):
    if normalize:
        qemb = F.normalize(qemb, dim=-1)
        pemb = F.normalize(pemb, dim=-1)
        nemb = F.normalize(nemb, dim=-1)

    pscore = (qemb * pemb).sum(dim=-1) * scale
    nscore = (qemb * nemb).sum(dim=-1) * scale
    margin_pred = pscore - nscore
    loss = F.mse_loss(margin_pred, labels)
    return loss

In [None]:
corpus = pd.read_csv('corpus.csv')
corpus = corpus['concat'].tolist()

data = pd.read_csv('pseudo-labels.csv')

In [None]:
model = SentenceTransformer(cfg.base_model).to('cuda')
param_groups = get_param_groups(model, cfg.weight_decay)
optimizer = torch.optim.AdamW(param_groups, lr=cfg.lr, weight_decay=cfg.weight_decay)
num_warmup_steps = int(cfg.num_training_steps * cfg.warmup_ratio)
scheduler = get_scheduler('linear', optimizer, num_training_steps=cfg.num_training_steps, num_warmup_steps=num_warmup_steps)

In [None]:
pbar = tqdm(range(1, cfg.num_training_steps+1))
for st in pbar:
    batch_idx = np.random.randint(0, len(data), cfg.batch_size)
    batch = data.iloc[batch_idx].to_dict('records')
    query = [b['query'] for b in batch]
    pos = [corpus[b['pos']] for b in batch]
    neg = [corpus[b['neg']] for b in batch]
    labels = [b['labels'] for b in batch]

    batch = [InputExample(texts=[q, p, n], label=l) for q,p,n,l in zip(query, pos, neg, labels)]
    inputs, labels = model.smart_batching_collate(batch)
    qemb, pemb, nemb = [model(ipt)['sentence_embedding'] for ipt in inputs]
    loss = margin_mse_fn(qemb, pemb, nemb, labels, cfg.scale)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
    optimizer.step()
    scheduler.step()
    pbar.set_postfix({'loss': loss.item()})
        
    if st % 1000 == 0:
        model.save('transformers')

## 7. Evaluate

In [None]:
corpus, queries, qrels = GenericDataLoader(cfg.dataset).load(split="test")

In [None]:
# transformer = SentenceTransformer(cfg.base_model)
transformer = SentenceTransformer('transformers')
sbert = models.SentenceBERT(sep=' ')
sbert.q_model = transformer
sbert.doc_model = transformer

In [None]:
model = DenseRetrievalExactSearch(sbert, batch_size=16)
retriever = EvaluateRetrieval(model, score_function='cos_sim') 

In [None]:
results = retriever.retrieve(corpus, queries)

In [None]:
# gpl 10k
ndcg, _map, recall, precision = EvaluateRetrieval.evaluate(qrels, results, k_values=[10])
mrr = EvaluateRetrieval.evaluate_custom(qrels, results, k_values=[10,], metric='mrr')
ndcg, _map, recall, precision, mrr