In [2]:
import pickle
import os
from tqdm import tqdm
import copy


class TokenGraphEmbedding():
    def __init__(self, load_path=None, save_path="default.pkl"):
        self.num_nodes = 0
        self.tot_weight = 0
        self.all_tokens = []  
        self.adj = {} 
        self.save_path = save_path
        self.w2i = {}

        if(load_path):
            self.load(path=load_path)

    def clean_line(self, line, bad=[',','.', ';', '(', ')', '/', '`', '%', '"', '-', '\\','\'',]):
        clean = ''
        for c in line:
            if c not in bad:
                clean += c
        return clean
    
    def save(self, path="default.pkl"):
        with open(os.path.join("./saved",path), 'wb') as handle:
            pickle.dump(self, handle)
    
    def init_from_obj(self,b):
        self.num_nodes = b.num_nodes
        self.all_tokens = b.all_tokens
        self.adj = b.adj
        self.tot_weight = b.tot_weight
    
    def load(self, path="default.pkl"):
        with open(os.path.join("./saved",path), 'rb') as handle:
            b = pickle.load(handle)
        self.init_from_obj(b)

    def display(self, print_adj=False):
        print("num_nodes: {}".format(self.num_nodes))
        print("tot_weight: {}".format(self.tot_weight))
        print("all_tokens: {}".format(self.all_tokens))
        
        if(print_adj):
            print("\nAdjacency Matrix: \n")
            for token in self.all_tokens:
                print(token)
                for neigh in self.adj[token]:
                    print("\t {} : {}".format(neigh,self.adj[token][neigh]))
                print()

    def update_link(self, token_src, token_dst, weight, link_type="bidirectional",stop=False):
        #check if this token exists
        if(token_src not in self.adj):
            self.adj[token_src] = {}
            self.all_tokens.append(token_src)
            self.w2i[token_src] = self.num_nodes
            self.num_nodes += 1
        #check if prior link exists
        if token_dst not in self.adj[token_src]:
            self.adj[token_src][token_dst] = 0
        # update dst -> src link
        self.adj[token_src][token_dst] += weight     
        self.tot_weight += weight   
        if(link_type=="bidirectional"):
            if(not stop):
                self.update_link(token_dst, token_src, weight, stop=True)

    def update_graph(self, from_file):
        with open(from_file, "r") as f:
            sents = f.readlines()
        for sent in sents:
            sent = self.clean_line(sent)
            tokens = sent.split(' ')
            for i in tqdm(range(len(tokens))):
                if(i == 0):
                    continue
                self.update_link(token_src=tokens[i], token_dst=tokens[i-1], weight=1)
        
        trythis.save()    


In [3]:
# t = TokenGraphEmbedding(load_path="./default/default.pkl")
t = TokenGraphEmbedding()
t.update_graph(from_file='./sample.txt')
t.display()
    

num_nodes: 55
tot_weight: 992
all_tokens: ['allows', 'PPLM', 'a', 'user', 'to', 'flexibly', 'plug', 'in', 'one', 'or', 'more', 'tiny', 'attribute', 'models', 'representing', 'the', 'desired', 'steering', 'objective', 'into', 'large', 'unconditional', 'language', 'model', 'LM', 'The', 'method', 'has', 'key', 'property', 'that', 'it', 'uses', 'as', 'is—no', 'training', 'finetuning', 'is', 'required—which', 'enables', 'researchers', 'leverage', 'bestinclass', 'LMs', 'even', 'if', 'they', 'do', 'not', 'have', 'extensive', 'hardware', 'required', 'train', 'them']


In [2]:
def remove_lbls(path):
    with open(path, 'r') as f:
        with open(path + '.cleaned', 'w') as f2:
            for line in f:
                line = line.split(',')[0]
                f2.write(line + '\n')

In [5]:
remove_lbls('./data/annotated/yelp-new/valid.csv')