In [153]:
from typing import List
import sys
from pathlib import Path
from collections import Counter
import re
from abc import ABC, abstractmethod

import jax
import jax.numpy as jnp
from jax import vmap
from jax.experimental import sparse as jaxsparse

import numpy as np

import io 

from scipy.sparse import coo_array, linalg

from functools import partial
import pickle

from deepscratch.models.sequence.tokeniser import Tokeniser, WordTokeniser, HFTokeniser

In [154]:
with open("/Users/willgilchrist/dev/deeplearning/data/books/timemachine.txt", "rt") as f:
    tokeniser = HFTokeniser('bert-base-uncased')
    tokeniser.tokenise(f.read())

In [None]:
def get_nearest_words(embed, target_word, n):
    """
    Return the n closest words to target_words with reference to embed.

    Measured by cosine similarity.
    """
    # get target embedding
    target_embedding = embed.word_embeddings[
        embed.corpus.word2idx[target_word]
    ]

    cosine_sim = (
        (embed.word_embeddings[1:] * target_embedding).sum(axis=1)
        / jnp.sqrt(
            (embed.word_embeddings[1:] ** 2).sum(axis=1)
            * (target_embedding ** 2).sum()
        )
    )
    
    euc_dist = (
        ((embed.word_embeddings[1:] - target_embedding) ** 2).sum(axis=1)
    )
    min_idx = jnp.argpartition(cosine_sim, -n)[-n:] + 1 # add 1 as we removed first row of nans

    # print to stdout
    print(f"Target token: {target_word}")
    print(f"Nearest words (no order): {[embed.corpus.word2idx.reversed[i] for i in min_idx.tolist()]}")

In [156]:
class Embedder(ABC):

    def __init__(self):
        self.embed = jax.vmap(self.embed)
   
    @staticmethod
    @abstractmethod
    def embed(idx: int) -> jnp.array:
        pass

# OHE

In [157]:
class OHE(Embedder):

    def __init__(self, f, tokeniser: Tokeniser):
        
        self.tokeniser = tokeniser

        f.seek(0)
        self.tokens = self.tokeniser.tokenise(f.read())
        self.embed = partial(self.embed, N=len(self.tokeniser.idx_to_token))
        super().__init__()

    @staticmethod
    def embed(idx: int, N: jnp.array):
        return jnp.zeros(N).at[idx].set(1)

In [158]:
with io.StringIO("The quick brown fox jumped over the lazy dog.") as f: 
    embedder = OHE(f, tokeniser)

token = "fox"
idx = embedder.tokeniser.token_to_idx[token]
embedder.embed(jnp.array([idx]))

Array([[0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

# LSA

In [159]:
class LSA(Embedder):
    def __init__(self, embeddings, tokeniser_hash, weighting, token_to_idx, idx_to_token=None):
        self.embeddings = embeddings
        self.tokeniser_hash = tokeniser_hash
        self.token_to_idx = token_to_idx
        self.weighting = weighting

        if idx_to_token is not None:
            self.idx_to_token = idx_to_token
        else:
            self.idx_to_token = {v: k for k, v in token_to_idx.items()}

        self.embed = partial(
            self.embed,
            embeddings=self.embeddings
        )
        super().__init__()

    @classmethod
    def from_file(cls, f, tokeniser, window_len=20, k=100, weighting="ppmi"):
        f.seek(0)
        tokens = tokeniser.tokenise(f.read())
        f.close()

        embeddings = cls._compute_word_embeddings(tokens, tokeniser, weighting, window_len, k)
        return cls(embeddings, hash(tokeniser), weighting, tokeniser.token_to_idx, tokeniser.idx_to_token)
    
    @classmethod
    def from_cache(cls, f):
        tokeniser_hash, embeddings, token_to_idx, weighting = pickle.load(f)
        return cls(embeddings, tokeniser_hash, weighting, token_to_idx)

    @staticmethod
    def _compute_word_embeddings(tokens, tokeniser, weighting, window_len, k):
        n_words = len(tokeniser.token_to_idx)
        n_windows = len(tokens) - window_len + 1
        tdfm_np = np.zeros((n_words, n_windows), dtype=np.float32)
        print(tdfm_np.shape)

        for window in range(n_windows):
            for token in tokens[window:window+window_len]:
                tdfm_np[token, window] += 1

        if weighting in {"pmi", "ppmi"}:
            word_freqs = tdfm_np.sum(axis=1, keepdims=True)
            context_freqs = tdfm_np.sum(axis=0, keepdims=True)
            total_count = tdfm_np.sum()
            p_wc = tdfm_np / total_count
            p_w = word_freqs / total_count
            p_c = context_freqs / total_count
            with np.errstate(divide='ignore', invalid='ignore'):
                pmi = np.log2(p_wc / (p_w @ p_c))
                pmi[np.isnan(pmi)] = 0  # Replace NaNs with 0
                pmi[np.isinf(pmi)] = 0  # Replace infinities with 0

            if weighting == "ppmi":
                pmi = np.maximum(pmi, 0)

            tdfm_np = pmi
        
        U, S, _ = np.linalg.svd(tdfm_np, full_matrices=False)
        U_k = U[:, :k]
        S_k = S[:k]
        
        embeddings = U_k * S_k[None, :]  # shape (N, k)
        embeddings = jnp.array(embeddings)
        return embeddings
    
    def cache(self, f):
        pickle.dump((self.tokeniser_hash, self.embeddings, self.token_to_idx, self.weighting), f)

    @staticmethod
    def embed(idx: int, embeddings: jnp.array):
        """
        Given a token index, returns its latent semantic vector.
        """
        return embeddings[idx]

In [160]:
tokeniser = WordTokeniser()
with open("/Users/willgilchrist/dev/deeplearning/data/books/timemachine.txt") as f: 
    tokens = tokeniser.tokenise(f.read())

In [147]:
with open("/Users/willgilchrist/dev/deeplearning/data/books/timemachine.txt") as f: 
    embedder = LSA.from_file(f, tokeniser, window_len=30, k=50, weighting="ppmi")

(4609, 32781)


In [148]:
def nearest_words(
    token: str,
    embedder: LSA,
    n_words: int = 10
):
    embeddings = embedder.embeddings
    idx = embedder.token_to_idx[token]
    word_embedding = embedder.forward(jnp.array([idx]))
    

    distances = ((embeddings - word_embedding) ** 2).sum(axis=1)
    nearest_idxs = jnp.argpartition(distances, n_words)[:10]
    nearest_words = [embedder.idx_to_token[i.item()] for i in nearest_idxs]

    return nearest_words

In [None]:
token = "dark"
nearest_words(token, embedder)

['dark',
 'struck',
 'daylight',
 'drove',
 'flinging',
 'thinking',
 'confidence',
 'moon',
 'box',
 'near']

## BERT

In [55]:
tokeniser = HFTokeniser('bert-base-uncased')

In [70]:
class HFEmbedder:

    def __init__(self, model_name, *args, **kwargs):
        self._model = FlaxAutoModel.from_pretrained(model_name, *args, **kwargs)
    
    def embed(self, tokens):
        return self._model(tokens.reshape(-1,1)).last_hidden_state[:,0,:]

hf_embedder = HFEmbedder('bert-base-uncased')

Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: {('pooler', 'dense', 'kernel'), ('pooler', 'dense', 'bias')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
with open("/Users/willgilchrist/dev/deeplearning/data/books/timemachine.txt", "rt") as f:
    text = f.read()

corpus = jnp.unique(tokeniser.tokenise(text))
corpus_embeddings = hf_embedder.embed(corpus)

In [None]:
TOKEN = "sun"
n_words = 10
token_id = tokeniser.tokenise(TOKEN)[1:-1]
assert token_id.shape[0] == 1, f"{token_id}"
word_embedding = hf_embedder.embed(token_id)

distances = (
    ((corpus_embeddings - word_embedding) ** 2).sum(axis=1) 
    / (jnp.linalg.norm(corpus_embeddings, axis=1) * (jnp.linalg.norm(word_embedding)))
)
nearest_idxs = jnp.argpartition(distances, n_words)[:n_words]
nearest_tokens = corpus[nearest_idxs]
nearest_words = [tokeniser.idx_to_token[i.item()] for i in nearest_tokens]

print(f"{n_words} nearest tokens to '{TOKEN}':\n - "+"\n - ".join(nearest_words))

10 nearest tokens to 'sun':
 - sun
 - ##ng
 - dream
 - remark
 - puzzle
 - ##ath
 - cad
 - weed
 - ##ffin
 - know
