Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

一种改进next_token计算的方式 #27

Open
luchangli03 opened this issue Jul 20, 2023 · 1 comment
Open

一种改进next_token计算的方式 #27

luchangli03 opened this issue Jul 20, 2023 · 1 comment

Comments

@luchangli03
Copy link

luchangli03 commented Jul 20, 2023

采用下面的方式替代已有计算可以明显降低next_token计算量,用于替换原有的

next_token_scores = self.apply_warp(next_token_scores)

probs = npsoftmax(next_token_scores.astype(np.float64), axis=1)

 # Caution:
 # *** ValueError: sum(pvals[:-1].astype(np.float64)) > 1.0. The pvals array is cast to 64-bit floating point prior to checking the sum. Precision changes when casting may cause problems even if the sum of the original pvals is valid.
next_token = npmultinominal2D(probs).astype(input_ids.dtype)

这几句代码。
新的计算代码:

next_token = post_process(next_token_scores)

def post_process(tensor, topk=3):
    tensor = tensor.reshape([-1]).astype("float32")
    tensor = warp_temperature(tensor, 1.0)

    topk_vals, topk_idxs = warp_topk1(tensor, topk)

    probs = npsoftmax(topk_vals, axis=0)

    max_idx = np.random.multinomial(1, probs).argmax()

    next_token = topk_idxs[max_idx]
    next_token = np.array([next_token], dtype="int64").reshape([-1, 1])
    return next_token


def warp_topk1(tensor, topk):
    tensor_1d = tensor.reshape([-1])
    topk_vals, topk_idxs = get_topk(tensor_1d, topk)
    return topk_vals, topk_idxs


def get_topk(tensor_1d, topk=3):
    # value in topk_vals are placed by descending order
    topk_vals = [-float("Inf")] * topk
    topk_idxs = [0] * topk

    for idx, elem in enumerate(tensor_1d):
        if elem > topk_vals[topk - 1]:
            for i in range(topk):
                # find where current top value should be placed
                # then we right shift the topk_vals to place the top value
                if elem > topk_vals[i]:
                    # right shift
                    for j in reversed(range(i, topk-1)):
                        topk_vals[j+1] = topk_vals[j]
                        topk_idxs[j+1] = topk_idxs[j]

                    topk_vals[i] = elem
                    topk_idxs[i] = idx
                    break
    return topk_vals, topk_idxs
@tpoisonooo
Copy link
Owner

您直接发 PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants