# Problem Statement:
Develop a machine learning model that takes a concatenated text string without spaces (e.g., `"helloworld"`) as input and outputs the correctly spaced version (e.g., `"hello world"`). The goal is to accurately predict word boundaries and insert spaces in the appropriate positions.

# Let's create a dataset

In [None]:
from datasets import load_dataset, load_from_disk

In [None]:
book_corpus = load_dataset("bookcorpus/bookcorpus")

In [None]:
book_corpus

In [None]:
def process_text(example):
    example['text_no_space'] = example['text'].replace(" ", "")
    return example

In [None]:
text_data = book_corpus.select(range(0, 1_000_000)).map(process_text)

In [None]:
text_data

In [None]:
text_data.save_to_disk(f"data/processed_bookcorpus_0")

### Training on 1M rows

In [2]:
from datasets import load_from_disk

In [3]:
text_data = load_from_disk("data/processed_bookcorpus_0")

In [4]:
text_data[0]

{'text': 'usually , he would be tearing around the living room , playing with his toys .',
 'text_no_space': 'usually,hewouldbetearingaroundthelivingroom,playingwithhistoys.'}

In [5]:
def build_vocab(texts):
    texts = ''.join(texts)
    print(f'total chars {len(texts)}')
    chars = set(''.join(texts))
    vocab = {char: idx for idx, char in enumerate(chars, start=2)}  # Reserve 0 for padding, 1 for <unk>
    vocab["<pad>"] = 0
    vocab["<unk>"] = 1
    return vocab

In [6]:
texts = [example["text_no_space"] for example in text_data] 
vocab = build_vocab(texts)

total chars 53467665


In [7]:
# sorted(vocab.keys())
len(vocab)

67

In [50]:
num_epochs = 1
batch_size = 64
lr = 1e-4

n_vocab = len(vocab)
emd_dim = 128
hidden_size = 128
num_lstm_layers = 3

In [9]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [10]:
class WordSplitterDataset(Dataset):

    def __init__(self, data, vocab):
        self.data = data
        self.vocab = vocab
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]["text"]

        n = len(text)

        inp = []
        out = []

        for i in range(0, n):
            if text[i] == ' ':
                continue;
            inp.append(self.vocab.get(text[i], self.vocab["<unk>"]))
            if i < n - 1 and text[i + 1] == ' ':
                out.append(1)
            else:
                out.append(0)
                
        return torch.tensor(inp, dtype=torch.long), torch.tensor(out, dtype=torch.long)

In [40]:
train_data = data_splits["train"]

Dataset({
    features: ['text', 'text_no_space'],
    num_rows: 800000
})

In [41]:
data_splits = text_data.train_test_split(0.2)
test_split = data_splits["test"].train_test_split(0.5)

train_data = data_splits["train"].select(range(0, 2_000_00))
val_data = test_split["train"].select(range(0, 50_000))
test_data = test_split["test"].select(range(0, 50_000))

In [43]:
train_data, val_data, test_data

(Dataset({
     features: ['text', 'text_no_space'],
     num_rows: 200000
 }),
 Dataset({
     features: ['text', 'text_no_space'],
     num_rows: 50000
 }),
 Dataset({
     features: ['text', 'text_no_space'],
     num_rows: 50000
 }))

In [51]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch, max_length=100):
    inputs, labels = zip(*batch)

    truncated_inputs = [seq[:max_length] for seq in inputs]
    truncated_labels = [seq[:max_length] for seq in labels]

    padded_inputs = pad_sequence(truncated_inputs, batch_first=True, padding_value=0)
    padded_labels = pad_sequence(truncated_labels, batch_first=True, padding_value=-100)  # Use -100 for ignored labels

    return padded_inputs, padded_labels


In [52]:
train_dataloader = DataLoader(WordSplitterDataset(train_data, vocab), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_dataloader =  DataLoader(WordSplitterDataset(val_data, vocab), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [53]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class LSTMWordSplitterModel(nn.Module):
    def __init__(self, n_vocab, emd_dim, hidden_size,
                 num_lstm_layers):
        super(LSTMWordSplitterModel, self).__init__()

        # embedding layer
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=emd_dim,
        )

        
        # LSMT layer
        self.lstm = nn.LSTM(input_size=emd_dim, hidden_size=hidden_size,
                            num_layers=num_lstm_layers, dropout=0.1,
                            bidirectional=True, batch_first=True)

        # FCN layer
        linear_h_size = 2 * hidden_size # BiLSTM
        
        self.first_fcn = nn.Linear(linear_h_size, linear_h_size)
        self.second_fcn = nn.Linear(linear_h_size, 2)
        
        self.relu = nn.ReLU()

    def forward(self, x):
        embed = self.embedding(x)        
        out, _ = self.lstm(embed)        
        out = self.first_fcn(out)
        out = self.relu(out)
        logits = self.second_fcn(out)
        return logits

In [54]:
model = LSTMWordSplitterModel(n_vocab=n_vocab,
                              emd_dim=emd_dim,
                              hidden_size=hidden_size,
                              num_lstm_layers=num_lstm_layers
                             )

In [55]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [56]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(model)

1129602

In [57]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [58]:
model = model.to(device)

In [67]:
from tqdm.notebook import tqdm

In [71]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

In [72]:
for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0

    for batch_idx, (inputs, labels) in enumerate(tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}")):
        inputs = inputs.to(device)  
        labels = labels.to(device)

        logits = model(inputs)

        loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

        total_train_loss += loss.item()
        writer.add_scalar("Loss/Train Batch", loss.item(), epoch * len(train_dataloader) + batch_idx)

        if (batch_idx + 1) % 100 == 0:
            tqdm.write(f"Epoch {epoch + 1}, Batch {batch_idx + 1}, Training Loss: {loss.item():.4f}")

    avg_train_loss = total_train_loss / len(train_dataloader)
    tqdm.write(f"Epoch {epoch + 1}, Average Training Loss: {avg_train_loss:.4f}")
    writer.add_scalar("Loss/Train Epoch", avg_train_loss, epoch)

    model.eval()
    total_val_loss = 0
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(tqdm(val_dataloader, desc=f"Validation Epoch {epoch + 1}")):
            inputs = inputs.to(device)
            labels = labels.to(device)

            logits = model(inputs)

            val_loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
            writer.add_scalar("Loss/Validation Batch", val_loss.item(), epoch * len(val_dataloader) + batch_idx)

            total_val_loss += val_loss.item()

            predictions = logits.argmax(dim=-1)
            valid_mask = labels.view(-1) != -100  # Mask to ignore padding
            valid_predictions = predictions.view(-1)[valid_mask]
            valid_labels = labels.view(-1)[valid_mask]
    
            correct_predictions += (valid_predictions == valid_labels).sum().item()
            total_predictions += valid_labels.numel()

        

    avg_val_loss = total_val_loss / len(val_dataloader)
    val_accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
    tqdm.write(f"Epoch {epoch + 1}, Average Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

Training Epoch 1:   0%|          | 0/3125 [00:00<?, ?it/s]

Epoch 1, Batch 100, Training Loss: 0.3983
Epoch 1, Batch 200, Training Loss: 0.2288
Epoch 1, Batch 300, Training Loss: 0.1583
Epoch 1, Batch 400, Training Loss: 0.1394
Epoch 1, Batch 500, Training Loss: 0.1312
Epoch 1, Batch 600, Training Loss: 0.1012
Epoch 1, Batch 700, Training Loss: 0.0973
Epoch 1, Batch 800, Training Loss: 0.0930
Epoch 1, Batch 900, Training Loss: 0.0865
Epoch 1, Batch 1000, Training Loss: 0.0908
Epoch 1, Batch 1100, Training Loss: 0.0888
Epoch 1, Batch 1200, Training Loss: 0.0686
Epoch 1, Batch 1300, Training Loss: 0.0858
Epoch 1, Batch 1400, Training Loss: 0.0852
Epoch 1, Batch 1500, Training Loss: 0.0807
Epoch 1, Batch 1600, Training Loss: 0.0815
Epoch 1, Batch 1700, Training Loss: 0.0708
Epoch 1, Batch 1800, Training Loss: 0.0657
Epoch 1, Batch 1900, Training Loss: 0.0675
Epoch 1, Batch 2000, Training Loss: 0.0746
Epoch 1, Batch 2100, Training Loss: 0.0533
Epoch 1, Batch 2200, Training Loss: 0.0603
Epoch 1, Batch 2300, Training Loss: 0.0552
Epoch 1, Batch 2400,

Validation Epoch 1:   0%|          | 0/782 [00:00<?, ?it/s]

Epoch 1, Average Validation Loss: 0.0440, Validation Accuracy: 0.9840


In [73]:
import os

checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

def save_checkpoint(epoch, model, optimizer, val_loss, filepath):
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "val_loss": val_loss,
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved at {filepath}")

In [75]:
save_checkpoint(
    1,
    model,
    optimizer,
    val_loss=avg_val_loss,
    filepath=os.path.join(checkpoint_dir, f"model_epoch_{1}.pth"),
)

Checkpoint saved at checkpoints/model_epoch_1.pth


In [76]:
def predict(model, vocab, input_text, max_length=100, device="cpu"):
    model.eval()

    input_indices = [vocab.get(char, vocab["<unk>"]) for char in input_text]
    if len(input_indices) > max_length:
        input_indices = input_indices[:max_length]
    input_tensor = torch.tensor(input_indices, dtype=torch.long).unsqueeze(0)  # Shape: (1, seq_len)

    input_tensor = input_tensor.to(device)

    with torch.no_grad():
        logits = model(input_tensor)  # Shape: (1, seq_len, num_classes)
        predictions = logits.argmax(dim=-1).squeeze(0).tolist()  # Predicted labels

    split_text = ""
    for i, char in enumerate(input_text[:max_length]):
        split_text += char
        if predictions[i] == 1:
            split_text += " "

    return split_text.strip()

In [94]:
test_sentences = [
    "himynameisx",
    "thequickbrownfoxjumpsoverthelazydog",
    "machinelearningisfun",
    "iloveprogramming",
    "deepneuralnetworksarepowerful",
    "welcometopytorchtraining",
    "hellohowareyou",
    "thisisatestsentence",
    "naturallanguageprocessing",
    "How are you?",
    "Let'ssee,yourgame",
    "whatis1+1?",
    "ihave4tasks"
]

In [95]:
for sentence in test_sentences:
    predicted_text = predict(model, vocab, sentence, device=device)
    
    print(f"{sentence:<50} | {predicted_text:<50}")

himynameisx                                        | himy name isx                                     
thequickbrownfoxjumpsoverthelazydog                | the quick brown fox jumps over the lazy dog       
machinelearningisfun                               | machine learning is fun                           
iloveprogramming                                   | i love programming                                
deepneuralnetworksarepowerful                      | deep neural net works a repowerful                
welcometopytorchtraining                           | welcome to pytorchtraining                        
hellohowareyou                                     | hello how are you                                 
thisisatestsentence                                | this isa testsentence                             
naturallanguageprocessing                          | natural language processing                       
How are you?                                       | How  are  y