In [1]:
import networkx as nx
import numpy as np

from random import choices

In [2]:
vocab_G = nx.read_gml('./Graphs/corpus_vocab.gml')

In [38]:
class NetworkLM():
#     def __init__(self, graph_path, freq = 'freq', count = 'count') -> None:
    def __init__(self, G, freq = 'freq', count = 'count') -> None:
        """Initialise the language model
        Parameters:
            graph_path (str): filepath to the network GML
            freq (str): the node attribute that stores word frequency (default: freq)
            count (str): the edge attribute that stores edge frequency (default: count)
        Returns:
            None"""
        
#         self.G = nx.read_gml(graph_path)
        self.G = G
        self.freq = freq
        self.count = count
        self.SENT_BEG = '<s>'
        self.SENT_END = '</s>'
        
    def k_most_common_from(self, target, k = 10) -> dict:
        """Find the k most common words after a target word
        Parameters:
            target (str): the target word
            k (int): the limit (default: 10), if None then all are retrieved
        Returns:
            dict: next words and their probabilities sorted desc"""
        
        # find all possible next words and counts using out-edges of target
        next_words = {i[1]: i[2][self.count] for i in self.G.out_edges(target, data = True)}
        
        if len(next_words) > 0:
            total = sum(next_words.values())  # calculate total out-edges
            
            # get next words sorted desc by probability
            next_words = sorted({i: next_words[i] / total for i in next_words}.items(), key = lambda x: x[1], reverse = True)
            
            if k:
                return dict(next_words[: k])
            else:
                return dict(next_words)
            
        return dict()
    
    def k_most_common_to(self, target, k = 10) -> dict:
        """Find the k most common words before a target word
        Parameters:
            target (str): the target word
            k (int): the limit (default: 10), if None then all are retrieved
        Returns:
            dict: prev words and their probabilities sorted desc"""
        
        # find all possible prev words and counts using in-edges of target
        prev_words = {i[0]: i[2][self.count] for i in self.G.in_edges(target, data = True)}
        
        if len(prev_words) > 0:
            total = sum(prev_words.values())  # calculate total in-edges
            
            # get prev words sorted desc by probability
            prev_words = sorted({i: prev_words[i] / total for i in prev_words}.items(), key = lambda x: x[1], reverse = True)
            
            if k:
                return dict(prev_words[: k])
            else:
                return dict(prev_words)
            
        return dict()
    
    def perplexity(self, prob, n) -> float:
        """Calculate the perplexity given the probability
        Parameters:
            prob (float): the probability
        Returns:
            float: the perplexity"""
    
        return prob ** (-1 / n)
    
    def generate_sentence_shannon(self, seed, max_len = 10, mode = 1) -> (list, float):
        """Generate a sentence from a seed word
        Parameters:
            seed (str): the seed word
            max_len (int): the max sentence length (default: 10)
            mode (int): mode of operation - 1 uses out-edges, 0 uses in-edges (default: 1)
        Returns:
            list: the tokens of the generated sentence
            float: the sentence probability"""
        
        score = 0
        sentence, sent_len = [], 0 if seed == self.SENT_BEG else 1
        
        sentence.append(seed)  # append the seed to the empty sentence
        
        # generate the next words
        while sent_len <= max_len:
            words = None
            
            if mode == 1:  # get possible words using out edges
                words = self.k_most_common_from(sentence[-1], k = None)
            else:
                words = self.k_most_common_to(sentence[-1], k = None)
                
            if len(words) == 0:
                break
            
            # select word and add to sentence
            word = choices(list(words.keys()), list(words.values()))
            word = word[0]
            
            sentence.append(word)
            sent_len += 1
            
            score += np.log10(words[word])  # get probability of the selected word 
            
            # break conditions
            if mode == 1 and word == self.SENT_END:
                break
            elif mode == 0 and word == self.SENT_BEG:
                break
                
        if mode == 1:
            if sentence[-1] != self.SENT_END:
                sentence.append(self.SENT_END)

            if sentence[0] != self.SENT_BEG:
                sentence.insert(0, self.SENT_BEG)
                
        else:
            if sentence[-1] != self.SENT_BEG:
                sentence.append(self.SENT_BEG)
                
            if sentence[0] != self.SENT_END:
                sentence.insert(0, self.SENT_END)
                
        return sentence, 10 ** score
    
    def generate_sentence_inside_out(self, seed, before = 5, after = 4) -> (str, float):
        """"""
        
        sb, p = self.generate_sentence_shannon(seed, before, mode = 0)
        
        sa, p = self.generate_sentence_shannon(seed, after)
        
        print(' '.join(sb[:: -1]))
        print(' '.join(sa))
        
#     def generate_sentences(self, technique, seed, n = None, before = None, after = None) -> (list, list):
#         """"""
        
#         sentences, probs = []
        
#         if technique == 1:
#             for _ in range(n):
#                 sentence, prob = self.generate_sentence_shannon(seed, n)
                
#                 sentence.append(' '.join(sentence))
#                 probs.append(prob)
                
#         return sentences, probs

In [39]:
nlm = NetworkLM(vocab_G)

In [40]:
nlm.generate_sentence_inside_out('fogg')

<s> to a fan and mr fogg </s>
<s> fogg </s>
