In [1]:
import numpy as np
import torch
import tqdm
from desi_llm.common.utils import shard
from llm_bases.chatglm6b import ChatGML6B

chatglm6b = ChatGML6B()



  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 8/8 [00:17<00:00,  2.21s/it]


In [2]:
def get_word_embedding(word: str):
    token_id = chatglm6b.tokenizer(word)['input_ids'][0]
    embedding = chatglm6b.condgen.transformer.word_embeddings.weight[token_id]
    embedding = embedding.float().to("cuda:2")
    return embedding

In [3]:
from matplotlib.pylab import lstsq


def test_recover_word(word: str, noise_scale: float, n_shards: int=10, lstsq_batch: int=32):
    raw_embedding = get_word_embedding(word)
    print(f"Embedding scale: {torch.std(raw_embedding).item():.4f}")
    input_embedding = raw_embedding + torch.normal(0, noise_scale, raw_embedding.shape, device=raw_embedding.device, dtype=raw_embedding.dtype)
    random_shards, random_coefs = shard(input_embedding, n_shards)
    print(f"Random coefs: {random_coefs}, sum: {torch.sum(random_coefs).item()}")

    def create_tensor(x):
        return torch.tensor(x, dtype=raw_embedding.dtype, device=raw_embedding.device)

    k, d = random_shards.shape
    solved_coefs = torch.linalg.lstsq(random_shards.T, raw_embedding[:, None])[0]
    #  [d, k], [k, 1]          [d, 1]
    err = random_shards.T @ solved_coefs - raw_embedding[:, None]  # [batch, dim, 1]
    print(f"Errors by given the original embedding: {torch.sqrt(torch.mean(torch.square(err))).item():4f}, sum: {torch.sum(solved_coefs).item()}")

    least_square_errors = []
    all_coefsums = []
    for i in tqdm.tqdm(range(chatglm6b.n_tokens // lstsq_batch + 1)):
        candiate_embedding = chatglm6b.condgen.transformer.word_embeddings.weight.data[i * lstsq_batch : (i + 1) * lstsq_batch].float().to(random_shards.device)
        solved_coefs = torch.linalg.lstsq(
            torch.cat([random_shards.T[None, ...], create_tensor([[[1.]]]).expand(1, 1, k)], dim=1), 
            torch.cat([candiate_embedding[..., None], create_tensor([[[1.]]]).expand(lstsq_batch, 1, 1)], dim=1))[0]
         [batch, d, k], [batch, k, 1]          [batch, d, 1]
        # solved_coefs = torch.linalg.lstsq(random_shards.T[None, ...], candiate_embedding[..., None])[0]

        # all_coefsums.extend(torch.sum(solved_coefs[..., 0], dim=-1).tolist())
        # errs = random_shards.T[None, ...] @ solved_coefs - candiate_embedding[..., None]  # [batch, d, 1]
        # least_square_errors.extend(torch.sqrt(torch.mean(torch.square(errs[:, :, 0]), dim=-1)).tolist())
    # Check the goodness of the solution

    asc_indices = np.argsort(least_square_errors)
    pos = np.searchsorted(np.array(least_square_errors)[asc_indices], noise_scale)
    print(f"Order of the target: {pos}")
    for i in asc_indices[max(0, pos - 1000): pos + 1000]:
        if i == asc_indices[pos]:
            print("================")
        if 0.8 < all_coefsums[i] < 1.2:
            print(f"{least_square_errors[i]:.6f}\t{i}\t{chatglm6b.tokenizer.decode(i)}\t{all_coefsums[i]:.4f}")

    


In [4]:
test_recover_word("Hearing", 0.005, 50)

Embedding scale: 0.0113
Random coefs: tensor([-4.0313,  3.2110,  0.7622,  1.3049, -0.7019, -3.8832,  0.7295, -1.2794,
        -0.3520,  3.4300,  1.4151, -2.2342,  3.7410,  0.5230, -2.4572,  0.0692,
         4.4390, -0.2464, -2.2110, -4.9531,  1.4838, -1.3954, -4.6838, -0.3439,
        -2.2473,  1.8736,  2.7595,  0.6757, -1.3311, -2.1164,  2.4726,  1.8249,
        -1.9824, -1.7143, -3.5250, -3.3120,  4.3333,  2.7268,  4.7316, -3.8280,
        -0.7841, -2.5892,  0.1153, -1.8021, -3.3790, -4.6919,  2.4585,  4.5744,
         2.0065,  0.2911], device='cuda:2'), sum: -10.123047828674316
Errors by given the original embedding: 0.004526, sum: -405.811767578125


100%|██████████| 4063/4063 [00:27<00:00, 148.80it/s]


Order of the target: 1
