In [1]:
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import LinearLR
from transformers import BertConfig, BertTokenizer, BertForMaskedLM

## Pretrained

In [2]:
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch
from transformers import BertConfig, BertTokenizer, BertForMaskedLM

In [3]:
model = BertForMaskedLM.from_pretrained("bert-base-chinese")
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
sent1 = "他開車時說話"
sent2 = "今天他很開心"
sent3 = "我畫全開海報"

In [5]:
def encode_sent(sent):
    batch = tokenizer(sent, return_tensors="pt")
    with torch.no_grad():    
        out = model(**batch, output_hidden_states=True)
        h = out.hidden_states[-1]
        h = h[0, 1:-1].detach().numpy()
        h /= np.linalg.norm(h, axis=-1)[:, np.newaxis]
        return h
enc1 = encode_sent(sent1)
enc2 = encode_sent(sent2)
enc3 = encode_sent(sent3)

In [6]:
plt.rcParams["font.family"]="Microsoft YaHei"
from sklearn.decomposition import PCA
from matplotlib import cm

pca = PCA()
proj = pca.fit_transform(np.vstack([enc1, enc2, enc3]))
groups = [0]*enc1.shape[0]+[1]*enc2.shape[0]+[2]*enc3.shape[0]

In [7]:

def plot_frame(frame_idx):
    fig, ax = plt.subplots(figsize=(5,5))
    chseqs = sent1+sent2+sent3
    ch_x = chseqs[frame_idx]
    tgt_idxs = [i for i, x in enumerate(chseqs) 
                if x=="開" and frame_idx>=i]
    ax.set_xlim(-.75,.75)
    ax.set_ylim(-.5,1.0)
    cmap = cm.get_cmap('Set1')
    plt.scatter(proj[:frame_idx+1,0], 
                proj[:frame_idx+1,1], 
                c=[cmap(g) for g in groups[:frame_idx+1]])
    
    n_sent1 = len(sent1)
    n_sent12 = len(sent1)+len(sent2)
    n_sent123 = len(groups)
    if frame_idx >= 0:
        end_idx = min(frame_idx+1, n_sent1)
        plt.plot(proj[:end_idx,0], 
                 proj[:end_idx,1], color=cmap(0))                
    if frame_idx >= n_sent1:
        end_idx = min(frame_idx+1, n_sent12)
        plt.plot(proj[n_sent1:end_idx,0], 
                 proj[n_sent1:end_idx,1], color=cmap(1))
    if frame_idx >= n_sent12:
        plt.plot(proj[n_sent12:frame_idx+1,0], 
                 proj[n_sent12:frame_idx+1,1], color=cmap(2))
    
    for g_idx, idx in enumerate(tgt_idxs):
        circle = plt.Circle((proj[idx,0], proj[idx, 1]), .05, 
                            facecolor='none', edgecolor=cmap(g_idx), linewidth=3)
        ax.add_artist(circle)
        
    plt.annotate(ch_x, (proj[frame_idx, 0], proj[frame_idx, 1]), fontsize=16)
    
    
    return fig

In [8]:
plt.ioff()
trace_dir = Path("../data/bert_traces")
trace_dir.mkdir(exist_ok=True)
for frame_idx in range(len(groups)):
    fig = plot_frame(frame_idx)    
    fig.savefig((trace_dir/f"trace_{frame_idx:02d}.jpg"))
    plt.close()