In [1]:
%cd "C:/Users/Patrickn/Jupyter_notebooks/Graduation/AI_Companion/TTS/styletts"

C:\Users\Patrickn\Jupyter_notebooks\Graduation\AI_Companion\TTS\styletts


In [None]:
!git clone https://github.com/patrick-3008/StyleTTS2.git

In [22]:
import re

# --- 1. Define Your Target Phoneme Set ---
# This is a simplified set, close to IPA and what Espeak uses for Arabic.
# You might need to expand this based on your specific needs.
PHONEME_MAP = {
    'ا': 'a',  # Alif - often carrier or part of long vowel/hamza
    'ب': 'b',
    'ت': 't',
    'ث': 'θ',
    'ج': 'dʒ', # Or 'g' depending on dialect (e.g., Egyptian)
    'ح': 'ħ',
    'خ': 'χ', # Or 'x'
    'د': 'd',
    'ذ': 'ð',
    'ر': 'r',
    'ز': 'z',
    'س': 's',
    'ش': 'ʃ',
    'ص': 'sˤ', # Or 'S' for simplicity
    'ض': 'dˤ', # Or 'D'
    'ط': 'tˤ', # Or 'T'
    'ظ': 'ðˤ', # Or 'DH'
    'ع': 'ʕ',
    'غ': 'ʁ', # Or 'ɣ' or 'g'
    'ف': 'f',
    'ق': 'q', # Or 'g' depending on dialect
    'ك': 'k',
    'ل': 'l',
    'م': 'm',
    'ن': 'n',
    'ه': 'h',
    'و': 'w',  # Wau - consonant or part of long vowel/diphthong
    'ي': 'j',  # Ya - consonant or part of long vowel/diphthong
    'ء': 'ʔ',  # Hamza
    'آ': 'ʔaː', # Alif Madda
    'ى': 'aː', # Alif Maqsura (often pronounced as long 'a')
    'ة': 'h',  # Ta' Marbutah (simplified - often 't' in connected speech)

    # --- Diacritics (Tashkeel) ---
    'َ': 'a',  # Fatha
    'ُ': 'u',  # Damma
    'ِ': 'i',  # Kasra
    'ْ': '',   # Sukun (indicates no vowel)
    'ّ': ':',  # Shadda (indicates gemination - will be handled by a rule)
    'ً': 'an', # Fathatan (Tanween Fath) - simplified, often dropped in pausal
    'ٌ': 'un', # Dammatan (Tanween Damm) - simplified
    'ٍ': 'in', # Kasratan (Tanween Kasr) - simplified

    # --- Other characters to handle ---
    ' ': ' ', # Space
    # Add punctuation if needed and how to handle it (e.g., ignore, map to silence)
}

# --- 2. Define Phonological Rules ---
# Rules will be applied sequentially. Order matters!
# Each rule is a tuple: (pattern_to_find, replacement_phonemes)
# We'll use regex for pattern matching.

# Note: This is a very basic set of rules. A real system needs many more.
# Handling unvocalized text requires inferring missing diacritics, which
# is NOT covered by these simple rules.

RULES = [
    # Rule 1: Handle long vowels (short vowel + madd letter)
    # Assumes short vowels are already mapped.
    # e.g., a + ا -> aː
    (r'aا', 'aː'),
    (r'uو', 'uː'),
    (r'iي', 'iː'),
    # Handle Alif Maqsura as long 'a'
    (r'اى', 'aː'), # Assuming Alif followed by Alif Maqsura results in long a

    # Rule 2: Handle Shadda (gemination)
    # Find a phoneme followed by the shadda marker ':' and duplicate the phoneme.
    # This needs to be done carefully to avoid duplicating vowels or spaces.
    # We'll look for a consonant phoneme before the ':'
    # This regex is simplified; a real system needs more robust handling.
    (r'([btdðr z s ʃ sˤ dˤ tˤ ðˤ ʕ ʁ f q k l m n h w j ʔ]):', r'\1\1'),


    # Rule 3: Handle the definite article 'Al' (ال)
    # This is a simplified rule for Sun letters assimilation.
    # It assumes 'al' has already been mapped from 'ال'.
    # Sun letters: ت ث د ذ ر ز س ش ص ض ط ظ ل ن
    # We look for 'al' followed by a Sun letter's phoneme and assimilate.
    # This rule assumes the Sun letter's phoneme is the next character.
    # A more robust rule would check the *original* Arabic letter.
    (r'al([tθdðrzsʃsˤdˤtˤðˤln])', r'a\1\1'), # al + SunLetter -> a + doubled SunLetter

    # Rule 4: Clean up any remaining shadda markers if not processed (shouldn't happen with Rule 2)
    (r':', ''),

]

# --- 3. Implement the G2P Function ---

def arabic_g2p(text):
    phonemes = []
    for char in text:
        if char in PHONEME_MAP:
            phonemes.append(PHONEME_MAP[char])
        else:
            # Handle characters not in map (e.g., numbers, other symbols)
            # You might want to raise an error or skip them.
            print(f"Warning: Character '{char}' not found in PHONEME_MAP. Skipping or mapping to itself.")
            phonemes.append(char) # Or '' to skip

    # Join into a string to apply regex rules
    phoneme_string = "".join(phonemes)
    print(f"After initial mapping: {phoneme_string}") # Debugging

    # Step 2: Apply Phonological Rules
    processed_phonemes = phoneme_string
    for pattern, replacement in RULES:
        # Use re.sub to apply the rule globally
        processed_phonemes = re.sub(pattern, replacement, processed_phonemes)
        print(f"After rule '{pattern}': {processed_phonemes}") # Debugging

    # Step 3: Final Cleanup (if needed)
    # You might have leftover characters or need to format the output
    # For example, removing extra spaces or adding stress markers if determined.
    final_phonemes = processed_phonemes.strip() # Example: remove leading/trailing space

    return final_phonemes

# --- 4. Test the System ---

unvocalized_phrase = "انا اسمي باتريك" # Ana ismi Patrick
# This will also be challenging without diacritics.
# Espeak output was ['ˈanaː ˈismiːj bˈaːtɹiːk ']
# Our simple system will not produce this level of detail (stress, specific vowels)
phonemic_output_phrase = arabic_g2p(unvocalized_phrase)
print(f"Input: {unvocalized_phrase}")
print(f"Output: {phonemic_output_phrase}")
# Expected output (simplified and likely incorrect): ana asmy batryk


After initial mapping: ana asmj batrjk
After rule 'aا': ana asmj batrjk
After rule 'uو': ana asmj batrjk
After rule 'iي': ana asmj batrjk
After rule 'اى': ana asmj batrjk
After rule '([btdðr z s ʃ sˤ dˤ tˤ ðˤ ʕ ʁ f q k l m n h w j ʔ]):': ana asmj batrjk
After rule 'al([tθdðrzsʃsˤdˤtˤðˤln])': ana asmj batrjk
After rule ':': ana asmj batrjk
Input: انا اسمي باتريك
Output: ana asmj batrjk


# PreProcess Data

In [28]:
import json
import io
import os
import glob

WORD_PHONEME_DIR = 'C:/Users/Patrickn/Jupyter_notebooks/Graduation/AI_Companion/TTS/g2p_data'
WORD_PHONEME_FILE_PATTERN = '*.md' # Adjust if your files have a different extension

SENTENCE_DATA_DIR = 'C:/Users/Patrickn/Jupyter_notebooks/Graduation/AI_Companion/TTS/g2p_data'
SENTENCE_DATA_FILE_PATTERN = '*.txt' # Adjust if your files have a different extension


def parse_word_phoneme_pairs(raw_data):
    word_phoneme_list = []
    # Use io.StringIO to treat the string as a file
    data_file = io.StringIO(raw_data.strip()) # strip() removes leading/trailing whitespace from the whole block

    for line in data_file:
        line = line.strip() # Remove leading/trailing whitespace from the line
        if line: # Ensure the line is not empty
            # Split the line by whitespace. Assumes the first token is the word
            # and the rest is the phonemic transcription (joined by spaces if needed).
            parts = line.split(maxsplit=1) # Split only on the first whitespace
            if len(parts) == 2:
                word = parts[0]
                phonemes = parts[1].strip() # Ensure no extra whitespace around phonemes
                # Basic validation: check if phonemes string is not empty after stripping
                if phonemes:
                     word_phoneme_list.append((word, phonemes))
                else:
                     print(f"Warning: Skipping line with empty phonemes in format 1: '{line}'")
            else:
                print(f"Warning: Skipping malformed line in format 1: '{line}'")

    return word_phoneme_list

def parse_sentence_data(raw_json_data):
    try:
        sentence_list = json.loads(raw_json_data)
        # Basic validation to ensure the structure is as expected
        if not isinstance(sentence_list, list):
            print("Error: JSON data is not a list.")
            return []
        valid_sentence_list = []
        for item in sentence_list:
            if isinstance(item, dict) and "text" in item and "g2p" in item:
                 # Optional: Basic validation for non-empty strings
                 if item["text"] and item["g2p"]:
                    valid_sentence_list.append(item)
                 else:
                    print(f"Warning: Skipping item with empty text or g2p in format 2: {item}")
            else:
                print(f"Warning: Skipping malformed item in format 2: {item}")
        return valid_sentence_list
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON data for format 2: {e}")
        return []

# --- Data Loading and Preparation ---

def load_data_from_files(directory=None, file_list=None, file_pattern='*', parser_func=None):
    all_data = []
    files_to_process = []

    if file_list:
        files_to_process = file_list
    elif directory and os.path.isdir(directory):
        files_to_process = glob.glob(os.path.join(directory, file_pattern))
    else:
        print(f"Error: No valid directory or file list provided.")
        return []

    print(f"Found {len(files_to_process)} files to process.")

    for file_path in files_to_process:
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                raw_data = f.read()
                parsed_data = parser_func(raw_data)
                all_data.extend(parsed_data)
            print(f"Successfully processed: {file_path}")
        except FileNotFoundError:
            print(f"Error: File not found: {file_path}")
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")

    return all_data

# --- Load and Prepare Datasets ---

print("--- Loading and Preparing Data Format 1 (Word Lexicon) ---")
# Choose one option based on your configuration above (directory or file_list)
# If using directory:
all_word_phoneme_pairs = load_data_from_files(directory=WORD_PHONEME_DIR, file_pattern=WORD_PHONEME_FILE_PATTERN, parser_func=parse_word_phoneme_pairs)
# If using file_list:
# all_word_phoneme_pairs = load_data_from_files(file_list=WORD_PHONEME_FILES, parser_func=parse_word_phoneme_pairs)


print("\n--- Loading and Preparing Data Format 2 (Sentence Pairs) ---")
# Choose one option based on your configuration above (directory or file_list)
# If using directory:
all_sentence_pairs = load_data_from_files(directory=SENTENCE_DATA_DIR, file_pattern=SENTENCE_DATA_FILE_PATTERN, parser_func=parse_sentence_data)
# If using file_list:
# all_sentence_pairs = load_data_from_files(file_list=SENTENCE_DATA_FILES, parser_func=parse_sentence_data)


print("\n" + "="*30 + "\n")
print("--- Data Loading Summary ---")
print(f"Total word-phoneme pairs loaded: {len(all_word_phoneme_pairs)}")
print(f"Total sentence pairs loaded: {len(all_sentence_pairs)}")

--- Loading and Preparing Data Format 1 (Word Lexicon) ---
Found 5 files to process.
Successfully processed: C:/Users/Patrickn/Jupyter_notebooks/Graduation/AI_Companion/TTS/g2p_data\egyptian_arabic_g2p_500.md
Successfully processed: C:/Users/Patrickn/Jupyter_notebooks/Graduation/AI_Companion/TTS/g2p_data\egyptian_arabic_g2p_500_additional.md
Successfully processed: C:/Users/Patrickn/Jupyter_notebooks/Graduation/AI_Companion/TTS/g2p_data\egyptian_arabic_g2p_new.md
Successfully processed: C:/Users/Patrickn/Jupyter_notebooks/Graduation/AI_Companion/TTS/g2p_data\egyptian_arabic_g2p_new_1000.md
Successfully processed: C:/Users/Patrickn/Jupyter_notebooks/Graduation/AI_Companion/TTS/g2p_data\egyptian_arabic_g2p_new_500.md

--- Loading and Preparing Data Format 2 (Sentence Pairs) ---
Found 15 files to process.
Error decoding JSON data for format 2: Expecting property name enclosed in double quotes: line 1595 column 5 (char 90382)
Successfully processed: C:/Users/Patrickn/Jupyter_notebooks/Grad

In [34]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import random # Import random for data splitting
import numpy as np # Import numpy for potential random seed setting

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# Assume all_word_phoneme_pairs and all_sentence_pairs are defined elsewhere
# Example placeholders (replace with your actual data loading):
# all_word_phoneme_pairs = [("hello", "H EH L OW"), ("world", "W ER L D")]
# all_sentence_pairs = [{"text": "this is a test", "g2p": "DH IH S IH Z AH T EH S T"}]

# --- 1. Build Vocabularies ---

def build_vocab(items, special_tokens=['<pad>', '<unk>', '<sos>', '<eos>']): # Added <sos> and <eos> for potential future seq2seq improvements
    vocab = {tok: idx for idx, tok in enumerate(special_tokens)}
    idx = len(vocab)
    for item in items:
        if item not in vocab:
            vocab[item] = idx
            idx += 1
    return vocab

# Gather all unique chars and phonemes from both datasets
all_chars = set()
all_phonemes = set()

# --- Replace with your actual data loading and processing ---
# Example using placeholder data structure:
# for word, phonemes in all_word_phoneme_pairs:
#     all_chars.update(list(word))
#     all_phonemes.update(phonemes.split())

# for pair in all_sentence_pairs:
#     all_chars.update(list(pair['text']))
#     all_phonemes.update(pair['g2p'].split())
# --- End of placeholder example ---

# Placeholder for demonstration - replace with your actual data loading
# For this code to be runnable, we need some dummy data if the actual data isn't available
try:
    # Check if data variables exist (if loaded from elsewhere)
    _ = all_word_phoneme_pairs
    _ = all_sentence_pairs
except NameError:
    # Define dummy data if not found
    print("Using dummy data for demonstration. Replace with your actual data loading.")
    all_word_phoneme_pairs = [("hello", "H EH L OW"), ("world", "W ER L D"), ("python", "P AY TH AA N"),
                              ("example", "IH G Z AE M P AH L"), ("data", "D EY T AH"), ("science", "S AY AH N S")]
    all_sentence_pairs = [{"text": "this is a test", "g2p": "DH IH S IH Z AH T EH S T"},
                          {"text": "grapheme to phoneme", "g2p": "G R AE F IY M T UW F OW N IY M"},
                          {"text": "neural network", "g2p": "N UW R AH L N EH T W ER K"}]

# Process dummy/actual data to build vocab
for word, phonemes in all_word_phoneme_pairs:
    all_chars.update(list(word))
    all_phonemes.update(phonemes.split())

for pair in all_sentence_pairs:
    all_chars.update(list(pair['text']))
    all_phonemes.update(pair['g2p'].split())


char_vocab = build_vocab(sorted(list(all_chars))) # Convert set to list for sorting
phoneme_vocab = build_vocab(sorted(list(all_phonemes))) # Convert set to list for sorting

# --- 2. Dataset Classes ---

class G2PDataset(Dataset):
    def __init__(self, samples, char_vocab, phoneme_vocab):
        self.samples = samples # Samples are passed directly now after splitting
        self.char_vocab = char_vocab
        self.phoneme_vocab = phoneme_vocab

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        chars, phonemes = self.samples[idx]
        # Use .get with a default to handle potential missing keys gracefully, though build_vocab should cover most.
        char_ids = [self.char_vocab.get(c, self.char_vocab['<unk>']) for c in chars]
        phoneme_ids = [self.phoneme_vocab.get(p, self.phoneme_vocab['<unk>']) for p in phonemes]
        return torch.tensor(char_ids, dtype=torch.long), torch.tensor(phoneme_ids, dtype=torch.long)

def collate_fn(batch):
    char_seqs, phoneme_seqs = zip(*batch)
    char_lens = [len(seq) for seq in char_seqs]
    phoneme_lens = [len(seq) for seq in phoneme_seqs]
    max_char_len = max(char_lens)
    max_phoneme_len = max(phoneme_lens)
    pad_char = char_vocab['<pad>']
    pad_phoneme = phoneme_vocab['<pad>']

    # Pad sequences
    padded_chars = [torch.cat([seq, torch.full((max_char_len - len(seq),), pad_char, dtype=torch.long)]) for seq in char_seqs]
    padded_phonemes = [torch.cat([seq, torch.full((max_phoneme_len - len(seq),), pad_phoneme, dtype=torch.long)]) for seq in phoneme_seqs]

    return torch.stack(padded_chars), torch.stack(padded_phonemes), torch.tensor(char_lens), torch.tensor(phoneme_lens)

# --- Data Splitting ---
# Combine all samples first
all_samples = []
for word, phonemes in all_word_phoneme_pairs:
    all_samples.append((list(word), phonemes.split()))
for pair in all_sentence_pairs:
    all_samples.append((list(pair['text']), pair['g2p'].split()))

# Shuffle samples
random.shuffle(all_samples)

# Define split ratio
train_ratio = 0.8
train_size = int(len(all_samples) * train_ratio)

# Split data
train_samples = all_samples[:train_size]
val_samples = all_samples[train_size:]

# Create Dataset instances
train_dataset = G2PDataset(train_samples, char_vocab, phoneme_vocab)
val_dataset = G2PDataset(val_samples, char_vocab, phoneme_vocab)


# --- 4. Prepare DataLoaders ---

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn) # No need to shuffle validation data

# --- 3. Model ---

class SimpleLSTMG2P(nn.Module):
    def __init__(self, char_vocab_size, phoneme_vocab_size, emb_dim=64, hidden_dim=128, dropout_prob=0.3):
        super().__init__()
        self.embedding = nn.Embedding(char_vocab_size, emb_dim, padding_idx=char_vocab['<pad>'])
        # Dropout layer after embedding
        self.dropout_emb = nn.Dropout(dropout_prob)

        # Bidirectional LSTM encoder
        # dropout argument in LSTM applies dropout to the output of each layer except the last layer
        self.encoder = nn.LSTM(emb_dim, hidden_dim, batch_first=True, bidirectional=True, dropout=dropout_prob)

        # Unidirectional LSTM decoder
        # Input to decoder is the final encoder hidden state (bidirectional -> 2 * hidden_dim)
        self.decoder = nn.LSTM(hidden_dim * 2, hidden_dim, batch_first=True, dropout=dropout_prob)

        # Linear layer to map decoder output to phoneme vocabulary size
        self.fc = nn.Linear(hidden_dim, phoneme_vocab_size)
        # Dropout layer before the final linear layer
        self.dropout_fc = nn.Dropout(dropout_prob)


    def forward(self, char_seqs, char_lens, target_len):
        # Apply embedding
        emb = self.embedding(char_seqs)
        # Apply dropout to embeddings
        emb = self.dropout_emb(emb)

        # Pack padded sequences for efficient LSTM processing
        # char_lens must be on CPU for pack_padded_sequence
        packed = nn.utils.rnn.pack_padded_sequence(emb, char_lens.cpu(), batch_first=True, enforce_sorted=False)

        # Pass through encoder
        enc_out, (h, c) = self.encoder(packed)

        # Pad packed sequence back to original length
        enc_out, _ = nn.utils.rnn.pad_packed_sequence(enc_out, batch_first=True)

        # Simple decoder input: Repeat the final hidden state of the encoder
        # For bidirectional LSTM, the final hidden state (h) has shape (2 * num_layers, batch_size, hidden_dim)
        # We take the hidden state from the last layer (h[-2:, :, :]) and concatenate the forward and backward parts.
        # Then repeat it for each time step in the target sequence.
        # Note: This is a very basic decoder input strategy. More advanced models use attention.
        # h[-2, :, :] is the last layer's forward hidden state
        # h[-1, :, :] is the last layer's backward hidden state
        last_encoder_hidden = torch.cat((h[-2, :, :], h[-1, :, :]), dim=1) # Concatenate forward and backward last layer hidden states
        dec_input = last_encoder_hidden.unsqueeze(1).repeat(1, target_len, 1) # Repeat for target_len times

        # Pass through decoder
        dec_out, _ = self.decoder(dec_input)

        # Apply dropout before the final linear layer
        dec_out = self.dropout_fc(dec_out)

        # Pass through linear layer to get logits
        logits = self.fc(dec_out)

        return logits

# --- 5. Example Training Loop with Early Stopping ---

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Pass dropout_prob to the model constructor
model = SimpleLSTMG2P(len(char_vocab), len(phoneme_vocab), dropout_prob=0.3).to(device) # Added dropout_prob
criterion = nn.CrossEntropyLoss(ignore_index=phoneme_vocab['<pad>'])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 50 # Set a higher number of epochs, early stopping will stop it
patience = 5 # How many epochs to wait for improvement before stopping
best_val_loss = float('inf') # Initialize best validation loss
patience_counter = 0 # Counter for epochs without improvement

print(f"Starting training for up to {num_epochs} epochs with patience of {patience}...")

for epoch in range(num_epochs):
    # --- Training Phase ---
    model.train() # Set model to training mode (enables dropout)
    total_train_loss = 0
    for chars, phonemes, char_lens, phoneme_lens in train_dataloader:
        chars, phonemes = chars.to(device), phonemes.to(device)
        optimizer.zero_grad()

        # Ensure target_len is taken from the actual phoneme batch size for correct decoder input size
        target_len = phonemes.size(1)
        logits = model(chars, char_lens, target_len)

        # Reshape logits and targets for CrossEntropyLoss
        logits = logits.view(-1, logits.size(-1))
        targets = phonemes.view(-1)

        loss = criterion(logits, targets)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_dataloader)

    # --- Validation Phase ---
    model.eval() # Set model to evaluation mode (disables dropout)
    total_val_loss = 0
    with torch.no_grad(): # Disable gradient calculation for validation
        for chars, phonemes, char_lens, phoneme_lens in val_dataloader:
            chars, phonemes = chars.to(device), phonemes.to(device)

            target_len = phonemes.size(1)
            logits = model(chars, char_lens, target_len)

            logits = logits.view(-1, logits.size(-1))
            targets = phonemes.view(-1)

            loss = criterion(logits, targets)
            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_dataloader)

    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f} - Val Loss: {avg_val_loss:.4f}")

    # --- Early Stopping Logic ---
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        patience_counter = 0 # Reset patience counter
        # Optional: Save the model state dict that achieved the best validation loss
        # torch.save(model.state_dict(), 'best_g2p_model.pth')
    else:
        patience_counter += 1 # Increment patience counter
        print(f"Validation loss did not improve. Patience: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs due to no improvement in validation loss.")
            # Optional: Load the best model state dict before stopping
            # model.load_state_dict(torch.load('best_g2p_model.pth'))
            break # Stop training

print("Training finished.")



Starting training for up to 50 epochs with patience of 5...
Epoch 1/50 - Train Loss: 8.5031 - Val Loss: 8.2953
Epoch 1/50 - Train Loss: 8.5031 - Val Loss: 8.2953
Epoch 2/50 - Train Loss: 7.9112 - Val Loss: 8.2461
Epoch 2/50 - Train Loss: 7.9112 - Val Loss: 8.2461
Epoch 3/50 - Train Loss: 7.6787 - Val Loss: 8.1702
Epoch 3/50 - Train Loss: 7.6787 - Val Loss: 8.1702
Epoch 4/50 - Train Loss: 7.4267 - Val Loss: 8.1643
Epoch 4/50 - Train Loss: 7.4267 - Val Loss: 8.1643
Epoch 5/50 - Train Loss: 7.1657 - Val Loss: 8.1829
Validation loss did not improve. Patience: 1/5
Epoch 5/50 - Train Loss: 7.1657 - Val Loss: 8.1829
Validation loss did not improve. Patience: 1/5
Epoch 6/50 - Train Loss: 6.9068 - Val Loss: 8.1830
Validation loss did not improve. Patience: 2/5
Epoch 6/50 - Train Loss: 6.9068 - Val Loss: 8.1830
Validation loss did not improve. Patience: 2/5
Epoch 7/50 - Train Loss: 6.6806 - Val Loss: 8.1668
Validation loss did not improve. Patience: 3/5
Epoch 7/50 - Train Loss: 6.6806 - Val Loss

In [35]:
# Prepare the test sentence
test_text = "عامل ايه يا علي؟ أنا كويس الحمد لله. كتابك الجديد فين؟"
test_chars = [char_vocab.get(c, char_vocab['<unk>']) for c in test_text]
test_tensor = torch.tensor([test_chars], dtype=torch.long).to(device)
test_len = torch.tensor([len(test_chars)])

# Set a reasonable output length (can be same as input or a bit longer)
output_len = len(test_chars) + 5

# Run the model in eval mode
model.eval()
with torch.no_grad():
    logits = model(test_tensor, test_len, output_len)
    pred_ids = logits.argmax(-1).cpu().numpy()[0]

# Convert predicted ids to phonemes
id2phoneme = {v: k for k, v in phoneme_vocab.items()}
pred_phonemes = [id2phoneme.get(i, '<unk>') for i in pred_ids]

# Print the result
print("Input:", test_text)
print("Predicted phonemes:", " ".join(pred_phonemes))

Input: عامل ايه يا علي؟ أنا كويس الحمد لله. كتابك الجديد فين؟
Predicted phonemes: ʔeːh ʔeːh jaː ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh? ʔeːh?
