# Translate model

We are using [this nice dataset](https://github.com/BangBOOM/Classical-Chinese)

## Imports

In [1]:
from forgebox.imports import *
from forgebox.thunder.callbacks import DataFrameMetricsCallback
from gc_utils.env import *
from datasets import load_dataset
# from fastai.text.all import *
from unpackai.nlp import *
from tqdm.notebook import tqdm
import random

In [2]:
import pytorch_lightning as pl

In [3]:
import re

def remove_all_punkt(text):
    """
    Removes all punctuation from Chinese text.

    :param text: text to remove punctuation from
    :return: text with no punctuation
    """
    return re.sub(r'[^\w\s]', '', text)

In [4]:
remove_all_punkt("亳州水军千户胡进等领骑兵渡淝水，逾荆山，与宋兵战，杀获甚众，赏钞币有差。")

'亳州水军千户胡进等领骑兵渡淝水逾荆山与宋兵战杀获甚众赏钞币有差'

## Config

In [5]:
DATA = Path(sys_loc('DATA')/"nlp"/"zh"/"cc_vs_zh")
TO_CLASSICAL = False

## Download data

## Data

### Combine data

In [6]:
all_file = list(DATA.rglob("data/*"))

In [7]:
def open_file_to_lines(file):
    with open(file) as f:
        lines = f.read().splitlines()
    return lines

def pairing_the_file(files,kw):
    pairs = []
    for file in files:
        if kw not in file.name:
            file1 = file
            file2 = f"{file}{kw}"
            pairs.append((file1,file2))
    return pairs

In [8]:
pairs = pairing_the_file(all_file,"翻译")

In [9]:
def open_pairs(pairs):
    chunks = []
    for pair in tqdm(pairs, leave=False):
        file1,file2 = pair
        lines1 = open_file_to_lines(file1)
        lines2 = open_file_to_lines(file2)
        chunks.append(pd.DataFrame({"classical":lines1,"modern":lines2}))
    return pd.concat(chunks).sample(frac=1.).reset_index(drop=True)

In [10]:
data_df = open_pairs(pairs)

  0%|          | 0/27 [00:00<?, ?it/s]

In [11]:
df = data_df.rename(
    columns = dict(
        zip(["modern","classical"],
             ["source","target"] if TO_CLASSICAL else ["target","source",]))
)

In [12]:
df.head()

Unnamed: 0,source,target
0,谏议大夫宁原悌上言：以为先朝悖逆庶人以爱女骄盈而及祸，新城、宜都以庶孽抑损而获全。,谏议大夫宁原悌向唐睿宗进言认为：先朝悖逆庶人作为中宗和韦后的爱女而骄傲自满，终于难逃杀身之祸...
1,意等漏卮，江河无以充其溢。,思想像渗漏的酒器，长江、黄河无法来填满他的欲壑。
2,琥珀太多，及差，痕不灭，左颊有赤点如痣。,因琥珀用得过多，到伤愈时，邓夫人左颊疤疮没有完全去掉，脸上留下一颗象痣一样的红点。
3,督军疾进，师至阴山，遇其斥候千余帐，皆俘以随军。,于是督军疾进，军队行进到阴山，遇到颉利可汗的哨兵千余帐，把他们全部俘获，并押着他们随军行动。
4,莽曰夕阴。,王莽时叫夕阴县。


### Loading tokenizer

In [13]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModel,
    EncoderDecoderModel
    )

# we find a English parsing encoder, as a pretrained bert is good at understanding english
# BERT is short for Bidirectional **Encoder** Representations from Transformers, which consists fully of encoder blocks
ENCODER_PRETRAINED = "bert-base-chinese"
# we find a Chinese writing model for decoder, as decoder is the part of the model that can write stuff
DECODER_PRETRAINED = "uer/gpt2-chinese-poem"

encoder_tokenizer = AutoTokenizer.from_pretrained(ENCODER_PRETRAINED)

decoder_tokenizer = AutoTokenizer.from_pretrained(
    ENCODER_PRETRAINED # notice we use the BERT's tokenizer here
)

### Pytoch Dataset

In [14]:
class Seq2Seq(Dataset):
    def __init__(
        self, df, tokenizer, target_tokenizer,
        max_len=128,
        no_punkt:bool = False,
    ):
        """
        no_punkt, do we ramdomly remove punctuation
        from source sentence
        """
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer
        self.target_tokenizer = target_tokenizer
        self.max_len = max_len
        self.no_punkt = no_punkt
        
    def __len__(self, ):
        return len(self.df)

    def __getitem__(self, idx):
        return dict(self.df.iloc[idx])

    def collate(self, batch):
        batch_df = pd.DataFrame(list(batch))
        x, y = batch_df.source, batch_df.target
        # there is a random no punctuation mode
        # for source text
        # as some of the classical text we get
        # might be whole chunk of paragraph without
        # any punctuation
        if self.no_punkt:
            x = list(i if random.random()>.5
                     else remove_all_punkt(i)
                     for i in x)
        else:
            x = list(x)
        x_batch = self.tokenizer(
            x,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )
        y_batch = self.target_tokenizer(
            list(y),
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )
        x_batch['decoder_input_ids'] = y_batch['input_ids']
        x_batch['labels'] = y_batch['input_ids'].clone()
        x_batch['labels'][x_batch['labels'] == self.tokenizer.pad_token_id] = -100
        return x_batch

    def dataloader(self, batch_size, shuffle=True):
        return DataLoader(
            self,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=self.collate,
        )

    def split_train_valid(self, valid_size=0.1):
        split_index = int(len(self) * (1 - valid_size))
        cls = type(self)
        shuffled = self.df.sample(frac=1).reset_index(drop=True)
        train_set = cls(
            shuffled.iloc[:split_index],
            tokenizer=self.tokenizer,
            target_tokenizer=self.target_tokenizer,
            max_len=self.max_len,
            no_punkt=self.no_punkt,
        )
        valid_set = cls(
            shuffled.iloc[split_index:],
            tokenizer=self.tokenizer,
            target_tokenizer=self.target_tokenizer,
            max_len=self.max_len,
            no_punkt=self.no_punkt,
        )
        return train_set, valid_set

### PL datamodule

In [15]:
class Seq2SeqData(pl.LightningDataModule):
    def __init__(
        self, df,
        tokenizer,
        target_tokenizer,
        batch_size=12,
        max_len=128,
        no_punkt:bool=False):
        super().__init__()
        self.df = df
        self.ds = Seq2Seq(df,
                          tokenizer,
                          target_tokenizer,
                          max_len=max_len,
                          no_punkt=no_punkt)
        self.tokenizer = tokenizer
        self.target_tokenizer = target_tokenizer
        self.max_len = max_len
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_set, self.valid_set = self.ds.split_train_valid()

    def train_dataloader(self):
        return self.train_set.dataloader(
            batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return self.valid_set.dataloader(
            batch_size=self.batch_size*2, shuffle=False)

In [16]:
data_module = Seq2SeqData(
    df, encoder_tokenizer,
    decoder_tokenizer,
    batch_size=28,
    max_len=256,
    no_punkt=False if TO_CLASSICAL else True,)
data_module.setup()

In [17]:
inputs = next(iter(data_module.train_dataloader()))
inputs

{'input_ids': tensor([[ 101, 1921, 5688,  ...,    0,    0,    0],
        [ 101, 2828, 1062,  ...,    0,    0,    0],
        [ 101, 1039, 1469,  ...,    0,    0,    0],
        ...,
        [ 101,  718,  886,  ...,    0,    0,    0],
        [ 101, 1071, 1095,  ...,    0,    0,    0],
        [ 101, 1062, 1920,  ...,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'decoder_input_ids': tensor([[ 101, 1921, 5688,  ...,    0,    0,    0],
        [ 101, 1315,  752,  ...,    0,    0,    0],
        [ 101, 1039, 1469,  ...,    0,    0,    0],
        

if we are doing clasical Chinese to modern Chinese, we can randomly set half of the input without any punctuation, as many data source might be

In [18]:
encoder_tokenizer.batch_decode(
    inputs.input_ids,skip_special_tokens=True
)

['天 节 八 星 ， 在 毕 、 附 耳 南 ， 主 使 臣 持 节 宣 威 四 方 。',
 '把 公 子 成 的 话 报 告 给 赵 武 灵 王 。 武 灵 王 说 ： 我 就 知 道 王 叔 反 对 这 件 事 。 于 是 马 上 就 去 公 子 成 家 里 ， 亲 自 向 他 阐 述 自 己 的 观 点 ： 大 凡 衣 服 是 为 了 便 于 穿 用 ， 礼 制 是 为 了 便 于 办 事 。',
 '元 和 五 年 已 前 租 赋 并 放 。',
 '凡 杀 三 人 ， 伤 五 人 ， 手 驱 郎 吏 二 十 余 人 。',
 '杨 石 二 少 年 为 民 害 简 置 狱 中 谕 以 祸 福 咸 感 悟 愿 自 赎',
 '辛 亥 诸 将 自 汉 口 开 坝 引 船 入 沦 河 先 遣 万 户 阿 剌 罕 以 兵 拒 沙 芜 口 逼 近 武 矶 巡 视 阳 罗 城 堡 径 趋 沙 芜 遂 入 大 江',
 '江 东 民 户 殷 盛 风 俗 峻 刻 强 弱 相 陵 奸 吏 蜂 起 符 书 一 下 文 摄 相 续',
 '昏 夜 ， 平 善 ， 乡 晨 ， 傅 绔 袜 欲 起 ， 因 失 衣 ， 不 能 言 ， 昼 漏 上 十 刻 而 崩 。',
 '子 十 三 篇',
 '扶 风 民 鲁 悉 达 ， 纠 合 乡 人 以 保 新 蔡 ， 力 田 蓄 谷 。',
 '明 年 ， 又 贬 武 安 军 节 度 副 使 、 永 州 安 置 。',
 '良 久 徐 曰 恬 罪 故 当 死 矣',
 '部 曲 将 田 泓 请 没 水 潜 行 趣 彭 城 ， 玄 遣 之 。',
 '必 久 停 留 ， 恐 非 天 意 也 。',
 '具 传 其 业 又 默 讲 论 义 理 五 经 诸 子 无 不 该 览 加 博 好 技 艺 算 术 卜 数 医 药 弓 弩 机 械 之 巧 皆 致 思 焉',
 '苏 秦 初 合 纵 至 燕',
 '讼 者 言 词 忿 争 理 无 所 屈',
 '高 祖 闻 之 ， 曰 ： 二 将 和 ， 师 必 济 矣 。',
 '谧 兄 谌 字 兴 伯 性 平 和',
 '平 受 诏 ， 立 复 驰 至 宫 ， 哭 殊 悲 ； 因 固 请 得 宿 卫 中 。',
 '属 淮 阴 ， 击 破 齐 历 下 军 ， 击 田 解 。',
 '惇 与 蔡 卞 将 

### Load pretrained models

In [19]:
# encoder = AutoModel.from_pretrained(ENCODER_PRETRAINED, proxies={"http":"bifrost:3128"})
# decoder = AutoModelForCausalLM.from_pretrained(DECODER_PRETRAINED, add_cross_attention=True,
#                                                proxies={"http":"bifrost:3128"})

## Model

We create a seq2seq model by using pretrained encoder + pretrained decoder

In [20]:
# loading pretrained model
encoder_decoder = EncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_pretrained_model_name_or_path=ENCODER_PRETRAINED,
    decoder_pretrained_model_name_or_path=DECODER_PRETRAINED,
)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at uer/gpt2-chinese-poem and are newly initialized: ['transformer.h.6.crossa

In [21]:
class Seq2SeqTrain(pl.LightningModule):
    def __init__(self, encoder_decoder):
        super().__init__()
        self.encoder_decoder = encoder_decoder
        
    def forward(self, batch):
        return self.encoder_decoder(
                **batch
            )

    def training_step(self, batch, batch_idx):
        outputs = self(batch)
        self.log('loss', outputs.loss)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        outputs = self(batch)
        self.log('val_loss', outputs.loss)
        return outputs.loss
    
    def configure_optimizers(self):
        encoder_params = list(
            {"params":param,"lr":1e-5}
            for param in self.encoder_decoder.encoder.embeddings.parameters()) +\
            list({"params":param,"lr":1e-5}
            for param in self.encoder_decoder.encoder.encoder.parameters()) +\
            list({"params":param,"lr":1e-3}
            for param in self.encoder_decoder.encoder.pooler.parameters())

        decoder_params = list()
        for name, param in self.encoder_decoder.decoder.named_parameters():
            if 'ln_cross_attn' in name:
                decoder_params.append({"params":param,"lr":1e-3})
            elif 'crossattention' in name:
                decoder_params.append({"params":param,"lr":1e-3})
            elif 'lm_head' in name:
                decoder_params.append({"params":param,"lr":1e-4})
            else:
                decoder_params.append({"params":param,"lr":1e-5})

        return torch.optim.Adam(
                encoder_params + decoder_params,
                lr=1e-3,
            )

In [22]:
module = Seq2SeqTrain(encoder_decoder)

## Training

In [23]:
save = pl.callbacks.ModelCheckpoint(
    '/GCI/transformers/weights/cc_to_zh',
    save_top_k=2,
    verbose=True,
    monitor='val_loss',
    mode='min',
)

trainer = pl.Trainer(
    gpus=[1],
    max_epochs=10,
    callbacks=[save],
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [24]:
trainer.fit(module, datamodule=data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name            | Type                | Params
--------------------------------------------------------
0 | encoder_decoder | EncoderDecoderModel | 233 M 
--------------------------------------------------------
233 M     Trainable params
0         Non-trainable params
233 M     Total params
935.203   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'


Training: 0it [00:00, ?it/s]

  rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')
