In [3]:
import pandas as pd
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from jiwer import wer
from torch.utils.tensorboard import SummaryWriter

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# hyperparams
log_dir = "logs/run_3"
writer = SummaryWriter(log_dir=log_dir, purge_step=0)

batch_size = 16
lr = 5e-5

model_name = "google-t5/t5-small"
model_path = "g2p_t5_model"

device = torch.device("mps" if torch.mps.is_available() else "cpu")
# device = torch.device("cpu")

In [5]:
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
model.gradient_checkpointing_enable()

# Optimizer
optimizer = AdamW(model.parameters(), lr=lr)

writer.add_scalar("Learning Rate", lr)
writer.add_scalar("Batch Size", batch_size)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [None]:
# Function to force character-level tokenization
def format_input(text):
    return f"grapheme to phoneme: {text}"
    # return f"grapheme to phoneme: {" ".join(text)}"


# Custom dataset class
class G2PDataset(Dataset):
    def __init__(self, file_path, max_length=512):
        self.data = pd.read_csv(file_path)
        self.max_length = max_length

        print(f"Dataset: {file_path}")
        print(f"Original dataset size:\t{len(self.data)}")

        # self.data = self.data[self.data["text"].apply(lambda x: len(tokenizer(format_input(x))["input_ids"]) <= self.max_length)]
        self.data = self.data[
            self.data.apply(
                lambda x: (
                    len(tokenizer(format_input(x["text"]))["input_ids"]) <= self.max_length
                    and len(tokenizer(x["phonemes"])["input_ids"]) <= self.max_length
                ),
                axis=1
            )
        ]

        print(f"Reduced dataset size:\t{len(self.data)}\n")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        grapheme_text = self.data.iloc[idx]["text"]
        phoneme_text = self.data.iloc[idx]["phonemes"]

        # Force character-level tokenization
        formatted_input = format_input(grapheme_text)

        return formatted_input, phoneme_text

# Collate function for dynamic padding
def collate_fn(batch):
    inputs, targets = zip(*batch)

    # Tokenize with dynamic padding (longest in batch)
    input_enc = tokenizer(list(inputs), padding=True, return_tensors="pt", truncation=False)
    target_enc = tokenizer(list(targets), padding=True, return_tensors="pt", truncation=False)

    return {
        "input_ids": input_enc.input_ids,
        "attention_mask": input_enc.attention_mask,
        "labels": target_enc.input_ids,
    }

# Load dataset and dataloader
train_dataset = G2PDataset("clean_data/train-clean-100.csv")
val_dataset = G2PDataset("clean_data/test-clean.csv")

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)

Dataset: clean_data/train-clean-100.csv
Original dataset size:	28538
Reduced dataset size:	28536
Dataset: clean_data/test-clean.csv
Original dataset size:	2620
Reduced dataset size:	2598


In [5]:
# Training function
def train_model(model, train_loader, val_loader, epochs, writer, verbose=True):
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=True)
        for batch in progress_bar:
            optimizer.zero_grad()

            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, use_cache=False)
            loss = outputs.loss
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            progress_bar.set_postfix(loss=loss.item())
            del input_ids, attention_mask, labels
            torch.mps.empty_cache()

        avg_train_loss = total_loss / len(train_loader)
        if verbose:
            print(f"Epoch {epoch+1}, Loss: {avg_train_loss:.4f}")

        writer.add_scalar("Training Loss", avg_train_loss, epoch)

        # Validation
        avg_val_loss, exact_match_accuracy, avg_per = validate_model(model, val_loader, writer)
        writer.add_scalar("Validation Loss", avg_val_loss, epoch)
        writer.add_scalar("Exact Match", exact_match_accuracy, epoch)
        writer.add_scalar("Average PER", avg_per, epoch)

        torch.mps.empty_cache()

    return


# Validation function
def validate_model(model, val_loader, verbose=True):
    model.eval()
    total_loss = 0
    total_exact = 0
    total_samples = 0
    total_per = 0

    progress_bar = tqdm(val_loader, desc="Validation", leave=True)

    with torch.no_grad():
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            total_loss += outputs.loss.item()

            # calculate per
            predicted_ids = torch.argmax(outputs.logits, dim=-1)
            pred_phonemes = [tokenizer.decode(pred, skip_special_tokens=True) for pred in predicted_ids]
            true_phonemes = [tokenizer.decode(label, skip_special_tokens=True) for label in labels]

            for pred, target in zip(pred_phonemes, true_phonemes):
                if pred.strip() == target.strip():
                    total_exact += 1
                else:
                    total_per += wer(target, pred)  # WER works similarly for phonemes

            total_samples += len(batch['labels'])

            torch.mps.empty_cache()

    avg_val_loss = total_loss / len(val_loader)
    exact_match_accuracy = total_exact / total_samples
    avg_per = total_per / total_samples

    if verbose:
        print(f"Validation Loss: {avg_val_loss:.4f}")
        print(f"Validation Accuracy (Exact Match): {exact_match_accuracy * 100:.2f}%")
        print(f"Average Phoneme Error Rate (PER): {avg_per * 100:.2f}%")

    return avg_val_loss, exact_match_accuracy, avg_per

In [6]:
train_model(model, train_loader, val_loader, 10, writer)

In [None]:
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)

writer.close()

In [None]:
from tensorboard.backend.event_processing import event_accumulator

# Load the TensorBoard log directory
ea = event_accumulator.EventAccumulator("logs/run_1")
ea.Reload()  # Load data from event files

In [None]:
input = "even in middle age they were still comely and the old grey haired women at their cottage doors had a dignity not to say majesty of their own"
# output = "IY1 V IH0 N IH1 N M IH1 D AH0 L EY1 JH DH EY1 W ER0 S T IH1 L K AH1 M L IY0 sp AE1 N D DH IY0 OW1 L D G R EY1 HH EH1 R D W IH1 M AH0 N AE1 T DH EH1 R K AA1 T IH0 JH sp D AO1 R Z sp HH AE1 D AH0 D IH1 G N AH0 T IY0 sp N AA1 T T IH0 S EY1 M AE1 JH AH0 S T IY0 AH0 V DH EH1 R OW1 N sp"
# input = "i expressed by signs my admiration and pleasure to my guides and they were greatly pleased"
# output = "AY1 IH0 K S P R EH1 S T sp B AY1 S AY1 N Z M AY1 AE2 D M ER0 EY1 SH AH0 N AH0 N D P L EH1 ZH ER0 T AH0 M AY1 G AY1 D Z sp AE1 N D DH EY1 W ER0 G R EY1 T L IY0 P L IY1 Z D sp"


input_text = format_input(input)
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

model.eval()
with torch.no_grad():
    output_ids = model.generate(input_ids, use_cache=True, max_length=1024)

phoneme_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
phoneme_output