# Cross Language Search

We are using [this nice dataset](https://github.com/BangBOOM/Classical-Chinese)

## Imports

In [1]:
# !pip install -Uqq git+https://github.com/raynardj/forgebox

In [2]:
from forgebox.imports import *
from forgebox.thunder.callbacks import DataFrameMetricsCallback
from forgebox.multiproc import DataFrameRowling
from gc_utils.env import *
from datasets import load_dataset
# from fastai.text.all import *
from unpackai.nlp import *
from tqdm.notebook import tqdm
import random

In [3]:
import pytorch_lightning as pl

In [4]:
import re

def remove_all_punkt(text):
    """
    Removes all punctuation from Chinese text.

    :param text: text to remove punctuation from
    :return: text with no punctuation
    """
    return re.sub(r'[^\w\s]', '', text)

In [5]:
remove_all_punkt("亳州水军千户胡进等领骑兵渡淝水，逾荆山，与宋兵战，杀获甚众，赏钞币有差。")

'亳州水军千户胡进等领骑兵渡淝水逾荆山与宋兵战杀获甚众赏钞币有差'

## Config

In [6]:
DATA = Path(sys_loc('DATA')/"nlp"/"zh"/"cc_vs_zh")
TO_CLASSICAL = False

## Download data

## Data

### Combine data

In [7]:
all_file = list(DATA.rglob("data/*"))

In [8]:
def open_file_to_lines(file):
    with open(file) as f:
        lines = f.read().splitlines()
    return lines

def pairing_the_file(files,kw):
    pairs = []
    for file in files:
        if kw not in file.name:
            file1 = file
            file2 = f"{file}{kw}"
            pairs.append((file1,file2))
    return pairs

In [9]:
pairs = pairing_the_file(all_file,"翻译")

In [10]:
def open_pairs(pairs):
    chunks = []
    for pair in tqdm(pairs, leave=False):
        file1,file2 = pair
        lines1 = open_file_to_lines(file1)
        lines2 = open_file_to_lines(file2)
        chunks.append(pd.DataFrame({"classical":lines1,"modern":lines2}))
    return pd.concat(chunks).sample(frac=1.).reset_index(drop=True)

In [11]:
data_df = open_pairs(pairs)

  0%|          | 0/27 [00:00<?, ?it/s]

In [12]:
df = data_df.rename(
    columns = dict(
        zip(["modern","classical"],
             ["source","target"] if TO_CLASSICAL else ["target","source",]))
)

In [13]:
df.head()

Unnamed: 0,source,target
0,下也。,因为在下面。
1,长乐王尉粲甚礼之。,垦銮王幽垩很礼待他。
2,太师王舜自莽篡位后，病悸剧，死。,太师王舜自王莽篡夺皇位后，得了心悸病，渐渐加剧，终于病故。
3,秋七月丙寅，以旱，亲录京城囚徒。,秋七月二十九日，因为干旱，皇上亲自审查并记录囚徒罪状。
4,乙亥，齐仪同三司元旭坐事赐死。,乙亥，北齐国仪同三司元旭因犯罪被赐死。


### Loading tokenizer

In [14]:
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    AutoModel,
    EncoderDecoderModel
    )
PRETRAINED = "bert-base-chinese"

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED)

### Pytoch Dataset

In [15]:
import random

def combine_randomly(data):
    if random.random()>.5:
        a,b = data['source'],data['target']
    else:
        a,b = data['target'],data['source']
    return f"{a}{b}"

def pick_randomly(data):
    return list(data.values())[int(random.random()>.5)]

def mixup(data):
    if len(data['target'])> 70:
        th = .7
    else:
        th = .3
    if random.random()>th:
        return combine_randomly(data)
    else:
        return pick_randomly(data)

In [16]:
class XLSearch(Dataset):
    def __init__(
        self, df, tokenizer,
        max_len=128,
        no_punkt:bool = False,
        mlm_probability:float = .15,
    ):
        """
        no_punkt, do we ramdomly remove punctuation
        from source sentence
        """
        super().__init__()
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.mlm_probability = mlm_probability
        
    def __len__(self, ):
        return len(self.df)

    def __getitem__(self, idx):
        return mixup(dict(self.df.loc[idx]))

    def collate(self, data):
        inputs = self.tokenizer(
            list(data),
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )
        return self.mlm_masking(inputs)
    
    def mlm_masking(self,inputs):
        """
        convert inputs for masked language modeling
        """
        if self.mlm_probability is None:
            return inputs
        input_ids = inputs.input_ids
        token_type_ids = inputs.token_type_ids
        
        # masking input_ids
        masked = input_ids.clone()
        masked[
            torch.rand(input_ids.shape).to(input_ids.device) < self.mlm_probability
        ] = self.tokenizer.mask_token_id
        
        labels = input_ids.clone()
        labels[token_type_ids == 1] = -100
        labels[labels==0] = -100
        token_type_ids[masked==self.tokenizer.mask_token_id] = 1
        labels[token_type_ids == 0] = -100
        
        inputs['input_ids'] = masked
        inputs['labels'] = labels
        inputs['token_type_ids'] = token_type_ids
        return inputs

    def dataloader(self, batch_size, shuffle=True):
        return DataLoader(
            self,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=self.collate,
        )

    def split_train_valid(self, valid_size=0.1):
        split_index = int(len(self) * (1 - valid_size))
        cls = type(self)
        shuffled = self.df.sample(frac=1).reset_index(drop=True)
        train_set = cls(
            shuffled.iloc[:split_index],
            tokenizer=self.tokenizer,
            max_len=self.max_len,
        )
        valid_set = cls(
            shuffled.iloc[split_index:],
            tokenizer=self.tokenizer,
            max_len=self.max_len,
        )
        return train_set, valid_set

In [17]:
ds = XLSearch(df, tokenizer, )

In [18]:
ds[5]

'又将御史王金，主事马思聪、金山，参议黄宏、许效廉，布政使胡廉，参政陈杲、刘非木，佥事赖凤，指挥许金、白昂等人逮捕下狱。执御史王金，主事马思聪、金山，参议黄宏、许效廉，布政使胡廉，参政陈杲、刘棐，佥事赖凤，指挥许金、白昂等下狱。'

### Different ways of mixing and masking

### PL datamodule

In [19]:
class DataModule(pl.LightningDataModule):
    def __init__(
        self, df,
        tokenizer,
        batch_size=12,
        max_len=128,
        no_punkt:bool=False):
        super().__init__()
        self.df = df
        self.ds = XLSearch(df,
                          tokenizer,
                          max_len=max_len,)
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_set, self.valid_set = self.ds.split_train_valid()

    def train_dataloader(self):
        return self.train_set.dataloader(
            batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return self.valid_set.dataloader(
            batch_size=self.batch_size*2, shuffle=False)

In [20]:
data_module = DataModule(
    df, tokenizer,
    batch_size=64,
    max_len=256,
    no_punkt=False if TO_CLASSICAL else True,)
data_module.setup()

In [21]:
inputs = next(iter(data_module.train_dataloader()))
inputs

{'input_ids': tensor([[ 101, 1282,  103,  ...,    0,    0,    0],
        [ 101, 3293, 1062,  ...,    0,    0,  103],
        [ 101,  758, 2399,  ...,    0,    0,    0],
        ...,
        [ 101, 7826,  815,  ...,  103,    0,    0],
        [ 101, 5628, 6818,  ...,    0,    0,    0],
        [ 101, 5745,  815,  ...,    0,  103,    0]]), 'token_type_ids': tensor([[0, 0, 1,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 1],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 1, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 1, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([[-100, -100, 1063,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100],
        ...,
      

if we are doing clasical Chinese to modern Chinese, we can randomly set half of the input without any punctuation, as many data source might be

In [22]:
# tokenizer.batch_decode(
#     inputs.input_ids,skip_special_tokens=False
# )

### Load pretrained models

## Model

In [23]:
# loading pretrained model
model = AutoModelForMaskedLM.from_pretrained(PRETRAINED
)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [24]:
class MaskedLM(pl.LightningModule):
    def __init__(
        self,
        model):
        super().__init__()
        self.model = model

    def forward(self, **kwargs):
        return self.model(**kwargs)

    def accuracy(self, batch_input, outputs):
        """
        Accuracy for masked language model
        """
        mask_mask = batch_input.labels != -100
        predictions = outputs.logits.argmax(-1)[mask_mask]
        targets = batch_input.labels[mask_mask]
        return (predictions == targets).float().mean()

    def training_step(self, batch, batch_idx):
        inputs = dict(
            input_ids=batch.input_ids,
            attention_mask=batch.attention_mask,
            labels=batch.labels,
            )
        outputs = self(**inputs)
        self.log("loss", outputs.loss, prog_bar=True)
        self.log("acc",
            self.accuracy(batch, outputs),
            on_step=True, prog_bar=True)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        inputs = dict(
            input_ids=batch.input_ids,
            attention_mask=batch.attention_mask,
            labels=batch.labels,
            )
        outputs = self(**inputs)
        self.log("val_loss", outputs.loss, prog_bar=True)
        self.log("val_acc",
            self.accuracy(batch, outputs),
            on_step=False, prog_bar=True)
        return outputs.loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-6)

In [25]:
module = MaskedLM(model)

## Training

In [26]:
TASK = "xlsearch_cc_zh"

In [None]:
tb_logger = pl.loggers.TensorBoardLogger(
    save_dir=f"/GCI/tensorboard/{TASK}",
    name=TASK,
    )

save_cb = pl.callbacks.ModelCheckpoint(
    dirpath=f"/GCI/transformers/weights/{TASK}",
    save_top_k=3,
    verbose=True,
    monitor="acc",
    save_weights_only=True,
    every_n_train_steps=1024,
    mode="max",
    )

trainer = pl.Trainer(
    gpus=[1,],
    max_epochs=10,
    logger = [tb_logger,],
    callbacks=[save_cb,
#                DataFrameMetricsCallback()
              ],
    )

trainer.fit(
    module,
    datamodule = data_module
    )


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type            | Params
------------------------------------------
0 | model | BertForMaskedLM | 102 M 
------------------------------------------
102 M     Trainable params
0         Non-trainable params
102 M     Total params
409.161   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'


Training: 0it [00:00, ?it/s]

Epoch 0, global step 1023: acc reached 0.54819 (best 0.54819), saving model to "/nvme/GCI/transformers/weights/xlsearch_cc_zh/epoch=0-step=1023.ckpt" as top 3
IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



## Save

In [22]:
best = save.best

In [24]:
module.load_state_dict(torch.load(best, map_location="cpu")['state_dict'])

<All keys matched successfully>

In [61]:
# encoder_decoder.push_to_hub("raynardj/wenyanwen-chinese-translate-to-ancient")

Cloning https://huggingface.co/raynardj/wenyanwen-chinese-translate-to-ancient into local empty directory.


Upload file pytorch_model.bin:   0%|          | 32.0k/916M [00:00<?, ?B/s]

To https://user:eOwfuFZJHbcMgbzVtVPDaSGtpbpjumsgTzZtfKlrMbSECzypnCYHZGDhHVsHRsYZzvdrkcxbnnSXRROfqdNRYfMvVfaVSOTxORkEUcMnAPEWXhkWpVEDrgfUZJdmleTx@huggingface.co/raynardj/wenyanwen-chinese-translate-to-ancient
   08f3b21..5ee2133  main -> main



'https://huggingface.co/raynardj/wenyanwen-chinese-translate-to-ancient/commit/5ee213356db17dfa9577226a90d5e9bd9461b495'

In [65]:
# encoder_tokenizer.push_to_hub("raynardj/wenyanwen-chinese-translate-to-ancient")

To https://user:eOwfuFZJHbcMgbzVtVPDaSGtpbpjumsgTzZtfKlrMbSECzypnCYHZGDhHVsHRsYZzvdrkcxbnnSXRROfqdNRYfMvVfaVSOTxORkEUcMnAPEWXhkWpVEDrgfUZJdmleTx@huggingface.co/raynardj/wenyanwen-chinese-translate-to-ancient
   5ee2133..ab72fa4  main -> main



'https://huggingface.co/raynardj/wenyanwen-chinese-translate-to-ancient/commit/ab72fa41627cfeb6fef64e196d68d81b0adb6228'