In [60]:
from spellchecker import SpellChecker
from tensorflow.keras.models import load_model
from collections import Counter
import numpy as np
import string

model_3 = load_model("lstm_3word_800k_e75_final.keras")
model_2 = load_model("nextword_lstm_200k.keras")
spell = SpellChecker()

In [90]:
import pickle

with open("tokenizer.pkl", "rb") as f:
    tokenizer = pickle.load(f)

word_index = tokenizer.word_index
index_word = {v:k for k,v in word_index.items()}
vocab = set(word_index.keys())
import pickle
pickle.dump(vocab, open("vocab.pkl", "wb"))


In [46]:
with open("trigram.pkl", "rb") as f:
    trigram_model = pickle.load(f)


In [62]:
def preProcess(sentence):
    for ch in string.punctuation:
        sentence = sentence.replace(ch,"")
    sentence = sentence.replace("\t","")
    sentence = sentence.lower()
    return sentence

In [65]:
class OptimizedVocabSpellChecker:
    """
    Optimized spellchecker with caching to avoid duplicate spell.correction() calls.
    Key Optimization:
    - correction_cache dictionary inside DP loop
    - Checks cache before calling spell.correction()
    - Reuses cached results for same substrings
    """
    def __init__(self, vocab):
        """Initialize with custom vocabulary"""
        self.vocab = set(vocab)
        self.spell = SpellChecker(language=None)  # Create empty spellchecker
        self.spell.word_frequency.load_words(vocab)  # Load custom vocabulary
    
    def split_and_correct(self, token):
        """
        Optimized: Cache spell corrections inside DP loop to reduce duplicate calls.
        Returns:
            list: List of corrected/split words
        """
        n = len(token)
        # dp[i] stores: (cost, word_list)
        # cost = 0 for perfect match, increases with corrections
        dp = [None] * (n + 1)
        dp[0] = (0, [])
        
        # Cache for spell corrections to avoid duplicate calls
        correction_cache = {}

        for i in range(n):
            if dp[i] is None:
                continue
            
            current_cost, current_words = dp[i]
            
            # Try all possible pieces starting from position i
            for j in range(i + 1, n + 1):
                piece = token[i:j]
                
                # Case 1: Perfect match in vocab
                if piece in self.vocab and (len(piece) >= 2 or piece in {"i", "a"}):
                    new_cost = current_cost  # No penalty for perfect match
                    new_words = current_words + [piece]
                    
                    if dp[j] is None or new_cost < dp[j][0]:
                        dp[j] = (new_cost, new_words)
                
                # Case 2: Try spell correction (with caching)
                elif len(piece) >= 3:
                    # OPTIMIZATION: Check cache first to avoid duplicate spell.correction() calls
                    if piece not in correction_cache:
                        if piece in self.vocab:
                            correction_cache[piece] = piece
                        else:
                            # Call spell.correction() only if not in cache
                            corrected = self.spell.correction(piece)
                            correction_cache[piece] = corrected if corrected else piece
                    
                    # Reuse cached correction
                    corrected = correction_cache[piece]
                    
                    if corrected != piece and corrected in self.vocab:
                        # Found a valid correction
                        new_cost = current_cost + 1  # Add penalty for correction
                        new_words = current_words + [corrected]
                        
                        if dp[j] is None or new_cost < dp[j][0]:
                            dp[j] = (new_cost, new_words)

        if dp[n] is not None:
            return dp[n][1]
        else:
            # Last resort: try to correct the whole word
            if token not in correction_cache:
                corrected = self.spell.correction(token)
                correction_cache[token] = corrected if corrected else token
            return [correction_cache[token]]
    
    def splitWord(self, sentence):
        """
        Main function to split and correct sentence.
        Args:
            sentence (str): Input sentence with potential misspellings
        Returns:
            str: Corrected and split sentence
        """
        words = sentence.split()
        new_words = []

        for word in words:
            if word in self.vocab:
                new_words.append(word)
            else:
                new_words.extend(self.split_and_correct(word))

        return " ".join(new_words)

In [49]:
abc = OptimizedVocabSpellChecker(vocab)
print(abc.splitWord("i loveinyouthe"))
print(abc.splitWord("youthe"))
print(abc.splitWord("iam"))

i love in you the
you the
i am


In [50]:
def SpellingCheck(sentence):
    words = sentence.split()
    corrected_words = []

    for w in words:
        if w in vocab:
            corrected_words.append(w)
        else:
            correction = spell.correction(w)
            corrected_words.append(correction if correction else w)

    return " ".join(corrected_words)

In [51]:
def sample_with_temperature(preds, temperature=1.0, k=3):
    preds = np.log(preds + 1e-9) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    
    top_indices = preds.argsort()[-k:][::-1]
    return top_indices


In [52]:
def predict_trigram(sentence, k=3):
    words = sentence.lower().split()
    
    if len(words) < 2:
        return []
    
    key = (words[-2], words[-1])
    
    if key not in trigram_model:
        return []
    
    freq = Counter(trigram_model[key])
    return [w for w, _ in freq.most_common(k)]


In [53]:
def predict_lstm2(sentence, k=3):
    words = sentence.lower().split()
    
    if len(words) < 2:
        return []
    
    w1, w2 = words[-2], words[-1]
    
    if w1 not in word_index or w2 not in word_index:
        return []
    
    sequence = np.array([[word_index[w1], word_index[w2]]])
    
    preds = model_2.predict(sequence, verbose=0)[0]
    
    top_indices = sample_with_temperature(preds, temperature=0.7, k=k)
    
    return [index_word.get(i, "") for i in top_indices]

In [75]:
def predict_lstm3(sentence, k=3):
    words = sentence.lower().split()
    
    if len(words) < 3:
        return []
    
    w1, w2, w3 = words[-3], words[-2], words[-1]
    
    if any(w not in word_index for w in [w1, w2, w3]):
        return []
    
    sequence = np.array([[word_index[w1], word_index[w2], word_index[w3]]])
    
    preds = model_3.predict(sequence, verbose=0)[0]
    
    top_indices = sample_with_temperature(preds, temperature=0.7, k=k)
    
    return [index_word.get(i, "") for i in top_indices]


In [55]:
print("Tri:", predict_trigram("what are you"))
print("L2:", predict_lstm2("what are you doing"))
print("L3:", predict_lstm3("what are you telling"))

Tri: ['doing', 'going', 'talking']
L2: ['we', 'get', 'have']
L3: ['me', 'the', 'them']


In [56]:
print(predict_lstm3("i want to"),
predict_lstm3("i need to"),
predict_lstm3("do you think"))

['know', 'be', 'see'] ['know', 'see', 'go'] ['i', 'im', 'you']


In [57]:
from collections import defaultdict

def merge_predictions(tri_preds, l2_preds, l3_preds, k=3):

    scores = defaultdict(int)

    # Weighting
    for i, word in enumerate(l3_preds):
        scores[word] += 3 - i   # higher rank = more weight

    for i, word in enumerate(l2_preds):
        scores[word] += 2 - i

    for i, word in enumerate(tri_preds):
        scores[word] += 1 - i

    # Sort by score
    ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)

    return [word for word, _ in ranked[:k]]


In [88]:
def ultimate_predictor(sentence, k=3):
    """
    Full pipeline: preprocess -> split -> spell check -> predict.

    Priority order for next-word prediction:
        1. lstm3  (primary - strongest context model)
        2. trigram (secondary - reliable n-gram baseline)
        3. lstm2  (tertiary - weaker, used only as tiebreaker)

    Fallback logic:
        - If lstm3 returns results, they anchor the final output.
        - If lstm3 is empty/unavailable, trigram takes over as primary.
        - lstm2 can only influence results when it AGREES with lstm3 or trigram.
    """
    # --- Preprocessing pipeline ---
    sentence = preProcess(sentence)
    sentence = abc.splitWord(sentence)
    sentence = SpellingCheck(sentence)

    # --- Get predictions from all three sources ---
    l3_preds  = predict_lstm3(sentence,   k=5)
    tri_preds = predict_trigram(sentence, k=5)
    l2_preds  = predict_lstm2(sentence,   k=5)

    # --- Fallback chain: determine which source leads ---
    # If lstm3 has no results, fall back to trigram-only
    if not l3_preds and not tri_preds:
        # Absolute fallback: just return whatever lstm2 has
        return l2_preds[:k]

    if not l3_preds:
        # lstm3 failed: promote trigram to primary, allow lstm2 as secondary
        primary   = tri_preds
        secondary = l2_preds
        # Return top-k from trigram, then fill with l2 that aren't duplicates
        seen = set(primary[:k])
        result = list(primary[:k])
        for w in secondary:
            if len(result) >= k:
                break
            if w not in seen:
                result.append(w)
                seen.add(w)
        return result

    # --- Normal path: all three sources available ---
    # Suppress lstm2 predictions that don't appear in lstm3 OR trigram
    # This prevents noisy lstm2 outputs from corrupting the final ranking
    valid_context = set(l3_preds) | set(tri_preds)
    filtered_l2 = [w for w in l2_preds if w in valid_context]

    final_preds = merge_predictions(tri_preds, filtered_l2, l3_preds, k=k)

    # --- Safety: if merge returns fewer than k, fill from lstm3 then trigram ---
    seen = set(final_preds)
    for w in (l3_preds + tri_preds):
        if len(final_preds) >= k:
            break
        if w not in seen:
            final_preds.append(w)
            seen.add(w)

    return final_preds

In [89]:
ultimate_predictor("What are you giong")

['to', 'where', 'im']