In [1]:
# References:
# https://github.com/karpathy/nanoGPT
# https://github.com/openai/gpt-2/
# https://github.com/huggingface/transformers/tree/main/src/transformers/models/gpt2

# Data:
# https://github.com/Werneror/Poetry

In [2]:
# 1st attempt. Basic MLP model
# 2nd try. Added multi-head casual attention and layernorm and skip connection
# 3rd version. Model after "Attention is all you need" with multiple attention layers (exclude encoder cross attention) ->val loss 4.92
# 4th. Increase Model size/depth.
# todo: 
#drop invalid char '{', etc.
#(FAIL)train data use whole poem, per index.
#(FAIL)find max len of poem and use it as context length

In [3]:
import torch
from torch import nn
import pickle
from torch import tensor
import torch.nn.functional as F
import pandas as pd

In [4]:
with open("meta.pkl", "rb") as f:
    meta = pickle.load(f)
t2i = meta['t2i']
i2t = meta['i2t']
encode = lambda x: [t2i[c] for c in x]
decode = lambda x: "".join([i2t[i] for i in x])

In [5]:
decode([0,1,3])

'$?、'

In [6]:
encode("$?一")

[0, 1, 132]

In [24]:
batch_size = 128 # B, batch size
block_size = 48 # T, context len for poem is shorter, to set to 48
vocab_size = len(t2i.keys())
nn_emb_size = 64 # nn_emb
n_head = 16
n_layers = 8

device = "cuda"

In [8]:
def encode_pad(s):
    if len(s) >= block_size:
        sample = s[:block_size]
    else:
        sample = s
    sample = encode(s)
    sample = [0]*(block_size-len(sample)) + sample    
    inp = tensor(sample[:block_size])[None,...]
    return inp

In [9]:
targ = encode_pad("叶唐").to(device)
targ

tensor([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0, 784, 905]], device='cuda:0')

In [10]:
inp = encode_pad("黑叶唐").to(device)
inp

tensor([[   0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0, 7401,  784,  905]],
       device='cuda:0')

In [11]:
decode(inp[0].tolist())

'$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$黑叶唐'

In [12]:
df = pd.read_csv("./Poetry/唐.csv")

In [13]:
# drop invalid char "{", '□', '《'
df = df[df['内容'].str.contains("{") == False]
df = df[df['内容'].str.contains('□') == False]
df = df[df['内容'].str.contains('《') == False]

# drop context len > block_size
#df = df[df['内容'].str.len()<=block_size]

In [14]:
rare_chars = [
'㒿',
 '㗫',
 '㗻',
 '㘭',
 '㟅',
 '㟍',
 '㠔',
 '㩳',
 '㪍',
 '㪷',
 '㬠',
 '㭊',
 '㴩',
 '㵳',
 '㶁',
 '㸌',
 '㸙',
 '㸦',
 '㹀',
 '㹞',
 '㾕',
 '䆉',
 '䆗',
 '䋏',
 '䍥',
 '䍦',
 '䍲',
 '䍽',
 '䏑',
 '䏶',
 '䐑',
 '䑳',
 '䒠',
 '䔩',
 '䔫',
 '䖘',
 '䖟',
 '䗈',
 '䗫',
 '䙰',
 '䛏',
 '䜝',
 '䨴',
 '䩋',
 '䪻',
 '䫻',
 '䭀',
 '䭃',
 '䭔',
 '䯀',
 '䱐',
 '䲺',
 '䳒',
 '䴏',
 '䴥',
 '䴵',
 '䴺',
 '䶎']

for c in rare_chars:
    df = df[df['内容'].str.contains(c) == False]

In [15]:
all_data = df['内容'].tolist()
data = "".join(all_data)
train_d = data[:int(len(data)*0.9)]
valid_d = data[int(len(data)*0.9):]
train_d[:100]

'风淅淅。夜雨连云黑。滴滴。窗外芭蕉灯下客。除非魂梦到乡国。免被关山隔。忆忆。一句枕前争忘得。别路云初起，离亭叶正飞。所嗟人异雁，不作一行归。弄玉有夫皆得道，刘纲兼室尽登仙。君能仔细窥朝露，须逐云车拜洞'

In [16]:
valid_d[:100]

'草皆生。朝来门閤无事，晚下高斋有情。胡马，胡马，远放燕支山下。咆沙咆雪独嘶，东望西望路迷。迷路，迷路，边草无穷日暮。河汉，河汉，晓挂秋城漫漫。愁人起望相思，江南塞北别离。离别，离别，河汉虽同路绝。上界'

In [135]:
def adjust_idxs(idxs):
    for i, idx in enumerate(idxs):
        while True:
            idx = idxs_adjusted[i]
            if dataLoaded[idx] != "。" and idx < (len(dataLoaded)-block_size-1):
                idxs_adjusted[i] = idx+1
            else:
                break
        if dataLoaded[idx] == "。" and idx < (len(dataLoaded)-block_size-1):
            idxs_adjusted[i] = idx+1
            idx = idxs_adjusted[i]
            
def load_data(type="train"):
    if type == "train":
        dataLoaded = train_d
    else:
        dataLoaded = valid_d
    idxs = torch.randint(len(dataLoaded)-block_size-1,(batch_size,))
    #print(idxs)
    idxs = idxs_adjusted = list(idxs)

            
            
    inp = [dataLoaded[i:i+block_size] for i in idxs]
    targ = [dataLoaded[i+1:i+block_size+1] for i in idxs]
    inp = [ encode_pad(i) for i in inp]
    targ = [ encode_pad(i) for i in targ]
    return torch.cat(inp, dim=0).to(device), torch.cat(targ, dim=0).to(device)

In [136]:
i, t = load_data()
i

tensor([[1578, 1235, 7471,  ..., 5108,    4, 5527],
        [ 305, 2026, 2635,  ...,  305,  177,    4],
        [6177, 6166,    4,  ..., 4453,  802, 1714],
        ...,
        [3933, 1844,  722,  ..., 2128, 2098, 5624],
        [7471, 2419,  191,  ..., 6503, 5788, 1052],
        [4733,  160, 4074,  ...,  177, 1814, 7471]], device='cuda:0')

In [19]:
class AttentionBlock(nn.Module):
    def __init__(self, nn_emb = nn_emb_size, block_size = block_size, n_head = n_head):
        super().__init__()
        self.nn_emb = nn_emb_size
        self.block_size = block_size
        self.n_head = n_head

        self.emb_proj = nn.Linear(nn_emb, nn_emb * 3)
        self.ln_1 = nn.LayerNorm(nn_emb) 
        self.mult_head = nn.MultiheadAttention(nn_emb, n_head, dropout=0.2, batch_first=True)
        self.ln_2 = nn.LayerNorm(nn_emb) 
        self.ff = nn.Sequential(nn.Linear(nn_emb, nn_emb * 4),nn.GELU(), nn.Dropout(0.2), nn.Linear(nn_emb * 4, nn_emb), nn.GELU(), nn.Dropout(0.2))

    def forward(self,x): # (B, T, nn_emb)
        x1 = x
        x = self.emb_proj(x) # (B, T, nn_emb*3)
        q,k,v = x.split(self.nn_emb, dim=2)
        x,_ = self.mult_head(q, k, v,  key_padding_mask=None, need_weights=False, attn_mask=torch.nn.Transformer.generate_square_subsequent_mask(self.nn_emb), average_attn_weights=True, is_causal=True) # (B,T,nn_emb)
        x = x+x1
        x = self.ff(self.ln_2(x)) + x
        return x
        
        
class Model(nn.Module):
    def __init__(self, nn_emb = nn_emb_size, block_size = block_size,vocab_size = vocab_size, n_head = n_head, n_layers = n_layers): 
        super().__init__()
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.nn_emb = nn_emb
        self.n_head = n_head
        self.n_layers = n_layers
        
        self.tk_emb = nn.Embedding(vocab_size, nn_emb)
        self.pos_emb = nn.Embedding(block_size, nn_emb)
        self.ln = nn.LayerNorm(nn_emb)
        #self.emb_proj = nn.Linear(nn_emb, nn_emb * 3)
        #self.atten = nn.MultiheadAttention(nn_emb, n_head, dropout=0.2, batch_first=True)
        self.attention_blocks = nn.ModuleList( [AttentionBlock(nn_emb, block_size, n_head)] * n_layers)
        #self.h = nn.Sequential(nn.Linear(nn_emb, nn_emb),nn.GELU(), nn.Dropout(0.2), nn.Linear(nn_emb, nn_emb), nn.GELU(), nn.Dropout(0.2))
        self.ln_h = nn.Linear(nn_emb, self.vocab_size)

    def forward(self, inp, targ = None): # inp is (B, T), targ is (B, T)
        inp.to(device)
        tk = self.tk_emb(inp) # (B,T,nn_emb)
        positions = torch.arange(self.block_size).to(device)
        #print(positions)
        pos = self.pos_emb(positions) # (T,nn_emb)
        x = tk + pos # (B,T,nn_emb)
        #x = self.ln(x)        
        #a = x
        #x = self.emb_proj(x) # (B,t,nn_emb*3)
        for blk in self.attention_blocks:
            x = blk(x)
        #q,k,v = x.split(self.nn_emb, dim=2)
        #x,_ = self.atten(q, k, v,  key_padding_mask=None, need_weights=False, attn_mask=torch.nn.Transformer.generate_square_subsequent_mask(self.nn_emb), average_attn_weights=True, is_causal=True) # (B,T,nn_emb)
        #x = x + a
        #x = self.ln(x)                
        #x = x+self.h(x) # (B,T,nn_emb)
        x = self.ln(x) # (B,T,nn_emb)                
        x = self.ln_h(x) # (B,T,vocab_size)
        if targ == None:
            loss = None
        else:
            targ.to(device)
            loss = F.cross_entropy(x.view(-1, x.shape[-1]), targ.view(-1))
        return x, loss

m = Model()
m.to(device)
optim = torch.optim.AdamW(m.parameters(), lr=1e-3)

out, loss = m(inp, targ)
loss

tensor(8.5572, device='cuda:0', grad_fn=<NllLossBackward0>)

In [118]:
steps = 1000
optim = torch.optim.AdamW(m.parameters(), lr=1e-5)

In [119]:
def train():
    m.train()
    losses_train = 0
    losses_valid = 0
    for i in torch.arange(steps):
        inp, targ = load_data()
        out, loss = m(inp, targ)
        loss.backward()
        with torch.no_grad():
            optim.step()
            optim.zero_grad()
        losses_train += loss.item()
        if i % 100 == 99:
            print("Train:" + str(losses_train/100))
            losses_train = 0
        

        inp, targ = load_data("valid")
        out, loss = m(inp, targ)
        losses_valid += loss.item()
        if i % 100 == 99:
            print("Valid:" + str(losses_valid/100))
            losses_valid = 0
        
     

In [140]:
steps = 1000
optim = torch.optim.AdamW(m.parameters(), lr=1e-4)
train()   

Train:5.094790782928467
Valid:5.1303889226913455
Train:5.03735876083374
Valid:5.07416754245758
Train:5.008514714241028
Valid:5.048934197425842
Train:4.984685220718384
Valid:5.02165885925293
Train:4.967304706573486
Valid:5.005953598022461
Train:4.959701714515686
Valid:4.997688221931457
Train:4.945753922462464
Valid:4.98355993270874
Train:4.9380385255813595
Valid:4.969886789321899
Train:4.934246792793274
Valid:4.9641446352005
Train:4.919867401123047
Valid:4.956051287651062


In [163]:
with open("model_v4.pkl", "wb") as f:
    pickle.dump(m,f)

In [165]:
with open("model_v4.pkl","rb") as f:
    m=pickle.load(f)

In [168]:
torch.save(m, "model_v4t.pkl")

In [42]:
inp = '终南'
inp = encode_pad(inp).to(device)
#inp[0].shape
out, loss = m(inp)
prob = torch.softmax(out[:,-1,:], dim=-1)
g = torch.multinomial(prob, num_samples=1)
#g[0].item()
i2t[g[0].item()]

'寺'

In [158]:
top_k = 30
def generate(s, num = 60):

    for i in range(num + num):
        inp = s[-block_size:]
        inp = encode_pad(inp).to(device)
        out, loss = m(inp)
        out = out[:,-1,:]
        if top_k is not None:
            v, _ = torch.topk(out, min(top_k, out.size(-1)))
            out[out < v[:, [-1]]] = -float('Inf')        
        prob = torch.softmax(out[:,:], dim=-1)
        g = torch.multinomial(prob, num_samples=30)
        next_c = i2t[g[0][0].item()]

        if not s[-6:].find("，") and not s[-6:].find("。") :
            for c in g[0]:        
                c_d = i2t[c.item()]
                #print(c_d)
                if c_d == '，' or c_d == '。': 
                    next_c = c_d
                    break

        if next_c in s and next_c != '。' and next_c != '，':
            continue
            
        s = s + next_c

        if (len(s) > num and s[-1] == "。"):
            break
    return s
    

In [161]:
generate('终南。')

'终南。年少无事长安一种，亦不可怜心骨更。十二三巴山中路，却到潼关天畔人。夜深城晓楼船起，犹喜秋愁月，日下江湖水。孤雁多飞去，何时送宿尘。'

In [162]:
generate('灵者。')

'灵者。何日东归白头儿亦在，不知人少小相识，莫是生心身去有。今宵复此别，夜雨又无花。欲问山水滨，只应云雾中。江南千里寺三秋，风景四星霜雪曙。'