# Conditional text generation

Data downloaded [here](https://github.com/chinese-poetry/chinese-poetry)

## Imports

In [2]:
# Forgebox Imports
from forgebox.imports import *
from gc_utils.env import *
import pytorch_lightning as pl
from transformers import (
    AutoTokenizer,
    GPT2LMHeadModel
)
import random
from typing import List
import re
from jieba import cut

In [3]:
def is_jupyter():
    try:
        get_ipython()
        return True
    except NameError:
        return False
    
IS_JUPYTER = is_jupyter()
if IS_JUPYTER:
    from tqdm.notebook import tqdm
else:
    from tqdm import tqdm

## Locations

In [4]:
DATA = sys_loc("DATA")/"nlp"/"zh"
DATA.ls()

['cc_vs_zh', 'cctc', 'cn_shi', 'daizhigev20']

In [5]:
POET = DATA/"cn_shi"
ALL_JSON = list(POET.rglob("*.json"))

## Read and transform data

In [6]:
def read_json(path):
    return json.loads(Path(path).read_text())

In [7]:
ci_dict = dict((str(i),pd.read_json(i))
               for i in tqdm(list(
                   DATA.rglob("cn_shi/ci/ci.song*.json"))))

shi_dict = dict((str(i),pd.read_json(i))
               for i in  tqdm(list(
                   DATA.rglob("cn_shi/json/poet.*.json"))))

all_df = pd.concat(list(ci_dict.values())+list(shi_dict.values()))[
    ["author","paragraphs","rhythmic"]].reset_index(drop=True)

  0%|          | 0/23 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

In [8]:
all_df = all_df.sample(frac=1.).reset_index(drop=True)

In [9]:
para = list(all_df["paragraphs"])[0]

In [11]:
def extract(
    paragraphs: List[str], puncts="，。？！?,.!"
    ):
    text = "".join(paragraphs)
    num_head = random.choice([2,3,4])
    heads = ""
    return_text = ""
    last_is_break = True
    for i, c in enumerate(text):
        if last_is_break:
            heads += c
            return_text += "[CLS]"
        else:
            return_text += c
        if len(heads) >= num_head:
            return_text += text[i+1:]
            break
        if c in puncts:
            last_is_break = True
        else:
            last_is_break = False
    return heads, return_text

In [12]:
extract(para)

('間曜翩過',
 '[CLS]維有常度，[CLS]靈無停輈。[CLS]翩葉辭柯，[CLS]眼綠已稠。弱榦不盈尺，忽已高岑樓。念昔過庭日，朋來悉良儔。我年未成童，子少無與侔。我質本駑駘，蹇步畏阻脩。子如渥洼駒，猛氣已食牛。當時二老人，笑語懽且酬。門戶各有托，寧計才與不。登門如昨日，星紀跡再周。二老安在哉，體魄歸山丘。隔屋聞讀書，玉樹鏘琳球。呼燈使來前，秀氣炯雙眸。問之垂九齡，屬對解冥搜。感此傷我心，淚下不可收。來者日已長，逝者挽不留。其間我與子，能閲幾春秋。寧復青衿佩，與子從親游。幸子齒猶壯，有母方白頭。刷翮凌青霄，足勝負米由。而我風樹悲，耿耿何時休。四十已無聞，過是夫何求。矧復病日侵，見面良可羞。竹實不療饑，芰製非寒裘。躬耕苦勤勞，代耕多悔尤。學仙竟誰成，百年等浮漚。俛仰天地間，身世真悠悠。時雨漲綠池，好風交平疇。嚶嚶出谷鳥，汎汎川上鷗。遇景適會心，曠望聊夷猶。')

## Get tokenizer

In [13]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

In [14]:
def replace_punctuation(text):
    return re.sub(r'[^\w\s]', ' ', text)

def cutting(text):
    return list(i for i in cut(replace_punctuation(text), HMM=True,)if i != ' ')

In [15]:
cutting("春眠不觉晓， 处处闻啼鸟")

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.665 seconds.
Prefix dict has been built successfully.


['春眠', '不觉', '晓', '处处', '闻啼鸟']

In [20]:
def pick_and_shuffle(li, min_n:int=0, max_n:int=None):
    if max_n is None:
        max_n = int(len(li)*.7)
    n = min_n + random.randint(0, min(max_n - min_n,10))
    random.shuffle(li)
    return list(set(li[:n]))

In [21]:
def create_kw(text):
    return pick_and_shuffle(cutting(text))

In [23]:
create_kw("春眠不觉晓， 处处闻啼鸟")

['晓', '春眠', '不觉']

In [24]:
heads, headless = extract(para)

In [36]:
heads, create_kw(headless.replace('[CLS]',""))

('間曜翩過', ['星紀跡', '我', '何求', '駘', '身世', '風樹悲', '雙眸', '復', '托', '與子'])

## Dataset

In [48]:
class PoetDataset(Dataset):
    def __init__(
        self,
        df,
        tokenizer,
        p_head:float=.2,
    ):
        self.df = df.sample(frac=1).reset_index(drop=True)
        self.tokenizer = tokenizer
        self.p_head = p_head
        self.cn_num_dict = dict((i+1,f"『{c}』") for i, c in enumerate("一二三四"))

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.loc[idx]
        paragraphs = row.paragraphs
        heads, headless = extract(paragraphs)
        kws = '-'.join(create_kw(headless.replace('[CLS]',"")))
        return f"{kws}《{heads}》{self.cn_num_dict.get(len(heads))}{headless}"
    
    def collate_fn(self, batch):
        texts = list(batch)
        batch = self.tokenizer(
            list(texts),
            max_length=256,
            padding='max_length',
            return_tensors='pt',
            truncation=True
        )
    
        labels = batch['input_ids'].clone()
        labels[labels==0] = -100
        batch['labels'] = labels
        return batch
    
    def dataloader(self, batch_size=32, shuffle=True):
        return DataLoader(
            self,
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=self.collate_fn
        )

    def split(self, val_ratio=.05):
        df = self.df.sample(frac=1).reset_index(drop=True)
        train_df = df[:int(len(df)*(1-val_ratio))]
        val_df = df[int(len(df)*(1-val_ratio)):]
        return PoetDataset(train_df, tokenizer=self.tokenizer),\
            PoetDataset(val_df, tokenizer=self.tokenizer)

In [49]:
poet_ds = PoetDataset(all_df, tokenizer)

Let's arrange the text data this way, so the casual language modeling will work it's own magic

In [51]:
poet_ds[1000]

'忍看-窈窕-孤寝-勾带-嫩-黄昏《粉度》『二』[CLS]堞云齐，[CLS]清笳、愁入暮烟林杪。素艳透春，玉骨凄凉，勾带月痕生早。江天苍莽黄昏後，依然是、粉寒香瘦。动追感、西园嫩约，夜深人悄。记得东风窈窕。曾夜踏横斜，醉携娇小。惆怅旧欢，回首俱非，忍看绿笺红豆。香销纸帐人孤寝，相思恨、花还知否。梦回处，霜飞翠楼已晓。'

In [52]:
dl = poet_ds.dataloader(12)

In [53]:
batch = next(iter(dl))

In [54]:
model = GPT2LMHeadModel.from_pretrained("uer/gpt2-chinese-poem")

In [55]:
class DataModule(pl.LightningDataModule):
    def __init__(self, dataset, batch_size=32):
        super().__init__()
        self.dataset = dataset
        self.batch_size = batch_size
        
    def setup(self, stage=None):
        self.train_dataset, self.val_dataset = self.dataset.split()

    def train_dataloader(self):
        return self.train_dataset.dataloader(
            batch_size = self.batch_size,
            shuffle=True)

    def val_dataloader(self):
        return self.val_dataset.dataloader(
            batch_size = self.batch_size*2,
            shuffle=False)

In [56]:
class CausalLMModule(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, **batch):
        return self.model(**batch)

    def training_step(self, batch, batch_idx):
        outputs = self(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch.labels,
        )
        loss = outputs.loss
        self.log("loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch.labels,
        )
        loss = outputs.loss
        self.log("val_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-5)

In [57]:
data_module = DataModule(poet_ds, batch_size=54)

In [58]:
module = CausalLMModule(model)

In [59]:
save = pl.callbacks.ModelCheckpoint(
    '/GCI/transformers/weights/kw_leading_po',
    save_top_k=2,
    verbose=True,
    monitor='val_loss',
    mode='min',
)

trainer = pl.Trainer(
    gpus=[1],
    max_epochs=6,
    callbacks=[save],
)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [161]:
trainer.fit(module, datamodule=data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type            | Params
------------------------------------------
0 | model | GPT2LMHeadModel | 103 M 
------------------------------------------
103 M     Trainable params
0         Non-trainable params
103 M     Total params
412.665   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
  f'The dataloader, {name}, does not have many workers which may be a bottleneck.'


Training: 0it [00:00, ?it/s]

  rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')


In [None]:
module.load_state_dict(
        torch.load(str(save.best), map_location="cpu")['state_dict'])

In [26]:
model = module.model
model = model.cpu()
model = model.eval()

In [27]:
model.save_pretrained(hub/"kw-lead-po")

In [None]:
model.push_to_hub("raynardj/keywords-cangtou-chinese-poetry")

In [28]:
def inference(lead):
    leading = f"《{lead}》"
    input_ids = tokenizer(leading, return_tensors='pt', ).input_ids
    with torch.no_grad():
        pred = model.generate(
            input_ids,
            max_length=256,
            num_beams=3,
#             do_sample=True,
#             top_p=.6,
            bos_token_id=tokenizer.sep_token_id,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.sep_token_id,
        )
    print(pred)
    return tokenizer.batch_decode(pred, skip_special_tokens=True)