# Translate model

We are using [this nice dataset](https://github.com/BangBOOM/Classical-Chinese)

## Imports

In [None]:
from forgebox.imports import *
from datasets import load_dataset
# from fastai.text.all import *
from unpackai.nlp import *
from tqdm.notebook import tqdm
import pytorch_lightning as pl

## Config

In [None]:
data=Path("/some_location/data")

In [None]:
DATA = Path(data/"nlp"/"zh"/"cc_vs_zh")
TO_CLASSICAL = True

## Data

### Combine data

In [None]:
all_file = list(DATA.rglob("data/*"))


def open_file_to_lines(file):
    with open(file) as f:
        lines = f.read().splitlines()
    return lines

def pairing_the_file(files,kw):
    pairs = []
    for file in files:
        if kw not in file.name:
            file1 = file
            file2 = f"{file}{kw}"
            pairs.append((file1,file2))
    return pairs

pairs = pairing_the_file(all_file,"翻译")

def open_pairs(pairs):
    chunks = []
    for pair in tqdm(pairs, leave=False):
        file1,file2 = pair
        lines1 = open_file_to_lines(file1)
        lines2 = open_file_to_lines(file2)
        chunks.append(pd.DataFrame({"classical":lines1,"modern":lines2}))
    return pd.concat(chunks).sample(frac=1.).reset_index(drop=True)

data_df = open_pairs(pairs)

df = data_df.rename(
    columns = dict(
        zip(["modern","classical"],
             ["source","target"] if TO_CLASSICAL else ["target","source",]))
)

df.head()

### Loading tokenizer

In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoModel,
    EncoderDecoderModel
    )

# we find a English parsing encoder, as a pretrained bert is good at understanding english
# BERT is short for Bidirectional **Encoder** Representations from Transformers, which consists fully of encoder blocks
ENCODER_PRETRAINED = "bert-base-chinese"
# we find a Chinese writing model for decoder, as decoder is the part of the model that can write stuff
DECODER_PRETRAINED = "uer/gpt2-chinese-poem"

encoder_tokenizer = AutoTokenizer.from_pretrained(ENCODER_PRETRAINED)

decoder_tokenizer = AutoTokenizer.from_pretrained(
    ENCODER_PRETRAINED # notice we use the BERT's tokenizer here
)

### Pytoch Dataset

In [None]:
class Seq2Seq(Dataset):
    def __init__(self, df, tokenizer, target_tokenizer, max_len=128):
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer
        self.target_tokenizer = target_tokenizer
        self.max_len = max_len
        
    def __len__(self, ):
        return len(self.df)

    def __getitem__(self, idx):
        return dict(self.df.iloc[idx])

    def collate(self, batch):
        batch_df = pd.DataFrame(list(batch))
        x, y = batch_df.source, batch_df.target
        x_batch = self.tokenizer(
            list(x),
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )
        y_batch = self.target_tokenizer(
            list(y),
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )
        x_batch['decoder_input_ids'] = y_batch['input_ids']
        x_batch['labels'] = y_batch['input_ids'].clone()
        x_batch['labels'][x_batch['labels'] == self.tokenizer.pad_token_id] = -100
        return x_batch

    def dataloader(self, batch_size, shuffle=True):
        return DataLoader(
            self,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=self.collate,
        )

    def split_train_valid(self, valid_size=0.1):
        split_index = int(len(self) * (1 - valid_size))
        cls = type(self)
        shuffled = self.df.sample(frac=1).reset_index(drop=True)
        train_set = cls(
            shuffled.iloc[:split_index],
            tokenizer=self.tokenizer,
            target_tokenizer=self.target_tokenizer,
            max_len=self.max_len,
        )
        valid_set = cls(
            shuffled.iloc[split_index:],
            tokenizer=self.tokenizer,
            target_tokenizer=self.target_tokenizer,
            max_len=self.max_len,
        )
        return train_set, valid_set

### PL datamodule

In [None]:
class Seq2SeqData(pl.LightningDataModule):
    def __init__(self, df, tokenizer, target_tokenizer, batch_size=12, max_len=128):
        super().__init__()
        self.df = df
        self.ds = Seq2Seq(df, tokenizer, target_tokenizer,max_len=max_len)
        self.tokenizer = tokenizer
        self.target_tokenizer = target_tokenizer
        self.max_len = max_len
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_set, self.valid_set = self.ds.split_train_valid()

    def train_dataloader(self):
        return self.train_set.dataloader(
            batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return self.valid_set.dataloader(
            batch_size=self.batch_size*2, shuffle=False)

data_module = Seq2SeqData(df, encoder_tokenizer, decoder_tokenizer, batch_size=64, )
data_module.setup()

In [None]:
next(iter(data_module.train_dataloader()))

### Load pretrained models

In [None]:
# loading pretrained model
encoder_decoder = EncoderDecoderModel.from_encoder_decoder_pretrained(
    encoder_pretrained_model_name_or_path=ENCODER_PRETRAINED,
    decoder_pretrained_model_name_or_path=DECODER_PRETRAINED,
)

In [None]:
class Seq2SeqTrain(pl.LightningModule):
    def __init__(self, encoder_decoder):
        super().__init__()
        self.encoder_decoder = encoder_decoder
        
    def forward(self, batch):
        return self.encoder_decoder(
                **batch
            )

    def training_step(self, batch, batch_idx):
        outputs = self(batch)
        self.log('loss', outputs.loss)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        outputs = self(batch)
        self.log('val_loss', outputs.loss)
        return outputs.loss
    
    def configure_optimizers(self):
        encoder_params = list(
            {"params":param,"lr":1e-5}
            for param in self.encoder_decoder.encoder.embeddings.parameters()) +\
            list({"params":param,"lr":1e-5}
            for param in self.encoder_decoder.encoder.encoder.parameters()) +\
            list({"params":param,"lr":1e-3}
            for param in self.encoder_decoder.encoder.pooler.parameters())

        decoder_params = list()
        for name, param in self.encoder_decoder.decoder.named_parameters():
            if 'ln_cross_attn' in name:
                decoder_params.append({"params":param,"lr":1e-3})
            elif 'crossattention' in name:
                decoder_params.append({"params":param,"lr":1e-3})
            elif 'lm_head' in name:
                decoder_params.append({"params":param,"lr":1e-4})
            else:
                decoder_params.append({"params":param,"lr":1e-5})

        return torch.optim.Adam(
                encoder_params + decoder_params,
                lr=1e-3,
            )

In [None]:
module = Seq2SeqTrain(encoder_decoder)

## Training

In [None]:
save = pl.callbacks.ModelCheckpoint(
    data/'../weights/cc_to_zh',
    save_top_k=2,
    verbose=True,
    monitor='val_loss',
    mode='min',
)

trainer = pl.Trainer(
    gpus=[0],
    max_epochs=10,
    callbacks=[save],
)

In [None]:
trainer.fit(module, datamodule=data_module)

## Inference

In [None]:
best = save.best
module.load_state_dict(torch.load(best, map_location="cpu")['state_dict'])


encoder_decoder = encoder_decoder.cpu()
encoder_decoder = encoder_decoder.eval()

def inference(text, starter=''):
    tk_kwargs = dict(truncation=True, max_length=128, padding="max_length",
                     return_tensors='pt')
    inputs = encoder_tokenizer([text,],**tk_kwargs)
    with torch.no_grad():
        return decoder_tokenizer.batch_decode(
            encoder_decoder.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            num_beams=3,
            bos_token_id=101,
        ),
                                              skip_special_tokens=True)

In [None]:
inference('我来跟大家说一句话')

In [None]:
inference("这个翻译不是很聪明，因为训练数据不够")

In [None]:
encoder_decoder.push_to_hub("raynardj/wenyanwen-chinese-translate-to-ancient")
encoder_tokenizer.push_to_hub("raynardj/wenyanwen-chinese-translate-to-ancient")