**PTE**

PTE is a graph embedding algorithm for heterogeneous graph.

It uses text networks as the example. To construct a heterogeneous graph, we should define different sub-graphs. In this paper, the author define three sub-graphs: **word-word graph, word-document graph, and word-label graph**.

They construct word-word graph using a slide window. The weight of the edge between nodes is defined as the **number of times** that the two words co-occur in the context windows of a given window size. Remember, the edge of words are **directed**. In this level of graph, they want to capture the **statistical information** in documents.


And then, they construct word-document graph by simplily **matching** all the words to the documents where they appear. The weight between word and document is defined as the number of times **a word appears in one document**. In this level of graph, they want to explore a **higher level** of corpora.

Finally, they construct word-label graph by matching the words to the label of the document it belongs to. The weight of the edge between words and labels is defined as **$w_{ij}=∑_{d:l_d=j}n_{di}$**, where $n_{di}$ is the term frequency of word $v_i$ in document d, and $l_d$ is the class label of document d. In this level of graph, they want to capture the **label information**(graph data usually has a little labels) of the word.

After constructing all the sub-graphs, they use the **same objective of LINE** for all these three sub-graphs, so we could re-use the code of LINE.

During the training part, they use joint training, that is, **training three sub-graphs one by one** in one epoch.

After training, they get representations for words. And they claim that, if we want to get embedding for documents, simpliy use the mean of the embeddings of words in the document.

In [None]:
!wget http://www.cs.cornell.edu/people/pabo/movie-review-data/mix20_rand700_tokens_0211.tar.gz
!tar -zxvf mix20_rand700_tokens_0211.tar.gz

--2022-03-16 02:41:29--  http://www.cs.cornell.edu/people/pabo/movie-review-data/mix20_rand700_tokens_0211.tar.gz
Resolving www.cs.cornell.edu (www.cs.cornell.edu)... 132.236.207.36
Connecting to www.cs.cornell.edu (www.cs.cornell.edu)|132.236.207.36|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2194052 (2.1M) [application/x-gzip]
Saving to: ‘mix20_rand700_tokens_0211.tar.gz.18’


2022-03-16 02:41:30 (3.34 MB/s) - ‘mix20_rand700_tokens_0211.tar.gz.18’ saved [2194052/2194052]

diff.txt
README
tokens/
tokens/neg/
tokens/neg/cv303_tok-11557.txt
tokens/neg/cv000_tok-9611.txt
tokens/neg/cv001_tok-19324.txt
tokens/neg/cv002_tok-3321.txt
tokens/neg/cv003_tok-13044.txt
tokens/neg/cv004_tok-25944.txt
tokens/neg/cv005_tok-24602.txt
tokens/neg/cv006_tok-29539.txt
tokens/neg/cv007_tok-11669.txt
tokens/neg/cv008_tok-11555.txt
tokens/neg/cv009_tok-19587.txt
tokens/neg/cv010_tok-2188.txt
tokens/neg/cv011_tok-7845.txt
tokens/neg/cv012_tok-26965.txt
tokens/neg/cv013_tok-14854

In [None]:
import os
import re
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from collections import defaultdict
import numpy as np
from sklearn.cluster import KMeans
import networkx as nx
import matplotlib.pyplot as plt

In [None]:
def remove_tags(text):
  tags = re.compile(r'<[^>]+>')
  result = tags.sub(' ', text)
  return re.sub(r'([^\s\w]|_)+', '', result)

def get_imdb(mode):
  filedir = os.path.join('tokens', mode.lower())
  fs = []
  for filename in os.listdir(filedir):
    f = os.path.join(filedir, filename)
    fs.append(f)

  texts = []
  for f in fs:
    with open(f, encoding='utf-8') as inputf:
      try:
        texts += [remove_tags(' '.join(inputf.readlines())).lower().split()]
      except:
        continue
  return texts

labels = ['Neg', 'Pos']

texts = {}

for label in labels:
  texts[label] = random.sample(get_imdb(label), 20)

In [None]:
print(len(texts[labels[0]]))

20


In [None]:
item2id = {}
id2item = {}
count = 0
for label in labels:
  for text in texts[label]:
    for word in text:
      if word in item2id:
        continue
      item2id[word] = count
      id2item[count] = word
      count += 1


doc = "Doc_"
docid = 0
docsid = {}
for label in labels:
  docsid[label] = []
  for i in range(len(texts[label])):
    item2id[doc + str(docid)] = count
    id2item[count] = doc + str(docid)
    docsid[label].append(doc + str(docid))
    count += 1
    docid += 1

for label in labels:
  item2id[label] = count
  id2item[count] = label
  count += 1

print("num of items: {0}".format(len(item2id)))

num of items: 5533


In [None]:
windowsz = 2
wwedges = defaultdict(int)
for label in labels:
  for text in texts[label]:
    for i in range(len(text)):
      center = text[i]
      for j in range(i - windowsz, i + windowsz):
        if i != j and j >= 0 and j < len(text):
          context = text[j]
          wwedges[(item2id[center], item2id[context])] += 1

wdedges = defaultdict(int)
for label in labels:
  for i, text in enumerate(texts[label]):
    document = docsid[label][i]
    for word in text:
      wdedges[(item2id[document], item2id[word])] += 1
      wdedges[(item2id[word], item2id[document])] += 1

wledges = defaultdict(int)
for label in labels:
  for text in texts[label]:
    for word in text:
      wledges[(item2id[label], item2id[word])] += 1
      wledges[(item2id[word], item2id[label])] += 1

print("num of wwedges: {0}".format(len(wwedges)))
print("num of wdedges: {0}".format(len(wdedges)))
print("num of wledges: {0}".format(len(wledges)))

num of wwedges: 56375
num of wdedges: 27032
num of wledges: 13684


In [None]:
def gen_probs(dataset):
  nodes = defaultdict(int)
  nodeprobs = defaultdict(int)
  edges = defaultdict(int)
  edgeprobs = defaultdict(int)
  weightsum = 0
  nodeprobsum = 0
  power = 0.75

  for key in dataset:
    n1, n2 = key[0], key[1]
    w = dataset[key]
    nodes[int(n1)] += int(w)
    nodes[int(n2)] += 0
    nodeprobs[int(n1)] += int(w)
    nodeprobs[int(n2)] += 0
    edges[(int(n1), int(n2))] = int(w)
    edgeprobs[(int(n1), int(n2))] = int(w)
    weightsum += int(w)
    nodeprobsum += np.power(int(w), power)

  for node in nodeprobs:
    nodeprobs[node] = np.power(nodeprobs[node], power) / nodeprobsum

  for edge in edgeprobs:
    edgeprobs[edge] /= weightsum
  
  return nodeprobs, edgeprobs, nodes

wwnodeprobs, wwedgeprobs, wwnodes = gen_probs(wwedges)
wdnodeprobs, wdedgeprobs, wdnodes = gen_probs(wdedges)
wlnodeprobs, wledgeprobs, wlnodes = gen_probs(wledges)

In [None]:
class alias():
  def __init__(self, probs):
    self.n = len(probs)
    self.scaledprobs = {}
    self.table = {}
    self.aliastable = {}
    self.small = []
    self.big = []
    self.keys = list(probs.keys())

    for item in probs:
      prob = probs[item]
      self.scaledprobs[item] = prob * self.n
      if self.scaledprobs[item] > 1:
        self.big.append(item)
      elif self.scaledprobs[item] < 1:
        self.small.append(item)
      else:
        self.table[item] = 1
    
    while self.small and self.big:
      smallitem = self.small.pop()
      bigitem = self.big.pop()
      newprob = self.scaledprobs[bigitem] - (1 - self.scaledprobs[smallitem])
      self.table[smallitem] = self.scaledprobs[smallitem]
      self.aliastable[smallitem] = bigitem
      self.scaledprobs[bigitem] = newprob
      if self.scaledprobs[bigitem] > 1:
        self.big.append(bigitem)
      elif self.scaledprobs[bigitem] < 1:
        self.small.append(bigitem)
      else:
        self.table[bigitem] = 1
    
    while self.small:
      smallitem = self.small.pop()
      self.table[smallitem] = 1
    
    while self.big:
      bigitem = self.big.pop()
      self.table[bigitem] = 1

  def sampling_one(self):
    sample = random.choice(self.keys)
    if self.table[sample] >= random.uniform(0, 1):
      return sample
    else:
      return self.aliastable[sample]
  
  def sampling_n(self, n):
    samples = []
    for i in range(n):
      samples.append(self.sampling_one())
    return samples

In [None]:
wwnodealias = alias(wwnodeprobs)
wwedgealias = alias(wwedgeprobs)
wdnodealias = alias(wdnodeprobs)
wdedgealias = alias(wdedgeprobs)
wlnodealias = alias(wlnodeprobs)
wledgealias = alias(wledgeprobs)

In [None]:
def sampling_negedge(edgealias, nodealias, batchsz, edges, nodes):
  trainset = []
  edgesamples = edgealias.sampling_n(batchsz)
  nodesamples = nodealias.sampling_n(batchsz)
  for edge in edgesamples:
    n1 = edge[0]
    n2 = edge[1]
    trainset.append([n1, n2, 1])
    count = 0
    while(count < batchsz):
      n3 = random.choice(nodesamples)
      if (n1, n3) not in edges and n3 not in nodes:
        trainset.append([n1, n3, -1])
        break
      count += 1
  return trainset

def one_hot(node, nodes):
  vec = [0] * len(nodes)
  vec[node] = 1
  
  return vec

def tensor_trainset(trainset, nodes):
  vi = []
  vj = []
  labels = []
  for item in trainset:
    vi.append(one_hot(item[0], nodes))
    vj.append(one_hot(item[1], nodes))
    labels.append(item[2])
  return torch.Tensor(vi), torch.Tensor(vj), torch.tensor(labels)

In [None]:
lr = 0.01
batchsz = 64
epochs = 50
featuresz = 32

In [None]:
class PTE(nn.Module):
  def __init__(self, wordnum, nodenum, featuresz):
    super(PTE, self).__init__()
    self.wwembeddings = nn.Linear(nodenum, featuresz, bias=False)
    self.wdembeddings = nn.Linear(nodenum, featuresz, bias=False)
    self.wlembeddings = nn.Linear(nodenum, featuresz, bias=False)
    self.sigmoid = nn.LogSigmoid()
    self.wwembeddings.weight.data = self.wwembeddings.weight.data.uniform_(
                -.5, .5) / featuresz
    self.wdembeddings.weight.data = self.wdembeddings.weight.data.uniform_(
                -.5, .5) / featuresz
    self.wlembeddings.weight.data = self.wlembeddings.weight.data.uniform_(
                -.5, .5) / featuresz

  def forward(self, vi, vj, labels, traintype):
    viembeddings = self.wwembeddings(vi)
    if traintype == "ww":
      vjembeddings = self.wwembeddings(vj)
    elif traintype == "wd":
      vjembeddings = self.wdembeddings(vj)
    elif traintype == "wl":
      vjembeddings = self.wlembeddings(vj)
    inner_product = torch.sum(viembeddings * vjembeddings, 1)
    loss = -torch.sum(self.sigmoid(inner_product * labels))
    return loss

In [None]:
pte = PTE(len(wwnodes), len(item2id), featuresz)
optimier = optim.SGD(pte.parameters(), lr=lr)

In [None]:
def train():
  pte.train()
  for epoch in range(epochs):
    batchnum = len(wwedgeprobs) // batchsz
    for batch in range(batchnum):
      trainset = sampling_negedge(wwedgealias, wwnodealias, batchsz, wwedgeprobs, {})
      vi, vj, labels = tensor_trainset(trainset, item2id)
      pte.zero_grad()
      loss = pte(vi, vj, labels, "ww")
      loss.backward()
      optimier.step()

    batchnum = len(wdedgeprobs) // batchsz
    for batch in range(batchnum):
      trainset = sampling_negedge(wdedgealias, wdnodealias, batchsz, wdedgeprobs, wwnodes)
      vi, vj, labels = tensor_trainset(trainset, item2id)
      pte.zero_grad()
      loss = pte(vi, vj, labels, "wd")
      loss.backward()
      optimier.step()
    
    avgloss = 0
    batchnum = len(wledgeprobs) // batchsz
    for batch in range(batchnum):
      trainset = sampling_negedge(wledgealias, wlnodealias, batchsz, wledgeprobs, wwnodes)
      vi, vj, labels = tensor_trainset(trainset, item2id)
      pte.zero_grad()
      loss = pte(vi, vj, labels, "wl")
      loss.backward()
      optimier.step()
      avgloss += loss
    avgloss /= batchnum * batchsz

    print("epoch: {0}, loss: {1}".format(epoch, avgloss))

In [None]:
train()

epoch: 0, loss: 0.6465424299240112
epoch: 1, loss: 0.43026596307754517
epoch: 2, loss: 0.3696928322315216
epoch: 3, loss: 0.33595961332321167
epoch: 4, loss: 0.3167143762111664
epoch: 5, loss: 0.31535109877586365
epoch: 6, loss: 0.2873421013355255
epoch: 7, loss: 0.28124117851257324
epoch: 8, loss: 0.27097752690315247
epoch: 9, loss: 0.2706867456436157
epoch: 10, loss: 0.2567729353904724
epoch: 11, loss: 0.25000885128974915
epoch: 12, loss: 0.24768660962581635
epoch: 13, loss: 0.23336869478225708
epoch: 14, loss: 0.23113523423671722
epoch: 15, loss: 0.22924086451530457
epoch: 16, loss: 0.22927676141262054
epoch: 17, loss: 0.21980515122413635
epoch: 18, loss: 0.21156784892082214
epoch: 19, loss: 0.2071772813796997
epoch: 20, loss: 0.19911205768585205
epoch: 21, loss: 0.19647765159606934
epoch: 22, loss: 0.18719445168972015
epoch: 23, loss: 0.1766180843114853
epoch: 24, loss: 0.17514580488204956
epoch: 25, loss: 0.16466212272644043
epoch: 26, loss: 0.15361997485160828
epoch: 27, loss: 0.

In [None]:
embedding = pte.wwembeddings.weight.T
similaity = F.cosine_similarity(embedding.unsqueeze(1), embedding.unsqueeze(0), dim=2)
a, idx = torch.sort(similaity, descending=True)
k = 4
lists=idx[:,1:k+1]
for i in range(100):
  print("[{0}] is similar to ".format(id2item[i]), end="")
  for j in range(k):
    print("[{0}]".format(id2item[int(lists[i][j])]), end=" ")
  print()

[one] is similar to [course] [hand] [big] [is] 
[of] is similar to [satire] [course] [limousines] [predictability] 
[the] is similar to [audience] [lump] [ease] [subject] 
[indicator] is similar to [playboy] [inanity] [fortune] [tito] 
[badness] is similar to [deceit] [cavanaughs] [posttwin] [search] 
[in] is similar to [lump] [silence] [stinkers] [magazine] 
[film] is similar to [faded] [survivalofthefittest] [ala] [it] 
[is] is similar to [huddleston] [cool] [visiting] [corman] 
[hype] is similar to [bright] [loudmouth] [cheese] [plotline] 
[being] is similar to [sense] [second] [young] [least] 
[remembered] is similar to [bot] [miscast] [hates] [elevator] 
[more] is similar to [than] [youre] [church] [almost] 
[than] is similar to [more] [rather] [least] [gremlins] 
[itself] is similar to [away] [script] [anything] [nearly] 
[such] is similar to [being] [as] [original] [second] 
[was] is similar to [into] [wife] [had] [turtles] 
[case] is similar to [sense] [great] [makes] [good] 
[