<a href="https://colab.research.google.com/github/tienhuynh96/poem-generator/blob/main/%5BDemo%5D_Poem_Generation_Transformers_From_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **1. Import libraries**

In [None]:
import math
import os
import re
import time
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator



## **2. Download and load dataset**

In [None]:
# https://drive.google.com/file/d/1KfrBAycsgQBt1mtEbzJh5pSYL8YIk0Tc/view?usp=sharing
!gdown --id 1KfrBAycsgQBt1mtEbzJh5pSYL8YIk0Tc
!unzip poem_dataset_final.zip

Downloading...
From: https://drive.google.com/uc?id=1KfrBAycsgQBt1mtEbzJh5pSYL8YIk0Tc
To: /content/poem_dataset_final.zip
100% 37.6k/37.6k [00:00<00:00, 66.8MB/s]
Archive:  poem_dataset_final.zip
  inflating: poem_final.csv          


In [None]:
# Read data
DATASET_PATH ='/content/poem_final.csv'
df = pd.read_csv(DATASET_PATH)
df

Unnamed: 0.1,Unnamed: 0,title,content,source,url
0,0,“Cái làm ta hạnh phúc”,Cái làm ta hạnh phúc\nThực ra cũng chẳng nhiều...,"Nguồn: Châm ngôn mới (thơ), Thái Bá Tân, NXB L...",https://www.thivien.net/Th%C3%A1i-B%C3%A1-T%C3...
1,1,“Chiều vừa xốp trên tay”,Chiều vừa xốp trên tay\nChợt nghe thoáng ong b...,"Nguồn: Lâm Huy Nhuận, Chiều có thật (thơ), NXB...",https://www.thivien.net/L%C3%A2m-Huy-Nhu%E1%BA...
2,2,“Dưới giàn hoa thiên lý...”,Dưới giàn hoa thiên lý\nMột mình anh đang ngồi...,"Nguồn: Nguyễn Nhật Ánh, Mắt biếc, NXB Trẻ, 2004",https://www.thivien.net/Nguy%E1%BB%85n-Nh%E1%B...
3,3,"“Đến, nhiều nơi để đến”","Đến, nhiều nơi để đến\nVề, trở lại với mình\nC...","Nguồn: Châm ngôn mới (thơ), Thái Bá Tân, NXB L...",https://www.thivien.net/Th%C3%A1i-B%C3%A1-T%C3...
4,4,“Đừng bao giờ dại dột”,Đừng bao giờ dại dột\nĐem chuyện riêng của mìn...,"Nguồn: Châm ngôn mới (thơ), Thái Bá Tân, NXB L...",https://www.thivien.net/Th%C3%A1i-B%C3%A1-T%C3...
...,...,...,...,...,...
185,95,Ám ảnh sông xưa,"Ôi, con sóng chết khô,\nvật vờ trong bùn quánh...",,https://www.thivien.net/%C4%90%E1%BB%97-Qu%E1%...
186,96,Áng dương không biết sầu,Áng dương không biết sầu\nNằm mãi ở trên cao\n...,"Nguồn: Lâu Văn Mua, Tôi bay vào mắt em (thơ), ...",https://www.thivien.net/L%C3%A2u-V%C4%83n-Mua/...
187,97,Anh,Cây bút gẫy trong tay\nCặn mực khô đáy lọ\nÁnh...,19-7-1973\n\n[Thông tin 2 nguồn tham khảo đã đ...,https://www.thivien.net/Xu%C3%A2n-Qu%E1%BB%B3n...
188,98,Anh biết,Không có anh để già\nLàm sao em được trẻ\nMuốn...,,https://www.thivien.net/Nguy%E1%BB%85n-Minh-D%...


In [None]:
df["content"][0].split("\n")

['Cái làm ta hạnh phúc',
 'Thực ra cũng chẳng nhiều',
 'Chỉ cần có ai đó',
 'Để ta thầm thương yêu',
 '',
 'Rồi thêm chút công việc',
 'Cho ta làm hàng ngày',
 'Cuối cùng, chút mơ mộng',
 'Để đưa ta lên mây']

## **3. Build vectorization function**

In [None]:
# Create text normalize function
def text_normalize(text):
  text = text.strip()
  return text



In [None]:
text_normalize(df["content"][0])

'Cái làm ta hạnh phúc\nThực ra cũng chẳng nhiều\nChỉ cần có ai đó\nĐể ta thầm thương yêu\n\nRồi thêm chút công việc\nCho ta làm hàng ngày\nCuối cùng, chút mơ mộng\nĐể đưa ta lên mây'

In [None]:
# Text normalize for dataframe
df["content"] = df["content"].apply(lambda x: text_normalize(x))

In [None]:
# Create tokenizer function
# Only split text
def tokenizer(text):
  return text.split()

# Create yield function
def yield_tokens(df):
  for idx, row in df.iterrows():
    yield tokenizer(row['content'])

# Build vocab
vocab = build_vocab_from_iterator(
    yield_tokens(df),
    specials=['<unk>', '<pad>', '<sos>', '<eos>', '<eol>']    # eol is end of line
)

# Set default index
vocab.set_default_index(vocab['<unk>'])
vocab.get_stoi()

{'“Liệu': 2198,
 'ẩm': 2196,
 'ầm': 2195,
 'đợi,': 2193,
 'địa': 2187,
 'đền!': 2186,
 'đếm': 2185,
 'đặn': 2184,
 'đáng': 2182,
 'điên': 2178,
 'Được': 2174,
 'Đâu': 2172,
 'Đàn': 2171,
 'ô-kê': 2169,
 'Ðức': 2164,
 'Ðêm': 2162,
 'Ðã': 2161,
 'xấu': 2158,
 'xúc': 2156,
 'xoãi': 2154,
 'vụng': 2151,
 'về…': 2150,
 'vạng': 2148,
 'tột': 2144,
 'tổng': 2143,
 'tất': 2141,
 'túi': 2138,
 'tôi.': 2137,
 'tình.': 2134,
 'trưởng': 2132,
 'trách': 2129,
 'trung,': 2128,
 'trao': 2127,
 'toát': 2126,
 'thỏ': 2121,
 'thong': 2114,
 'xiêu': 2153,
 'thiểu': 2112,
 'sướng': 2110,
 'rợ.': 2104,
 'rộng': 2103,
 'rẩy': 2102,
 'rèn': 2100,
 'riết': 2096,
 'phứt': 2093,
 'phú': 2091,
 'phù.': 2090,
 'phím': 2089,
 'nứt': 2083,
 'nợ': 2081,
 'thon': 2113,
 'nề': 2079,
 'nước...': 2078,
 'vang': 2145,
 'nông': 2076,
 'nó,': 2075,
 'nâu': 2074,
 'nuốt': 2073,
 'nhằn': 2070,
 'nhạt': 2067,
 'nhưng': 2066,
 'nhung': 2065,
 'nhiệm': 2063,
 'ngó': 2058,
 'nghị': 2056,
 'mục.': 2054,
 'mổ': 2053,
 'mệt': 2051,

In [None]:
len(vocab)

2201

In [None]:
# Set pad token and eos token
PAD_TOKEN = vocab['<pad>']
EOS_TOKEN = vocab['<eos>']

# Set max sequence length
MAX_SEQ_LEN = 25

# Create pad and truncate function
def pad_and_truncate(input_ids, max_seq_len):
  if len(input_ids) > max_seq_len:
    input_ids = input_ids[:max_seq_len]
  else:
    input_ids += [PAD_TOKEN] * (max_seq_len - len(input_ids))
  return input_ids


# Create vectorize function
def vectorize(text, max_seq_len):
  input_ids = [vocab[token] for token in tokenizer(text)]
  input_ids = pad_and_truncate(input_ids, max_seq_len)
  return input_ids

# Create decode function
def decode(input_ids):
  return [vocab.get_itos()[token_id] for token_id in input_ids]

In [None]:
vocab.get_itos()[0]

'<unk>'

In [None]:
vocab.lookup_token(0)

'<unk>'

In [None]:
print(df['content'][0].split('\n')[0])
print(vectorize(df['content'][0].split("\n")[0], 10))

Cái làm ta hạnh phúc
[175, 62, 39, 313, 366, 1, 1, 1, 1, 1]


## **4. Create Poem Dataset**

In [None]:
# Create Dataset Class
class PoemDataset(Dataset):
  def __init__(self, df, tokenizer, vectorizer, max_seq_len):
    self.tokenizer = tokenizer
    self.vectorizer = vectorizer
    self.max_seq_len = max_seq_len
    self.input_seqs, self.target_seqs, self.padding_masks = self.create_samples(df) # Build create sample function

  ## reate_padding_mask function to create padding mask
  def create_padding_mask(self, input_ids, pad_token_id = PAD_TOKEN):
    return [0 if token_id == pad_token_id else 1 for token_id in input_ids]

  ## Split_content function to separate paragraph and line of each sample poem
  def split_content(self, content):
    samples = []

    # Separate paragraphs
    poem_parts = content.split('\n\n')
    # Separate lines
    for poem_part in poem_parts:
      poem_in_lines = poem_part.split('\n')
      if len(poem_in_lines) == 4:
        # Append sample poem with 4 line
        samples.append(poem_in_lines)

    return samples

  ## Build prepare sample function to prepare input, target and padding each sample
  def prepare_sample(self, sample):
    # Initialize variables
    input_seqs = []
    target_seqs = []
    padding_masks = []

    # Add special token for input text
    input_text = '<sos> ' + ' <eol> '.join(sample) + ' <eol> <eos>'
    # Tokenize input text
    input_ids = self.tokenizer(input_text)
    # Iterator for create input sequence, target sequence and padding mask
    for idx in range(1, len(input_ids)):
      # Get input sequence
      input_seq = ' '.join(input_ids[:idx])
      # Get target sequence
      target_seq = ' '.join(input_ids[1:idx+1])
      # Vectorize for input sequence
      input_seq = self.vectorizer(input_seq, self.max_seq_len)
      # Vectorize for target sequence
      target_seq = self.vectorizer(target_seq, self.max_seq_len)
      # Get padding_mask for input sequence
      padding_mask = self.create_padding_mask(input_seq)

      # Append to 3 list above
      input_seqs.append(input_seq)
      target_seqs.append(target_seq)
      padding_masks.append(padding_mask)

    return input_seqs, target_seqs, padding_masks

  # Create sample function to build sample in dataframe
  def create_samples(self, df):
    # Initialize variables
    input_seqs = []
    target_words = []
    padding_masks = []

    # Iterator to get each row in df
    for idx, row in df.iterrows():
      # Get content in row
      content = row['content']
      # Get paragraph and line each content
      samples = self.split_content(content)
      # Iterator to create sample_input_seqs, sample_target_words, sample_padding_masks each sample
      for sample in samples:
        # Get sample_input_seqs, sample_target_words, sample_padding_masks each sample
        sample_input_seqs, sample_target_words, sample_padding_masks = self.prepare_sample(sample)

        # Append to 3 list above
        input_seqs += sample_input_seqs
        target_words += sample_target_words
        padding_masks += sample_padding_masks

    # Convert data to tensor
    input_seqs = torch.tensor(input_seqs, dtype=torch.long)
    target_words = torch.tensor(target_words, dtype=torch.long)
    padding_masks = torch.tensor(padding_masks, dtype=torch.float)

    return input_seqs, target_words, padding_masks

  # len fucntion
  def __len__(self):
    return len(self.input_seqs)

  # Get item function
  def __getitem__(self, idx):
    input_seqs = self.input_seqs[idx]
    target_seqs = self.target_seqs[idx]
    padding_masks = self.padding_masks[idx]

    return input_seqs, target_seqs, padding_masks

In [None]:
# Create dataset
# Set batch size
TRAIN_BS = 256
# Create train dataset
train_dataset = PoemDataset(
    df=df,
    tokenizer=tokenizer,
    vectorizer=vectorize,
    max_seq_len=MAX_SEQ_LEN
)

# Create dataloader
train_loader = DataLoader(
    train_dataset,
    batch_size=TRAIN_BS,
    shuffle=False
)

In [None]:
# Check value in dataloader
input_seqs, target_seqs, padding_masks = next(iter(train_loader))

print(input_seqs[0])
print(target_seqs[0])
print(padding_masks[0])

print(input_seqs.shape)
print(target_seqs.shape)
print(padding_masks.shape)

tensor([2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1])
tensor([175,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1])
tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0.])
torch.Size([256, 25])
torch.Size([256, 25])
torch.Size([256, 25])


In [None]:
# Check decode function
for idx in range(MAX_SEQ_LEN):
    print(decode(input_seqs[idx]))
    print(decode(target_seqs[idx]))

['<sos>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
['Cái', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
['<sos>', 'Cái', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
['Cái', 'làm', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
['<sos>', 'Cái', 'làm', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>',

## **5. Create model**

In [None]:
# Create positional Encoding Class: text + position embedding
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_dims, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dims, 2) * (-math.log(10000.0) / embedding_dims))
        pe = torch.zeros(max_len, 1, embedding_dims)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        x = self.dropout(x)

        return x

In [None]:
# Create class TransformerModel
class TransformerModel(nn.Module):
  def __init__(self, vocab_size, embedding_dims, n_heads,
               hidden_dims, n_layers, dropout=0.5):
    super(TransformerModel, self).__init__()
    self.model_type = 'Transformer'
    self.embedding = nn.Embedding(vocab_size, embedding_dims)
    self.embedding_dims = embedding_dims

    self.pos_encoder = PositionalEncoding(embedding_dims, dropout)
    # Create encoder layers
    encoder_layers = nn.TransformerEncoderLayer(
        embedding_dims, n_heads,
        hidden_dims, dropout, batch_first=True
    )

    # Set transforer encoder with encoder layers
    self.transformer_encoder = nn.TransformerEncoder(encoder_layers, n_layers)
    self.linear = nn.Linear(embedding_dims, vocab_size)

    self.init_weights()

  # Create init_weights function to init weight
  def init_weights(self):
    initrange = 0.1
    self.embedding.weight.data.uniform_(- initrange, initrange)
    self.linear.bias.data.zero_()
    self.linear.weight.data.uniform_(- initrange, initrange)

  def forward(self, src, src_mask=None, padding_mask=None):
    # Get embedding from src
    src = self.embedding(src ) * math.sqrt(self.embedding_dims)
    # Add position
    src = self.pos_encoder(src)
    # Create src_mask to use in masked multi self attention (purpose: model do not know the next token)
    # src mask is matrix include 0 and -inf value, -inf when softmax will be = 0 => Do not know the next token
    if src_mask is None:
      src_mask = nn.Transformer.generate_square_subsequent_mask(len(src[0])).to(device) # len(src[0]) len of sequence

    output = self.transformer_encoder(src, mask = src_mask, src_key_padding_mask=padding_mask)  # B, S, E
    output = self.linear(output) # B, S, Vocabsize

    return output


In [None]:
# Test model
VOCAB_SIZE = len(vocab)
EMBEDDING_DIMS = 128
HIDDEN_DIMS = 128
N_LAYERS = 2
N_HEADS = 4
DROPOUT = 0.2

device = 'cuda' if torch.cuda.is_available() else 'cpu'
input_tests = torch.randint(1, 10, (2, 10)).to(device)

model = TransformerModel(
    VOCAB_SIZE,
    EMBEDDING_DIMS,
    N_HEADS,
    HIDDEN_DIMS,
    N_LAYERS,
    DROPOUT
).to(device)

with torch.no_grad():
  output = model(input_tests)
  print(output.shape)

torch.Size([2, 10, 2201])


## **6. Training**

In [None]:
# Set learning rate, epochs
LR = 5.0
EPOCHS = 5

# Set criterion, optimizer, scheduler
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
# scheduler for reduce learning rate throw each epoch (each epoch lr = lr * 0,95)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.95)

In [None]:
from tqdm import tqdm
# Train model
model.train()
for epoch in tqdm(range(EPOCHS)):
  losses = []
  for idx, samples in enumerate(train_loader):
    # Get sample (B, S)
    input_seqs, target_seqs, padding_masks = samples
    # Put data to devce
    input_seqs = input_seqs.to(device)
    target_seqs = target_seqs.to(device)
    padding_masks = padding_masks.to(device)

    # Compute output
    output = model(input_seqs, padding_mask=padding_masks) # B, S, D
    # Permute the output tensor if necessary (common in some loss functions)
    output = output.permute(0, 2, 1)    # Change shape to [batch_size, feature_dim, seq_length] required for loss function
    loss = criterion(output, target_seqs)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # Clip gradients to prevent exploding gradients
    optimizer.step()

    losses.append(loss.item())

  total_loss = sum(losses) / len(losses)
  print(f'EPOCH {epoch+1}\tLoss {total_loss}')
  scheduler.step()

 20%|██        | 1/5 [00:42<02:48, 42.15s/it]

EPOCH 1	Loss 4.272765425118533


 40%|████      | 2/5 [01:24<02:06, 42.12s/it]

EPOCH 2	Loss 3.5248493172905664


 60%|██████    | 3/5 [02:05<01:23, 41.79s/it]

EPOCH 3	Loss 3.2580403089523315


 80%|████████  | 4/5 [02:47<00:41, 41.81s/it]

EPOCH 4	Loss 2.758037756789814


100%|██████████| 5/5 [03:28<00:00, 41.79s/it]

EPOCH 5	Loss 2.1581580313769253





## **7. Inference**

In [None]:
# Compute temperature to chose random word ouput
# Temperature < 1 => focus on hight probabilities class
# Temperature > 1 => the low probabilities class will be increase
def sample_with_temperature(logits, temperature=1.0):
    if temperature != 1.0:
        logits = logits / temperature

    # Throw softmax to get probabilities
    probabilities = F.softmax(logits, dim=-1)
    # Return the chose index base probabilities
    sampled_index = torch.multinomial(probabilities, 1).item()

    return sampled_index

In [None]:
# Inference
# Model evaluation mode
model.eval()
# Set temperature
temperature = 1.2
# Input text
input_text = '<sos> Anh'
# Tokenize input
input_tokens = tokenizer(input_text)
# Convert token to ids
input_ids = [vocab[token] for token in input_tokens]

# Set eos token id
eos_token_id = vocab['<eos>']
# Copy input ids to generated ids
generated_ids = input_ids.copy()
# Set max generation len
MAX_GENERATION_LEN = 50

# Iterator to generate token
for _ in range(MAX_GENERATION_LEN):
    # Put generated ids to tensor and device
    input_tensor = torch.tensor([generated_ids], dtype=torch.long).to(device)
    # No grad
    with torch.no_grad():
        # Compute output
        outputs = model(input_tensor)   # B. S, VCS
    # Chose the last token logits
    last_token_logits = outputs[0, -1, :]   # Number ouput is equivalent to number input token => chose the last output
    # Chose next token base temperature
    next_token_id = sample_with_temperature(last_token_logits, temperature)
    # Append next token to generated ids
    generated_ids.append(next_token_id)

    # Break if next token id is eos token id
    if next_token_id == eos_token_id:
        break

# Convert the generated tokens back to text
generated_text = decode(generated_ids)
generated_text = ' '.join(generated_text)
generated_text = generated_text.replace('<sos>', '')
lines = generated_text.split('<eol>')
for line in lines:
    print(''.join(line))

 Anh chốn bao mùa trước 
 Tiếng 
 Một mối tình 
 khai ôm 
 phiên bạc 
 Hôm thu ngon 
 Anh cô 
 Tình như thuở xác rất trẻ rừng, ngàn đêm tuổi già 
 Em đừng vội bâng khuâng 
 Một Vết trầm kha tim đỏ 
 Nơi
