In [None]:
!git clone https://github.com/seantyh/morphert

Cloning into 'morphert'...
remote: Enumerating objects: 32, done.[K
remote: Counting objects: 100% (32/32), done.[K
remote: Compressing objects: 100% (28/28), done.[K
remote: Total 32 (delta 5), reused 30 (delta 3), pack-reused 0[K
Unpacking objects: 100% (32/32), done.


In [None]:
!pip install -q --progress-bar off transformers umap-learn opencc hdbscan functorch
!pip install -U -q gensim

[?25l
[?25h[?25l
[?25h[?25l
[?25h[?25l
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
[?25l
[?25h[?25l
[?25h[?25l
[?25h[?25l
[?25h[?25l
[?25h[?25l
[?25h  Building wheel for umap-learn (setup.py) ... [?25l[?25hdone
  Building wheel for pynndescent (setup.py) ... [?25l[?25hdone
  Building wheel for hdbscan (PEP 517) ... [?25l[?25hdone
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 24.1 MB 1.3 MB/s 
[?25h

In [None]:
import sys
if "./morphert/src" not in sys.path:
  sys.path.append("./morphert/src")

In [None]:
import pickle
from pathlib import Path
from itertools import groupby, combinations, permutations
from textwrap import wrap
from tqdm.auto import tqdm
from opencc import OpenCC

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.spatial.distance import pdist, squareform
from gensim.models import KeyedVectors

import torch
from functorch import jacrev, jacfwd
from transformers import BertTokenizer, BertModel, BertPreTrainedModel
from morphert.model import *

In [None]:
N = 500000
base_dir = Path("/content/drive/MyDrive/LangOn/morphert")
t2s = OpenCC("t2s").convert
with open(base_dir/"tencent_small_500k.pkl", "rb") as fin:
    (vocabs, embs) = pickle.load(fin)   
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
model = MorphertModel.from_pretrained(base_dir/"morphert_500k")
collator_fn = DataCollator(tokenizer)
model = model.to("cuda")
full_ds = MorphertDataset(np.arange(N), vocabs, embs)
full_emb = np.vstack([full_ds[i]["vec"] for i in range(N)])
in_tencent = lambda x: x in full_ds.vocabs

In [None]:
in_embeds = model.bert.embeddings.word_embeddings(torch.tensor([[101, 7442, 5582, 102]]).to("cuda"))
out_embeds = model(inputs_embeds=in_embeds)
out_tokens = model(**tokenizer("電腦", return_tensors="pt").to("cuda"))
torch.allclose(out_embeds.predictions, out_tokens.predictions)

In [None]:
class TCKeyedVectorWrap:
  def __init__(self, kv, t2s):
    self.kv = kv
    self.t2s = t2s

  def __contains__(self, key):
    skey = self.t2s(key)
    return key in self.kv or skey in self.kv

  def __getitem__(self, key):
    skey = self.t2s(key)
    if key in self.kv:
      return self.kv[key]
    elif skey in self.kv:
      return self.kv[skey]
    else:
      raise KeyError("kv and its simp. not found")
tencent_kv = KeyedVectors(100)
tencent_kv.add_vectors(vocabs, embs)
tencent_kv_wrap = TCKeyedVectorWrap(tencent_kv, t2s)

In [None]:
"一審" in tencent_kv_wrap, "一審" in tencent_kv

In [None]:
def compute_token_jacobian_functorch(tgt_word, model, tokenizer):    
    tokens = tokenizer([tgt_word], return_tensors="pt").to("cuda")    
    in_embeds = model.bert.embeddings.word_embeddings(tokens.input_ids)
    def partial_effect(x):                     
        out = model(inputs_embeds=x)
        return out.predictions
    J = jacrev(partial_effect, argnums=0)(in_embeds)
    return J

## Load char-noise dataset

In [None]:
import json
with open(base_dir / "affix_dataset.json", "r") as fin:
  affix_dataset = json.load(fin)

In [None]:
item_x = list(affix_dataset.items())[0]

In [None]:
item_x

In [None]:
from itertools import chain
ex_list = [x["ex"] for x in item_x[1]]
list(zip(*chain.from_iterable(ex_list)))

### Compute Jacobians for all targets and noise samples in the dataset

In [None]:
def mark_target(x, pos):
  xlist = list(x)
  xlist[pos] = f"<{x[pos]}>"
  return "".join(xlist)

def compute_pairwise_distances(tgt_char, words, Js, counter_position=False):
  dists = np.zeros((len(words), len(words)))
  for a, b in combinations(words, 2):
      # L2norm = np.sqrt(((Js[a] - Js[b])**2).sum())
      # Note Js is of size (1, 2, 768) for each bisyllabic word
      if counter_position:
        tgt_idx_a = 1-a.index(tgt_char)
        tgt_idx_b = 1-b.index(tgt_char)
      else:
        tgt_idx_a = a.index(tgt_char)
        tgt_idx_b = b.index(tgt_char)
      L1norm = np.abs(Js[a][:,tgt_idx_a,:] - Js[b][:,tgt_idx_b,:]).sum()
      idx_a = words.index(a)
      idx_b = words.index(b)
      dists[idx_a, idx_b] = dists[idx_b, idx_a] = L1norm      

  return dists

def compute_cross_distances(tgt_char, words, Js):
  # row is the canonical position
  # column is the counter position
  dists = np.zeros((len(words), len(words)))
  # compute self-cross
  for w in words:
      tgt_idx_a = w.index(tgt_char)
      tgt_idx_b = 1-w.index(tgt_char)
      L1norm = np.abs(Js[w][:,tgt_idx_a,:] - Js[w][:,tgt_idx_b,:]).sum()
      idx_w = words.index(w)
      dists[idx_w, idx_w] = L1norm

  # compute dists across words
  for a, b in permutations(words, 2):
      # L2norm = np.sqrt(((Js[a] - Js[b])**2).sum())      
      tgt_idx_a = a.index(tgt_char)
      tgt_idx_b = 1-b.index(tgt_char)      
      L1norm = np.abs(Js[a][:,tgt_idx_a,:] - Js[b][:,tgt_idx_b,:]).sum()
      idx_a = words.index(a)
      idx_b = words.index(b)
      dists[idx_a, idx_b] = L1norm

  return dists

def compute_pairwise_emb_distances(words, embs):
  dists = np.zeros((len(words), len(words)))
  dists[:] = np.NaN
  for a, b in combinations(words, 2):
    if not (a in embs and b in embs): continue
    cossim = KeyedVectors.cosine_similarities(embs[a], [embs[b]])[0]
    idx_a = words.index(a)
    idx_b = words.index(b)
    dists[idx_a, idx_b] = dists[idx_b, idx_a] = 1-cossim
    dists[idx_a, idx_a] = dists[idx_b, idx_b] = 0.
  return dists

def compute_pairwise_counter_emb_distances(tgt_char, words, embs):
  dists = np.zeros((len(words), len(words)))
  dists[:] = np.NaN
  for a, b in combinations(words, 2):    
    tgt_idx_a = 1-a.index(tgt_char)
    tgt_idx_b = 1-b.index(tgt_char)
    counter_a = a[tgt_idx_a]
    counter_b = b[tgt_idx_b]
    if not (counter_a in embs and counter_b in embs): continue
    cossim = KeyedVectors.cosine_similarities(
              embs[counter_a], [embs[counter_b]])[0]
    idx_a = words.index(a)
    idx_b = words.index(b)
    dists[idx_a, idx_b] = dists[idx_b, idx_a] = 1-cossim
    dists[idx_a, idx_a] = dists[idx_b, idx_b] = 0.
  return dists

def compute_pairwise_counter_bert_emb_distances(tgt_char, words, tokenizer, model):
  dists = np.zeros((len(words), len(words)))
  dists[:] = np.NaN
  for a, b in combinations(words, 2):    
    tgt_idx_a = 1-a.index(tgt_char)
    tgt_idx_b = 1-b.index(tgt_char)
    counter_a = a[tgt_idx_a]
    counter_b = b[tgt_idx_b]
    with torch.no_grad():
      tokens = tokenizer([counter_a, counter_b], return_tensors="pt").to("cuda")    
      in_embeds = model.bert.embeddings.word_embeddings(tokens.input_ids).cpu()      
    embed_ca = in_embeds[0,1,:].numpy().squeeze()
    embed_cb = in_embeds[1,1,:].numpy().squeeze()
    cossim = KeyedVectors.cosine_similarities(
              embed_ca, [embed_cb])[0]
    idx_a = words.index(a)
    idx_b = words.index(b)
    dists[idx_a, idx_b] = dists[idx_b, idx_a] = 1-cossim
    dists[idx_a, idx_a] = dists[idx_b, idx_b] = 0.
  return dists

In [None]:
compute_pairwise_counter_bert_emb_distances("人", ["男人", "女人","法人"], tokenizer, model)

array([[0.        , 0.34849578, 0.99376063],
       [0.34849578, 0.        , 0.92560776],
       [0.99376063, 0.92560776, 0.        ]])

In [None]:
assert (compute_pairwise_counter_emb_distances("人", ["法人", "成人"], tencent_kv_wrap)[1,0]
        == 1-KeyedVectors.cosine_similarities(tencent_kv_wrap["法"], [tencent_kv_wrap["成"]])[0])

In [None]:
7
import random
from itertools import chain

rng = random.Random(123)

char_dists = {}

for char, usages in tqdm(affix_dataset.items()):
  ex_list = [x["ex"] for x in usages]
  ex_labels = sum([[i] * len(x["ex"]) for i, x in enumerate(usages)], [])
  words, freq_list = list(zip(*chain.from_iterable(ex_list)))    
  
  # compute Jacobians for each word
  J_buf = {}  
  for word in words:
    if word in J_buf: continue
    J = compute_token_jacobian_functorch(word, model, tokenizer).squeeze().detach().cpu().numpy()
    # subset the second token (the first character, considering the [CLS] offset)
    J_buf[word] = J[:,1:3,:]
  
  ## predicted word embeddings
  with torch.no_grad():    
    preds = model(**tokenizer(words, return_tensors="pt").to("cuda")).predictions.cpu().numpy()
  pred_embs = {w: preds[i,:] for i, w in enumerate(words)}
  char_dists[char] = dict(
    ex_labels = ex_labels,
    J_target_dists = compute_pairwise_distances(char, words, J_buf, False),
    J_counter_dists = compute_pairwise_distances(char, words, J_buf, True),
    J_cross_dists = compute_cross_distances(char, words, J_buf),
    pred_emb_dists = compute_pairwise_emb_distances(words, pred_embs),
    tenc_emb_dists = compute_pairwise_emb_distances(words, tencent_kv_wrap),
    counter_emb_dists = compute_pairwise_counter_emb_distances(char, words, tencent_kv_wrap),
    counter_bert_dists = compute_pairwise_counter_bert_emb_distances(char, words, tokenizer, model)
  )  



  0%|          | 0/796 [00:00<?, ?it/s]

  allow_unused, accumulate_grad=False)  # Calls into the C++ engine to run the backward pass


In [None]:
## cache J_buf
from pathlib import Path
import pickle
base_dir = Path("/content/drive/MyDrive/LangOn/morphert")
with open(base_dir / "affix_dists.pkl", "wb") as fout:
  pickle.dump(char_dists, fout)


## Compute Silhouette scores

In [None]:
from sklearn.metrics import silhouette_score

def compute_silhouette_score(char_item, field_name, seed=12345):
  ex_labels = np.array(char_item["ex_labels"])
  # ex_labels = char_item["ex_labels"]
  dist_mat = char_item[field_name]
  nan_mask = ~np.isnan(dist_mat).all(axis=0)
  dist_mat = dist_mat[nan_mask, :][:, nan_mask]
  ex_labels = ex_labels[nan_mask]
  
  real_score = silhouette_score(
                  dist_mat, ex_labels, 
                  metric="precomputed")
  
  rng = np.random.RandomState(seed)
  # rng = random.Random(seed)
  rand_scores = []  
  for _ in range(1000):  
    # rand_labels = np.array(rng.sample(ex_labels.tolist(), len(ex_labels)))
    # rand_labels = rng.sample(ex_labels, len(ex_labels))    
    rand_labels = rng.choice(ex_labels, len(ex_labels), replace=False)        
    rand_score = silhouette_score(
                  dist_mat, rand_labels, 
                  metric="precomputed")
    rand_scores.append(rand_score)
  q05, q50, q95 = np.quantile(rand_scores, [.05, .50, .95])
  return real_score, (q05, q50, q95)

char_item = char_dists["一"]

fields = "J_target_dists,J_counter_dists,pred_emb_dists,tenc_emb_dists".split(",")
for field_name in fields:  
  real_score, rand_qs = compute_silhouette_score(char_item, field_name, 12)
  print(f"-- {field_name} --")
  print("Sample score: ", real_score)
  print("Random score: ", rand_qs)


-- J_target_dists --
Sample score:  0.06990259301319653
Random score:  (-0.026598284265533328, -0.008800670786540333, 0.06652363427146155)
-- J_counter_dists --
Sample score:  0.05385726641541365
Random score:  (-0.019728099748186043, -0.0034276391316780917, 0.032545640957437956)
-- pred_emb_dists --
Sample score:  0.24976651495779315
Random score:  (-0.1241113815891531, -0.0005817292233092575, 0.2719705927771045)
-- tenc_emb_dists --
Sample score:  0.26224155810702215
Random score:  (-0.08484478347935138, -0.024120759276518026, 0.2235869859362129)
