In [1]:
import torch
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GNNExplainer

import numpy as np
import pandas as pd
import more_itertools as mit

from typing import List
from torch_geometric.data import InMemoryDataset

import matplotlib.pyplot as plt

In [2]:
def reasonable_notebook_defaults():
    r"""Notbook defaults"""
    import seaborn as sns
    import matplotlib.pyplot as plt
    sns.set_context("paper", font_scale=1.5)
    plt.rcParams['figure.figsize'] = [20, 10]

In [3]:
reasonable_notebook_defaults()

In [4]:
doc0 = ["The", "quick", "brown", "fox", "jumped", "over", "the", "lazy", "dog"]
doc1 = ["Welcome", "to", "the", "dog", "days", "of", "summer"]
docs = [doc0, doc1]
vocab = set(doc0 + doc1)

In [5]:
vocab_to_int = {term: token for token, term in enumerate(vocab)}
int_to_vocab = {token: term for term, token in vocab_to_int.items()}

In [6]:
def tokenize(sent):
    """Map sentence to tokens"""
    return [vocab_to_int[word] for word in sent]


def untokenize(sent):
    """Map tokens to sentence"""
    return [int_to_vocab[token] for token in sent]


def one_hot_encode(sent):
    """One hot encode a sentence"""
    return np.stack([token_to_one_hot(x) for x in sent])


def decode_one_hot(encoding) -> List[str]:
    """Decode a one hot encoding back to sentence"""
    return [one_hot_word_to_token(x) for x in encoding]


def token_to_one_hot(token):
    """One hot encode a token"""
    word = np.zeros(len(vocab), dtype=int)
    word[token] = 1
    return word


def one_hot_word_to_token(one_hot_word):
    """Convert one hot encoded word to token"""
    return int_to_vocab[np.argmax(one_hot_word)]

In [7]:
tokenized = [tokenize(x) for x in docs]
sentences = [untokenize(x) for x in tokenized]

In [8]:
sentences

[['The', 'quick', 'brown', 'fox', 'jumped', 'over', 'the', 'lazy', 'dog'],
 ['Welcome', 'to', 'the', 'dog', 'days', 'of', 'summer']]

In [9]:
onehot_encoded = one_hot_encode(tokenized[0])

In [10]:
one_hot_token = token_to_one_hot(tokenized[0][0])

In [11]:
one_hot_token

array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [12]:
one_hot_word_to_token(one_hot_token)

'The'

In [13]:
onehot_encoded

array([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
       [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]])

In [14]:
decode_one_hot(onehot_encoded)

['The', 'quick', 'brown', 'fox', 'jumped', 'over', 'the', 'lazy', 'dog']

In [27]:
adj = torch.from_numpy(onehot_encoded)
data = Data()
data.edge_index = adj.nonzero().t()

x = np.random.random(9)
data.x = torch.from_numpy(x)
data.y = torch.randint(0, 6, (9, ), dtype=torch.long)

In [28]:
data

Data(edge_index=[2, 9], x=[9], y=[9])

In [29]:
data.edge_index

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
        [ 2, 11,  1,  4,  0, 13,  5,  3,  6]])