In [1]:
# References:
# https://github.com/karpathy/nanoGPT
# https://github.com/openai/gpt-2/
# https://github.com/huggingface/transformers/tree/main/src/transformers/models/gpt2

# Data:
# https://github.com/Werneror/Poetry

In [2]:
import torch
from torch import nn
import pickle
from torch import tensor
import torch.nn.functional as F
import pandas as pd

In [3]:
df = pd.read_csv("./Poetry/唐.csv")

In [4]:
# drop invalid char "{", '□', '《'
df = df[df['内容'].str.contains("{") == False]
df = df[df['内容'].str.contains('□') == False]
df = df[df['内容'].str.contains('《') == False]
df = df[df['内容'].str.contains('◇') == False]
df = df[df['内容'].str.contains('「') == False]

In [5]:
# drop rare char
rare_chars = [
 '㑻',
 '㓤',
 '㓨',
 '㔩',
 '㕙',
 '㕭',
 '㖇',
 '㖷',
 '㗀',
 '㗧',
 '㘈',
 '㙞',
 '㙧',
 '㙻',
 '㜕',
 '㟏',
 '㟼',
 '㠁',
 '㠢',
 '㠥',
 '㡊',
 '㤝',
 '㤞',
 '㥄',
 '㦎',
 '㦨',
 '㦬',
 '㧞',
 '㨖',
 '㨝',
 '㩧',
 '㩻',
 '㪇',
 '㯕',
 '㰂',
 '㰕',
 '㰤',
 '㰹',
 '㱔',
 '㱥',
 '㲉',
 '㲚',
 '㲪',
 '㴲',
 '㵝',
 '㹠',
 '㺄',
 '㺑',
 '㺗',
 '㺠',
 '㺦',
 '㻞',
 '㼐',
 '㼚',
 '㽅',
 '㽬',
 '㾌',
 '㾓',
 '㿉',
 '䀣',
 '䀨',
 '䁠',
 '䂍',
 '䂓',
 '䃂',
 '䃔',
 '䃧',
 '䃸',
 '䅎',
 '䆕',
 '䆛',
 '䆷',
 '䇹',
 '䈉',
 '䈴',
 '䉛',
 '䉦',
 '䌨',
 '䌰',
 '䍡',
 '䑔',
 '䔿',
 '䕓',
 '䕭',
 '䕹',
 '䗁',
 '䗱',
 '䘧',
 '䘨',
 '䙀',
 '䛠',
 '䜕',
 '䝟',
 '䞋',
 '䟃',
 '䟐',
 '䢇',
 '䤨',
 '䥓',
 '䥫',
 '䦪',
 '䦰',
 '䦱',
 '䨟',
 '䨥',
 '䪗',
 '䪜',
 '䪥',
 '䫇',
 '䫜',
 '䬓',
 '䬔',
 '䬘',
 '䬴',
 '䮂',
 '䮧',
 '䯔',
 '䯗',
 '䯱',
 '䰀',
 '䱙',
 '䱜',
 '䱬',
 '䴊',
 '䶉',    
'㒿',
 '㗫',
 '㗻',
 '㘭',
 '㟅',
 '㟍',
 '㠔',
 '㩳',
 '㪍',
 '㪷',
 '㬠',
 '㭊',
 '㴩',
 '㵳',
 '㶁',
 '㸌',
 '㸙',
 '㸦',
 '㹀',
 '㹞',
 '㾕',
 '䆉',
 '䆗',
 '䋏',
 '䍥',
 '䍦',
 '䍲',
 '䍽',
 '䏑',
 '䏶',
 '䐑',
 '䑳',
 '䒠',
 '䔩',
 '䔫',
 '䖘',
 '䖟',
 '䗈',
 '䗫',
 '䙰',
 '䛏',
 '䜝',
 '䨴',
 '䩋',
 '䪻',
 '䫻',
 '䭀',
 '䭃',
 '䭔',
 '䯀',
 '䱐',
 '䲺',
 '䳒',
 '䴏',
 '䴥',
 '䴵',
 '䴺',
 '䶎']

for c in rare_chars:
    df = df[df['内容'].str.contains(c) == False]

In [6]:
all_data = df['内容'].tolist()
data = "".join(all_data)

In [7]:
vocab = sorted(list(set(data)))
# add padding token
vocab = ["$"] + vocab

In [9]:
vocab[:10]

['$', '?', '、', '。', '一', '丁', '七', '万', '丈', '三']

In [10]:
i2t = {k:v for k,v in enumerate(vocab)}
{k:i2t[k] for k in list(i2t.keys())[0:5]}

{0: '$', 1: '?', 2: '、', 3: '。', 4: '一'}

In [11]:
t2i = {v:k for k,v in enumerate(vocab)}
{k:t2i[k] for k in list(t2i.keys())[0:5]}

{'$': 0, '?': 1, '、': 2, '。': 3, '一': 4}

In [12]:
def encode(s): return [ t2i[c] for c in s]
encode("$?。")

[0, 1, 3]

In [13]:
def decode(l): return "".join([ i2t[i] for i in l])
decode([0, 1,3])

'$?。'

In [14]:
batch_size = 4 # B, batch size
block_size = 8 # T, context len for poem is shorter, to set to 48
vocab_size = len(t2i.keys())
nn_emb_size = 64 # nn_emb

device = "cuda"

In [15]:
def encode_pad(s):
    if len(s) >= block_size:
        sample = s[:block_size]
    else:
        sample = s
    sample = encode(s)
    sample = [0]*(block_size-len(sample)) + sample    
    inp = tensor(sample[:block_size])[None,...]
    return inp

In [16]:
targ = encode_pad("叶唐").to(device)
targ

tensor([[  0,   0,   0,   0,   0,   0, 647, 768]], device='cuda:0')

In [17]:
inp = encode_pad("黑叶唐").to(device)
inp

tensor([[   0,    0,    0,    0,    0, 7119,  647,  768]], device='cuda:0')

In [18]:
train_d = data[:int(len(data)*0.9)]
valid_d = data[int(len(data)*0.9):]
train_d[:10]

'风淅淅。夜雨连云黑。'

In [19]:
valid_d[:10]

'雪。未老莫还乡，还乡'

In [20]:
def load_data(type="train"):
    if type == "train":
        dataLoaded = train_d
    else:
        dataLoaded = valid_d
    idxs = torch.randint(len(dataLoaded),(batch_size,))
    inp = [dataLoaded[i:i+block_size] for i in idxs]
    targ = [dataLoaded[i+1:i+block_size+1] for i in idxs]
    inp = [ encode_pad(i) for i in inp]
    targ = [ encode_pad(i) for i in targ]
    return torch.cat(inp, dim=0).to(device), torch.cat(targ, dim=0).to(device)

In [21]:
i, t = load_data()
i

tensor([[7112, 2876, 7182, 4071,  361,  420, 4511, 1289],
        [2470, 2545, 1550, 1550, 1117, 7182,   12, 1698],
        [4085,  121, 7182, 3108,  671, 1201, 6614, 5465],
        [  26, 3857,   11,  568, 1289, 7182, 1331,  381]], device='cuda:0')

In [22]:
class Model(nn.Module):
    def __init__(self, nn_emb = nn_emb_size, block_size = block_size,vocab_size = vocab_size): 
        super().__init__()
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.nn_emb = nn_emb
        self.tk_emb = nn.Embedding(vocab_size, nn_emb)
        self.pos_emb = nn.Embedding(block_size, nn_emb)
        self.h = nn.Sequential(nn.Linear(nn_emb, nn_emb),nn.GELU(), nn.Linear(nn_emb, nn_emb), nn.GELU())
        self.ln_h = nn.Linear(nn_emb, self.vocab_size)

    def forward(self, inp, targ = None): # inp is (B, T), targ is (B, T)
        inp.to(device)
        tk = self.tk_emb(inp) # (B,T,nn_emb)
        positions = torch.arange(self.block_size).to(device)
        pos = self.pos_emb(positions) # (T,nn_emb)
        x = tk + pos # (B,T,nn_emb)
        x = self.h(x) # (B,T,nn_emb)
        x = self.ln_h(x) # (B,T,vocab_size)
        if targ == None:
            loss = None
        else:
            targ.to(device)
            loss = F.cross_entropy(x.view(-1, x.shape[-1]), targ.view(-1))
        return x, loss

m = Model()
m.to(device)
optim = torch.optim.AdamW(m.parameters(), lr=1e-3)

out, loss = m(inp, targ)
loss

tensor(8.7705, device='cuda:0', grad_fn=<NllLossBackward0>)

In [24]:
steps = 1000

In [26]:
def train():
    m.train()
    for i in torch.arange(steps):
        inp, targ = load_data()
        out, loss = m(inp, targ)
        loss.backward()
        with torch.no_grad():
            optim.step()
            optim.zero_grad()
        print("Train:" + str(loss.item()))

        inp, targ = load_data("valid")
        out, loss = m(inp, targ)
        print("Valid:" + str(loss.item()))     

In [30]:
train()

Train:6.360795974731445
Valid:6.608985900878906
Train:6.899512767791748
Valid:6.94514274597168
Train:6.600334167480469
Valid:6.650232315063477
Train:6.770755767822266
Valid:7.279598712921143
Train:7.250162601470947
Valid:7.040035247802734
Train:6.905786991119385
Valid:6.217785358428955
Train:6.6282196044921875
Valid:6.585540294647217
Train:6.821276664733887
Valid:6.883055210113525
Train:7.163849353790283
Valid:6.628798484802246
Train:6.221627712249756
Valid:6.786205291748047
Train:6.728028774261475
Valid:6.542478561401367
Train:6.982696056365967
Valid:6.716482639312744
Train:6.5275492668151855
Valid:7.1761794090271
Train:6.832442760467529
Valid:6.932262420654297
Train:6.247867107391357
Valid:6.226308345794678
Train:6.806585311889648
Valid:5.916360378265381
Train:6.850294589996338
Valid:6.7022247314453125
Train:6.2391767501831055
Valid:7.083147048950195
Train:6.488284111022949
Valid:6.199264049530029
Train:6.4270853996276855
Valid:6.910810947418213
Train:6.987454414367676
Valid:6.564404

In [31]:
inp = '终南'
inp = encode_pad(inp).to(device)
out, loss = m(inp)
prob = torch.softmax(out[:,-1,:], dim=-1)
g = torch.multinomial(prob, num_samples=1)
i2t[g[0].item()]

'红'

In [32]:
def generate(s, num = 16):
    for i in range(num):
        inp = s
        inp = encode_pad(inp).to(device)
        out, loss = m(inp)
        prob = torch.softmax(out[:,-1,:], dim=-1)
        g = torch.multinomial(prob, num_samples=1)
        s = s + i2t[g[0].item()]
    return s

In [86]:
generate('终南')

'终南刘忽。秋归凡日。有化之名寻浪一，'

In [87]:
generate('灵者')

'灵者香今斗，绮谢双径滨。虹炎百在旆江'