In [1]:
import numpy as np
import torch
import tqdm
from desi_llm.common.utils import shard
from llm_bases.chatglm6b import ChatGML6B

chatglm6b = ChatGML6B()



  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:17<00:00,  2.21s/it]


In [2]:
def get_word_embedding(word: str):
    token_id = chatglm6b.tokenizer(word)['input_ids'][0]
    embedding = chatglm6b.condgen.transformer.word_embeddings.weight[token_id]
    embedding = embedding.float().to("cuda:2")
    return embedding

In [19]:
from matplotlib.pylab import lstsq


def test_recover_word(word: str, noise_scale: float, n_shards: int=10, lstsq_batch: int=32):
    raw_embedding = get_word_embedding(word)
    print(f"Embedding scale: {torch.std(raw_embedding).item():.4f}")
    input_embedding = raw_embedding + torch.normal(0, noise_scale, raw_embedding.shape, device=raw_embedding.device, dtype=raw_embedding.dtype)
    random_shards, random_coefs = shard(input_embedding, n_shards)
    print(f"Random coefs: {random_coefs}, sum: {torch.sum(random_coefs).item()}")

    def create_tensor(x):
        return torch.tensor(x, dtype=raw_embedding.dtype, device=raw_embedding.device)

    k, d = random_shards.shape
    solved_coefs = torch.linalg.lstsq(
        torch.cat([random_shards.T, create_tensor([[1.]]).expand(1, k)], dim=0), 
        torch.cat([raw_embedding[..., None], create_tensor([[1.]])], dim=0))[0]
    #  [d, k], [k, 1]          [d, 1]
    err = random_shards.T @ solved_coefs - raw_embedding[:, None]  # [batch, dim, 1]
    print(f"Errors by given the original embedding: {torch.sqrt(torch.mean(torch.square(err))).item():4f}, sum: {torch.sum(solved_coefs).item()}")

    least_square_errors = []
    all_coefsums = []
    for i in tqdm.tqdm(range(chatglm6b.n_tokens // lstsq_batch + 1)):
        candiate_embedding = chatglm6b.condgen.transformer.word_embeddings.weight.data[i * lstsq_batch : (i + 1) * lstsq_batch].float().to(random_shards.device)
        
        solved_coefs = torch.linalg.lstsq(
            torch.cat([random_shards.T[None, ...], create_tensor([[[1.]]]).expand(1, 1, k)], dim=1), 
            torch.cat([candiate_embedding[..., None], create_tensor([[[1.]]]).expand(lstsq_batch, 1, 1)], dim=1))[0]
        # solved_coefs = torch.linalg.lstsq(random_shards.T[None, ...], candiate_embedding[..., None])[0]
        # [batch, d, k], [batch, k, 1]          [batch, d, 1]


        
        all_coefsums.extend(torch.sum(solved_coefs[..., 0], dim=-1).tolist())
        errs = random_shards.T[None, ...] @ solved_coefs - candiate_embedding[..., None]  # [batch, d, 1]
        least_square_errors.extend(torch.sqrt(torch.mean(torch.square(errs[:, :, 0]), dim=-1)).tolist())
    # Check the goodness of the solution

    asc_indices = np.argsort(least_square_errors)
    pos = np.searchsorted(np.array(least_square_errors)[asc_indices], noise_scale)
    print(f"Order of the target: {pos}")
    for i in asc_indices[max(0, pos - 1000): pos + 1000]:
        if i == asc_indices[pos]:
            print("================")
        if 0.9 < all_coefsums[i] < 1.1:
            print(f"{least_square_errors[i]:.6f}\t{i}\t{chatglm6b.tokenizer.decode(i)}\t{all_coefsums[i]:.4f}")

    


In [23]:
test_recover_word("Hearing", 0.007, 40)

Embedding scale: 0.0113
Random coefs: tensor([ 0.0112,  0.3430, -0.7242,  3.1021, -3.1991,  4.5910,  4.4832, -1.0496,
        -3.5800, -0.7639,  0.6073,  0.6704,  2.4500,  4.5964,  4.4412,  1.0909,
         2.1475,  3.9489, -0.6556, -3.4992,  1.0467, -3.2841,  4.9330,  0.9770,
        -2.2606,  3.8003, -1.8780,  1.6394,  0.2199, -1.9270, -0.7864,  1.7347,
        -1.1800,  3.6916, -0.1697, -3.6374,  2.7732, -4.0738, -4.4689, -0.6126],
       device='cuda:2'), sum: 15.548666954040527
Errors by given the original embedding: 0.005953, sum: 0.9999997615814209


  0%|          | 0/4063 [00:00<?, ?it/s]

100%|██████████| 4063/4063 [00:24<00:00, 163.50it/s]


Order of the target: 46
0.005891	90585	总结	1.0000
0.005900	82741	不知道	1.0000
0.005902	92799	中国的	1.0000
0.005905	130006		1.0000
0.005911	86076	等的	1.0000
0.005911	130007		1.0000
0.005912	73921	这样的	1.0000
0.005912	130003		1.0000
0.005916	84185	的中	1.0000
0.005920	73846	每	1.0000
0.005920	83503	看看	1.0000
0.005922	73845	以上	1.0000
0.005923	83453	现代	1.0000
0.005923	86190	她们	1.0000
0.005923	130002		1.0000
0.005928	57942	allClassesLink	1.0000
0.005929	82910	香港	1.0000
0.005935	3		1.0000
0.005953	33689	Hearing	1.0000
0.005954	1		1.0000
0.006021	94561	基督教上帝	1.0000
0.006102	83137	第一次	1.0000
0.006107	125749	в	1.0000
0.006158	57245	fzk	1.0000
0.006178	74291	你们	1.0000
0.006199	61911	executionOrder	1.0000
0.006208	129853		1.0000
0.006230	36903	assetBundle	1.0000
0.006344	52326	textureFormat	1.0000
0.006370	125860	не	1.0000
0.006428	52513	maxTextureSize	1.0000
0.006566	105521	基督教基督教基督教亿亿亿	1.0000
0.006589	124660	上帝阿门亿亿亿亿亿亿亿	1.0000
0.006630	93830	立即前往	1.0000
0.006649	58964	cnespace	1.0000
0.006666	125836	и	1