In [17]:
import torch
import torch.nn as nn
import os
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from torch import optim

In [18]:
class Encoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_size, dropout=0.5):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hidden_size, dropout=dropout, batch_first=True, bidirectional=True)
    def forward(self, token_seq, mode='concat'):
        embedded = self.embedding(token_seq)
        outputs, hidden = self.rnn(embedded)
        if mode == 'concat':
            return torch.cat((hidden[0], hidden[1]), dim=1), outputs
        else:
            return hidden[0] + hidden[1] ,outputs

In [19]:
class Attention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, enc_output, dec_output):
        # a_t = h_t @ h_s
        a_t = torch.bmm(enc_output, dec_output.permute(0, 2, 1))
        a_t = torch.softmax(a_t, dim=1)
        c_t = torch.bmm(a_t.permute(0, 2, 1), enc_output)
        return c_t

In [20]:
class Decoder(nn.Module):
    def __init__(self, input_dim, emb_dim, hidden_dim, dropout, mode='concat'):
        super(Decoder, self).__init__()
        # 定义嵌入层
        self.embedding = nn.Embedding(input_dim, emb_dim)
        if mode=='add':
            self.hidden_up = nn.Linear(hidden_dim, hidden_dim * 2)
            self.rnn = nn.GRU(emb_dim, hidden_dim , dropout=dropout,
                              batch_first=True)
            self.fc = nn.Linear(hidden_dim, input_dim)
            self.atteniton_fc = nn.Linear(hidden_dim * 4, hidden_dim)
        else:
            self.rnn = nn.GRU(emb_dim, hidden_dim * 2, dropout=dropout,
                            batch_first=True)
            self.fc = nn.Linear(hidden_dim * 2, input_dim)
            self.atteniton_fc = nn.Linear(hidden_dim * 4, hidden_dim * 2)
            self.hidden_up = nn.Identity()
        self.atteniton = Attention()

    
    def forward(self, token_seq, hidden_state, enc_output):
        # print(hidden_state.unsqueeze(0).shape)
        embedded = self.embedding(token_seq)
        # print(embedded.shape)
        dec_output, hidden = self.rnn(embedded, hidden_state.unsqueeze(0))
        dec_output = self.hidden_up(dec_output)
        # print(dec_output.shape)
        # print(hidden.shape)
        # print(enc_output.shape)
        c_t = self.atteniton(enc_output, dec_output)
        dec_output = self.atteniton_fc(torch.cat((dec_output, c_t), dim=-1))
        out = torch.tanh(dec_output) # 激活
        logits = self.fc(out)
        return logits, hidden

In [21]:
class Seq2Seq(nn.Module):
    def __init__(self,
                 enc_emb_size,
                 dec_emb_size,
                 emb_dim,
                 hidden_size,
                 dropout=0.5,
                 mode='concat'
                 ):

        super().__init__()
        # encoder
        self.encoder = Encoder(enc_emb_size, emb_dim, hidden_size, dropout=dropout)
        self.decoder = Decoder(dec_emb_size, emb_dim, hidden_size, dropout=dropout, mode=mode)

    def forward(self, enc_input, dec_input, mode='concat'):
        # encoder last hidden state
        encoder_state, outputs = self.encoder(enc_input, mode=mode)
        output, hidden = self.decoder(dec_input, encoder_state, outputs)
        return output, hidden

In [22]:
with open("E:\\study\\AI\\data\\chinese-couplets\\versions\\2\\couplet\\vocabs", encoding="utf-8") as f:
    vocab_list = [line.strip() for line in f]

vocab_list = ["PAD", "UNK"] + vocab_list
# 字符->索引
evoc = {char : idx for idx, char in enumerate(vocab_list)}
# 索引->字符
dvoc = {idx : char for char, idx in evoc.items()}

In [23]:
def get_proc(evoc, dvoc):
    def collate_fn(batch):
        enc_seqs, dec_seqs = zip(*batch) # encoder输入和decoder输出
        # 把字符序列转换为索引序列
        enc_idx = [torch.tensor([evoc.get(c, evoc['UNK'] ) for c in seq]) for seq in enc_seqs]
        # 解码输入添加 <s> 起始符，解码目标添加 </s> 终止符
        dec_idx = [torch.tensor(
            [evoc['<s>']] + [dvoc.get(c, evoc['UNK']) for c in seq] + [evoc['</s>']]
        ) for seq in dec_seqs]
        # 填充序列
        enc_padded = pad_sequence(enc_idx, batch_first=True, padding_value=evoc['PAD'])
        dec_padded = pad_sequence(dec_idx, batch_first=True, padding_value=evoc['PAD'])
        targets = [seq[1:] for seq in dec_padded] # 目标去掉开头<S>
        targets_padded = pad_sequence(targets, batch_first=True, padding_value=evoc['PAD'])
        return enc_padded, dec_padded[:, :-1], targets_padded
    return collate_fn


In [24]:
def process(model, loss_fn, optimizer, writer, train_loader, epochs=20, mode='concat',name='seq2seq_cat'):
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for i, (enc_x, dec_x, target) in enumerate(tqdm(train_loader)):
            enc_x, dec_x, target = enc_x.cuda(), dec_x.cuda(), target.cuda()
            logits, _ = model(enc_x, dec_x, mode=mode)
            loss = loss_fn(logits.view(-1, logits.size(-1)), target.view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            writer.add_scalar('train/loss', loss.item(), epoch * len(train_loader) + i)
        print(f'epoch: {epoch}, loss: {total_loss / len(train_loader)}')
    torch.save(model.state_dict(), f'{name}.pth')

In [25]:
with open("E:\\study\\AI\\data\\chinese-couplets\\versions\\2\\couplet\\train\\in.txt", encoding="utf-8") as f:
    inputs = [line.strip() for line in f]
with open("E:\\study\\AI\\data\\chinese-couplets\\versions\\2\\couplet\\train\\out.txt", encoding="utf-8") as f:
    outputs = [line.strip() for line in f]

train_set = list(zip(inputs, outputs))
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, collate_fn=get_proc(evoc, dvoc))

# test cat

In [30]:
device = torch.device("cuda")
model = Seq2Seq(len(vocab_list), len(vocab_list), emb_dim=128, hidden_size=128, dropout=0.5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss(ignore_index=evoc['PAD'])
writer = SummaryWriter(log_dir="homework\\week08\\cat")

In [31]:
process(model, loss_fn, optimizer, writer, train_loader, epochs=5, mode='concat', name='seq2seq_cat')

100%|██████████| 12039/12039 [04:03<00:00, 49.38it/s]


epoch: 0, loss: 0.045530947104281065


100%|██████████| 12039/12039 [03:56<00:00, 50.82it/s]


epoch: 1, loss: 0.012563692718363944


100%|██████████| 12039/12039 [03:58<00:00, 50.42it/s]


epoch: 2, loss: 0.008558550638696177


100%|██████████| 12039/12039 [04:01<00:00, 49.92it/s]


epoch: 3, loss: 0.007094684156260348


100%|██████████| 12039/12039 [03:59<00:00, 50.22it/s]

epoch: 4, loss: 0.006563768271730396





In [32]:
model_cat = Seq2Seq(
    enc_emb_size=len(evoc),
    dec_emb_size=len(evoc),
    emb_dim=128, 
    hidden_size=128,
    dropout=0.5
).to(device)
model_cat.load_state_dict(torch.load("seq2seq_cat.pth"))
upper = "无花无酒无花酒"
max_dec_len = 7
dec_tokens= []
dec_input = torch.tensor([[evoc['<s>']]]).to(device)
model_cat.eval()

with torch.no_grad():
    enc_input = torch.tensor([[evoc.get(c, evoc['UNK']) for c in upper]]).to(device)
    enc_hidden, enc_output  = model_cat.encoder(enc_input, mode='concat')

    while True:
        if len(dec_tokens) >= max_dec_len:
            break
        # print(enc_input.shape)
        # print(enc_output.shape, hidden.shape)
        logits, hidden = model_cat.decoder(dec_input, enc_hidden, enc_output)
        predicted_idx = torch.argmax(logits, dim=-1)
        print(predicted_idx)
        if dvoc[predicted_idx.squeeze().item()] == '</s>':
            break
        dec_tokens.append(predicted_idx.squeeze().item())
        dec_input = predicted_idx
    
print(''.join([dvoc[idx] for idx in dec_tokens]))

tensor([[3021]], device='cuda:0')
tensor([[4739]], device='cuda:0')
tensor([[8507]], device='cuda:0')
tensor([[4119]], device='cuda:0')
tensor([[6593]], device='cuda:0')
tensor([[616]], device='cuda:0')
tensor([[2298]], device='cuda:0')
沛褥摠傩尻鸡磊


  model_cat.load_state_dict(torch.load("seq2seq_cat.pth"))


# test add

In [33]:
device = torch.device("cuda")
model = Seq2Seq(len(vocab_list), len(vocab_list), emb_dim=128, hidden_size=128, dropout=0.5, mode='add').to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss(ignore_index=evoc['PAD'])
writer = SummaryWriter(log_dir="homework\\week08\\add")
process(model, loss_fn, optimizer, writer, train_loader, epochs=10, mode='add', name="Seq2Seq_add")

100%|██████████| 12039/12039 [03:04<00:00, 65.08it/s]


epoch: 0, loss: 0.20677100896630482


100%|██████████| 12039/12039 [03:02<00:00, 66.01it/s]


epoch: 1, loss: 0.029715389054948117


100%|██████████| 12039/12039 [03:04<00:00, 65.23it/s]


epoch: 2, loss: 0.024766155037593735


100%|██████████| 12039/12039 [02:59<00:00, 66.98it/s]


epoch: 3, loss: 0.02338384183167454


100%|██████████| 12039/12039 [02:55<00:00, 68.48it/s]


epoch: 4, loss: 0.025448132585351203


100%|██████████| 12039/12039 [02:54<00:00, 69.10it/s]


epoch: 5, loss: 0.019867335473235152


100%|██████████| 12039/12039 [02:54<00:00, 69.09it/s]


epoch: 6, loss: 0.025008262674501677


100%|██████████| 12039/12039 [02:54<00:00, 68.87it/s]


epoch: 7, loss: 0.02195845811163227


100%|██████████| 12039/12039 [02:52<00:00, 69.66it/s]


epoch: 8, loss: 0.024769679226888035


100%|██████████| 12039/12039 [02:55<00:00, 68.53it/s]

epoch: 9, loss: 0.02127216277900947





In [36]:
model_add = Seq2Seq(
    enc_emb_size=len(evoc),
    dec_emb_size=len(evoc),
    emb_dim=128, 
    hidden_size=128,
    dropout=0.5,
    mode='add'
).to(device)
model_add.load_state_dict(torch.load("Seq2Seq_add.pth"))
upper = "无花无酒无花酒"
max_dec_len = 7
dec_tokens= []
dec_input = torch.tensor([[evoc['<s>']]]).to(device)

while True:
    if len(dec_tokens) >= max_dec_len:
        break
    model_add.eval()
    with torch.no_grad():
        enc_input = torch.tensor([[evoc.get(c, evoc['UNK']) for c in upper]]).to(device)
        # print(enc_input.shape)
        enc_output, hidden = model_add.encoder(enc_input, mode='add')
        # print(enc_output.shape, hidden.shape)
        logits, hidden = model_add.decoder(dec_input, enc_output, hidden)
        predicted_idx = torch.argmax(logits, dim=-1)
        print(predicted_idx)
        if dvoc[predicted_idx.squeeze().item()] == '</s>':
            break
        dec_tokens.append(predicted_idx.squeeze().item())
        dec_input = predicted_idx
    
print(''.join([dvoc[idx] for idx in dec_tokens]))

tensor([[1]], device='cuda:0')
tensor([[3]], device='cuda:0')
UNK


  model_add.load_state_dict(torch.load("Seq2Seq_add.pth"))
