# Training a text generator with Word-RNN 


In [None]:
import pandas as pd
import re
import os
import torch
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical


Importing data sets

In [None]:
file_path2 = r"E:\NLP project\data\PoetryFoundationData.csv"
df2 = pd.read_csv(file_path2, encoding='utf-8', quotechar='"')

Splitting sentences into words

In [None]:

def tokenize_sentence(sentence):
    tokens = re.findall(r'\b\w+\b', sentence)
    return tokens


df2['tokenized_poem'] = df2['Poem'].astype(str).apply(tokenize_sentence)

df2['tokenized_poem'] = df2['Poem'].apply(tokenize_sentence)

print(df2['tokenized_poem'].sample(5))


Text cleanup

In [None]:

def remove_roman_letters(text):
    return re.sub(r'[IVXLCDMivxlcdm]', '', text)
df2['cleaned_poem'] = df2['Poem'].astype(str).apply(remove_roman_letters)


def remove_carriage_return(text):
    return re.sub(r'\r|\n', ' ', text)
df2['cleaned_poem'] = df2['cleaned_poem'].apply(remove_carriage_return)


def remove_numbers(text):
    return re.sub(r'\d', '', text)
df2['cleaned_poem'] = df2['cleaned_poem'].astype(str).apply(remove_numbers)


def remove_special_characters(text):
    text = remove_special_characters(text)
    return re.sub(r'[^a-zA-Z0-9\s]', '', text)
df2['cleaned_poem'] = df2['cleaned_poem'].apply(remove_carriage_return)


print(df2['cleaned_poem'].head(5))



In [None]:
pip install clean-text


In [None]:
import unicodedata

category = unicodedata.category('A')
print(f'Category: {category}')

numeric_value = unicodedata.numeric('1')
print(f'Numeric Value: {numeric_value}')

name = unicodedata.name('A')
print(f'Name: {name}')

is_digit = '9'.isdigit()
print(f'Is Digit: {is_digit}')


In [None]:
from cleantext import clean

clean("some input",
    fix_unicode=True,               # fix various unicode errors
    to_ascii=True,                  # transliterate to closest ASCII representation
    lower=True,                     # lowercase text
    no_line_breaks=True,           # fully strip line breaks as opposed to only normalizing them
    no_urls=True,                  # replace all URLs with a special token
    no_emails=True,                # replace all email addresses with a special token
    no_phone_numbers=True,         # replace all phone numbers with a special token
    no_numbers=True,               # replace all numbers with a special token
    no_digits=True,                # replace all digits with a special token
    no_currency_symbols=True,      # replace all currency symbols with a special token
    no_punct=True,                 # remove punctuations
    replace_with_punct="",          # instead of removing punctuations you may replace them
    replace_with_url="<URL>",
    replace_with_email="<EMAIL>",
    replace_with_phone_number="<PHONE>",
    replace_with_number="<NUMBER>",
    replace_with_digit="0",
    replace_with_currency_symbol="<CUR>",
    lang="en"                       # set to 'de' for German special handling
)


Filtering to remove low-frequency words

In [None]:
import random
from collections import Counter



all_tokens2 = [token for tokens in df2['tokenized_poem'] for token in tokens]
word_counts2 = Counter(all_tokens2)


filtered_vocab2 = {word for word, count in word_counts2.items() if count > 10}



print("\nRandom 20 words from Vocabulary - df2:")
random_words2 = random.sample(filtered_vocab2, 20)
print(random_words2)

print(f"Total unique tokens - df2: {len(filtered_vocab2)}")


In [None]:
print("Dataset size:", len(df2))


Create a glossary

In [None]:


vocab_file_path = "E:/NLP project/data/vocab.txt"
with open(vocab_file_path, 'w', encoding='utf-8') as f:
    for word in filtered_vocab2:
        f.write(word + '\n')

print(f"Vocabulary saved to: {vocab_file_path}")

In [None]:
device = 'cpu'

In [None]:
!pip install torch

In [None]:
hidden_size = 10   # size of hidden state
batch_size = 5    # size of the batch used for training
step_len = 2     # number of training samples in each stem
num_layers = 5      # number of layers in LSTM layer stack
lr = 0.002          # learning rate
num_steps = 20     # max number of training steps
gen_seq_len = 10    # length of generated sequence
load_chk = False    # load in pre-trained checkpoint for training
save_path = "word_rnn_model.pt"
# load_path = "word_rnn_model.pt"

In [None]:
def load_all_text_files_in_folder(path, max_files = 10000):
    corpus = ''
    # Find all files in the folder or subfolders
    for root, _, files in os.walk(path):
        for i, file in enumerate(files):
            # If the file is a text file
            if file.endswith(".txt") and i <= max_files:
                # Open the file and add the text to the corpus
                with open(os.path.join(root, file), 'r', encoding='utf-8') as f:
                    text = f.read()
                    # Add text from file
                    corpus += text
                    # Add 'End of File' between documents
                    corpus += '\n EOF \n'
    return corpus

In [None]:
data_path = "E:/NLP-23-24-main/data/vocab.txt"
corpus = load_all_text_files_in_folder(data_path)
words = sorted(list(set(corpus.split())))
data_size, vocab_size = len(corpus.split()), len(words)

vocab_to_index = {word: idx for idx, word in enumerate(filtered_vocab2)}
df2['tokenized_poem_indices'] = df2['tokenized_poem'].apply(lambda tokens: [vocab_to_index[word] for word in tokens if word in vocab_to_index])
print(df2['tokenized_poem_indices'].sample(5))


In [None]:
rnn = RNN(vocab_size, vocab_size, hidden_size, num_layers).to(device)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=lr)

In [None]:

def get_training_batch_indicies(index_list, batch_size):
    
    batch_size = min(batch_size, len(index_list))
    
    
    input_batch_indices = torch.tensor(np.array(random.choices(index_list, k=batch_size)))
    
   
    target_batch_indices = input_batch_indices + 1
    
    return input_batch_indices, target_batch_indices


In [None]:
print("Dataset size:", len(df2))


In [None]:

if len(index_list) < step_len:
    print(f"Insufficient data for training. Dataset size: {len(index_list)}, Step length: {step_len}")
else:



for step in range(1, num_steps):
    running_loss = 0
    hidden_state = None
    rnn.zero_grad()
    train_batch_indicies, target_batch_indicies = get_training_batch_indicies(index_list, batch_size)

  
    if len(train_batch_indicies) > 0:
        
        for i in range(step_len):
            
            input_batch = data[train_batch_indicies.long()].squeeze()
            target_batch = data[target_batch_indicies].squeeze()
            
           
            output, hidden_state = rnn(input_batch, hidden_state)
            
            
            loss = loss_fn(output.view(-1, vocab_size), target_batch.view(-1))
            running_loss += loss.item() / step_len
            
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            
            train_batch_indicies = train_batch_indicies + 1
            target_batch_indicies = target_batch_indicies + 1

        
        print('\n' + '-' * 75)
        print(f"\nStep: {step} Loss: {running_loss}")

        
        save_dict = {}
        
        save_dict['state_dict'] = rnn.state_dict()
        
        save_dict['ix_to_word'] = ix_to_word
        save_dict['word_to_ix'] = word_to_ix
        
        torch.save(save_dict, save_path)

        
        with torch.no_grad():

            
            rand_index = np.random.randint(data_size - 1)
            input_batch = data[rand_index: rand_index + 1]
            hidden_state = None

            
            for i in range(gen_seq_len):
                
                output, hidden_state = rnn(input_batch, hidden_state)

                
                output = F.softmax(torch.squeeze(output), dim=0)
                dist = Categorical(output)
                index = dist.sample()

                
                print(ix_to_word[index.item()], end=' ')

                
                input_batch[0][0] = index.item()
    else:
        print(f"\nStep: {step} - Insufficient data for training.")
