In [1]:
import os
import time
from collections import deque
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import json
import glob


In [2]:
if not os.path.exists('./previews'):
    os.makedirs('./previews')
if not os.path.exists('./checkpoints'):
    os.makedirs('./checkpoints')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
n_words = 20 # 五言絕句
n_class = 0
n_noise = 64

In [3]:
def poet_filter(x):
    return len(''.join(x['paragraphs']))==20 and len(''.join(x['strains']))==20 # 五言絕句，不包含逗號句號分號

def poet_preprocess(x): # 去除標點符號
    x['paragraphs'] = ''.join(filter(lambda x: x!='，' and x!='。' and x!='；' ,list(''.join(x['paragraphs']))))
    x['strains'] = ''.join(filter(lambda x: x!='，' and x!='。' and x!='；' , list(''.join(x['strains']))))
    return x

def poet_data_reader(search='./chinese-poetry-master/json/poet.song.*.json', filters=lambda x: True):
    file_list = glob.glob(search)
    data = []
    for fname in file_list:
        with open(fname, 'r') as fp:
            data += filter(poet_filter, map(poet_preprocess , json.loads(fp.read())))
    return data
def gen_dict(poets):
    char_set = dict()
    char_set_inv = dict()
    for poet in poets:
        context = list(poet['paragraphs'])
        for c in filter(lambda x: x not in char_set, context):
            l = len(char_set)
            char_set[c] = l
            char_set_inv[l] = c
    return char_set, char_set_inv
def encode_context(context, charset):
    def f(x):
        return charset[x] if x in charset else 0
    return list(map(f, list(context)))
def one_hot(x, n_class):
    ohe = np.zeros((len(x), n_class), dtype=np.uint8)
    ohe[np.arange(len(x)), x] = 1
    return ohe
def str2ohe(x, charset):
    return one_hot(encode_context(x, charset), len(charset))
def ohe2str(x, charset_inv):
    x = np.argmax(x,axis=-1)
    return ''.join(list(map(lambda a: charset_inv[a], list(x))))

In [4]:
raw_data = poet_data_reader(search='./chinese-poetry-master/json/poet.song.*.json', filters=poet_filter)
charset, charset_inv = gen_dict(raw_data)
strainset = {'平': 0, '仄': 1}
with open('./charset.json', 'w') as fp:
    fp.write(json.dumps(charset))
with open('./strainset.json', 'w') as fp:
    fp.write(json.dumps(strainset))

In [5]:
data_ohe = torch.from_numpy(np.asarray(list(map(lambda x: str2ohe(x['paragraphs'], charset), raw_data)), dtype=np.float32).transpose(0,2,1))
data_ohe = data_ohe * 2 - 1 # [0, 1] -> [-1, +1]
data_ohe.size()

torch.Size([13260, 5798, 20])

In [6]:
label_ohe = torch.from_numpy(np.asarray(list(map(lambda x: str2ohe(x['strains'], strainset), raw_data)), dtype=np.float32).transpose(0,2,1))
label_ohe = label_ohe * 2 - 1 # [0, 1] -> [-1, +1]
label_ohe.size()

torch.Size([13260, 2, 20])

In [7]:
n_class = data_ohe.size(1)

In [8]:
poet_dataset = torch.utils.data.TensorDataset(data_ohe, label_ohe)

In [9]:
data_loader = torch.utils.data.DataLoader(
        poet_dataset,
        batch_size=batch_size, shuffle=True, num_workers=1)
def inf_data_gen():
    while True:
        for data, label in data_loader:
            yield data
gen = inf_data_gen()

In [10]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0, 0.001)
    elif classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0, 0.001)

def wgan_div_gp(real, d_real, device, p):
    ones_real = torch.ones_like(d_real, device=device, requires_grad=False)
    gradients_real = torch.autograd.grad(
            outputs=d_real,
            inputs=real,
            grad_outputs=ones_real,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
    return gradients_real.view(gradients_real.size(0),-1).pow(2).sum(1)**(p/2) 

class C(nn.Module):
    def __init__(self, n_words=20, n_class=5000):
        super(C, self).__init__()
        self.n_words = n_words
        self.n_class = n_class
        self.net = nn.Sequential(*[
            nn.Conv1d(self.n_class, 512, kernel_size=1, padding=0, bias=False), # embedding
            nn.Conv1d(512, 64, kernel_size=3, stride=1, padding=1, bias=False), # 10
            nn.InstanceNorm1d(64),
            nn.LeakyReLU(0.1),
            nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2, bias=False), # 10
            nn.InstanceNorm1d(128),
            nn.LeakyReLU(0.1),
            nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1, bias=False), # 10
            nn.InstanceNorm1d(256),
            nn.LeakyReLU(0.1),
            nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), # 10
            nn.InstanceNorm1d(256),
            nn.LeakyReLU(0.1),
            nn.Conv1d(256, 512, kernel_size=3, stride=2, padding=1, bias=False), # 5
            nn.InstanceNorm1d(512),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, 1, kernel_size=n_words//4, padding=0, bias=False)
        ])
        self.net.apply(weights_init)
        
    def forward(self, x):
        x = self.net(x)
        x = x.view(x.size(0),1)
        return x

class G(nn.Module):
    def __init__(self, n_words=20, n_class=5000, n_noise=128):
        super(G, self).__init__()
        self.n_words = n_words
        self.n_class = n_class
        self.n_noise = n_noise
        self.fc1 = nn.Linear(self.n_noise, 64*(self.n_words//4), bias=False)
        weights_init(self.fc1)
        self.net = nn.Sequential(*[
            nn.ConvTranspose1d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), # 10
            nn.InstanceNorm1d(128),
            nn.LeakyReLU(0.1),
            nn.ConvTranspose1d(128, 256, kernel_size=4, stride=2, padding=1, bias=False), # 20
            nn.InstanceNorm1d(256),
            nn.LeakyReLU(0.1),
            nn.Conv1d(256, 512, kernel_size=5, stride=1, padding=2, bias=False),
            nn.InstanceNorm1d(512),
            nn.LeakyReLU(0.1),
            nn.Conv1d(512, self.n_class, kernel_size=1, padding=0, bias=False),
        ])
        self.net.apply(weights_init)
    def forward(self, x):
        x = self.fc1(x)
        x = x.view(x.size(0), 64, self.n_words//4)
        x = self.net(x)
        x = torch.tanh(x)
        return x


In [11]:
seed = 3 # debug!!!
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
G_net = G(n_words, n_class, n_noise).to(device)
C_net = C(n_words, n_class).to(device)
opt_C = optim.Adam(C_net.parameters(), lr=0.0002, betas=(0.5, 0.999))
opt_G = optim.Adam(G_net.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [12]:
from tqdm import tqdm_notebook

iterations = 50000
preview_iter = 500
preview_n = 8
d_iter = 15
std = 1.0
lambda_1 , lambda_2 = 10 , 0.2
M = 0.05
k, p = 2, 6

samples_preview = torch.randn(preview_n, n_noise).clamp(-2,2) * std

for ite in tqdm_notebook(range(1, iterations+1)):
    start_train_ts = time.time()
    # train D:
    G_net.eval()
    C_net.train()
    d_loss_mean = 0.0
    g_loss_mean = 0.0
    for _ in range(d_iter):
        opt_C.zero_grad()
        real = next(gen).to(device).requires_grad_(True)
        sample = torch.randn(real.size(0), n_noise, device=device).clamp(-2,2) * std
        with torch.no_grad():
            fake   = G_net(sample).detach() # not to touch G_net
        fake.requires_grad_(True)
        d_real = C_net(real)
        d_fake = C_net(fake)
        d_loss_real = d_real.mean()
        d_loss_real.backward(retain_graph=True)
        d_loss_fake = -d_fake.mean()
        d_loss_fake.backward(retain_graph=True)
        d_real_gp = wgan_div_gp(real, d_real, device, p)
        d_fake_gp = wgan_div_gp(fake, d_fake, device, p)
        d_gp_loss = (d_real_gp+d_fake_gp).mean() * k / 2
        d_gp_loss.backward(retain_graph=True)
        d_loss = d_loss_real + d_loss_fake + d_gp_loss
        opt_C.step()
        d_loss_mean += d_loss.item()
    d_loss_mean /= d_iter
    D_update_ts = time.time()
    # train G:
    G_net.train()
    C_net.train() # activate Discriminator's Dropout 
    opt_G.zero_grad()
    sample = torch.randn(batch_size, n_noise, device=device).clamp(-2,2) * std
    generated = G_net(sample)
    g_loss = C_net(generated).mean()
    g_loss.backward()
    opt_G.step()
    g_loss_mean = g_loss.mean().item()
    G_update_ts = time.time()
    if ite%preview_iter==0:
        print('[{}/{}] G: {:.4f}, D:{:.4f} -- elapsed_G: {:.4f}s -- elapsed_D: {:.4f}s'.format(ite, iterations, g_loss_mean, d_loss_mean, (G_update_ts-D_update_ts), (D_update_ts-start_train_ts) ))
        
        with torch.no_grad():
            G_net.eval() # evaluation state
            generated = G_net(samples_preview.to(device)).detach().cpu().numpy()
            generated = generated.transpose(0, 2, 1) # (?, 20, 5xxx)
            recovered = list(map(lambda x: ohe2str(x, charset_inv), generated))
            with open('./previews/iter-{:d}.txt'.format(ite), 'w') as fp:
                for poet in recovered:
                    fp.write(poet[:5]+'，')
                    fp.write(poet[5:10]+'。')
                    fp.write(poet[10:15]+'，')
                    fp.write(poet[15:20]+'。')
                    fp.write('\n')
                
        
        torch.save(G_net.state_dict(), './checkpoints/iter-{:d}-G.ckpt'.format(ite))
        torch.save(C_net.state_dict(), './checkpoints/iter-{:d}-D.ckpt'.format(ite))

HBox(children=(IntProgress(value=0, max=50000), HTML(value='')))

[500/50000] G: 9.5918, D:-19.5586 -- elapsed_G: 0.0217s -- elapsed_D: 1.7426s
[1000/50000] G: -1.7624, D:-0.8696 -- elapsed_G: 0.0215s -- elapsed_D: 1.6139s
[1500/50000] G: -0.7173, D:-0.0249 -- elapsed_G: 0.0216s -- elapsed_D: 1.6110s
[2000/50000] G: -3.2163, D:-0.0150 -- elapsed_G: 0.0216s -- elapsed_D: 1.6099s
[2500/50000] G: 0.5349, D:0.0098 -- elapsed_G: 0.0211s -- elapsed_D: 1.5801s
[3000/50000] G: -2.2498, D:-0.0135 -- elapsed_G: 0.0210s -- elapsed_D: 1.5802s
[3500/50000] G: -1.1548, D:-0.0000 -- elapsed_G: 0.0211s -- elapsed_D: 1.5820s
[4000/50000] G: 0.8413, D:-0.0045 -- elapsed_G: 0.0210s -- elapsed_D: 1.5817s
[4500/50000] G: -40.9502, D:-0.4438 -- elapsed_G: 0.0210s -- elapsed_D: 1.5801s
[5000/50000] G: -32.1468, D:-0.4804 -- elapsed_G: 0.0213s -- elapsed_D: 1.5830s
[5500/50000] G: -21.4478, D:0.0362 -- elapsed_G: 0.0211s -- elapsed_D: 1.5826s
[6000/50000] G: -21.5161, D:-0.6145 -- elapsed_G: 0.0210s -- elapsed_D: 1.5816s
[6500/50000] G: -21.8664, D:-0.0305 -- elapsed_G: 0.0

KeyboardInterrupt: 