In [1]:
def construct_prompt(question, article):
    prompt = f'''You are a friendly chatbot who always responds in the style of a pirate. Below is an article, read the article and answer my question after the article. Now the article begins:{article} Now the article ends. Select several sentences from the article to answer my question. Question: {question}'''
    return prompt

In [2]:
import pysbd
seg = pysbd.Segmenter(language="en", clean=False)

In [133]:
with open('long_knowledge.txt', 'r') as f:
    search_rst = f.read().replace('\\n', '\n')
    sents_in_rst = [s for s in seg.segment(search_rst)]
question = 'What\'s the most significant news related to cybersecurity in this week? '
full_prompt = construct_prompt(question, search_rst)

In [134]:
from transformers import LlamaForCausalLM, LlamaTokenizer

tokenizer = LlamaTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Line

In [220]:
inputs = tokenizer(full_prompt, return_tensors="pt")
ouputs = model.model(**inputs, return_dict=True)
logits = model.lm_head(ouputs.last_hidden_state)

In [148]:
id2token = {v:k for k,v in tokenizer.get_vocab().items()}
all_prefix = set()
prefix2sents = {}
for id_, s in enumerate(sents_in_rst):
    ptr = prefix2sents
    input_ids = tokenizer(s, return_tensors="pt")['input_ids'][0]
    for i in range(1,4):
        prefix_id = int(input_ids[i])
        if prefix_id not in ptr:
            ptr[prefix_id] = {'sents': []}
        ptr[prefix_id]['sents'].append(id_)
        ptr = ptr[prefix_id]
        all_prefix.add(prefix_id)
len(all_prefix)

78

In [149]:
from torch import nn
import torch
def predict_next_token(model, hidden_state):
    logits = model.lm_head(hidden_state)
    logits = logits.float()
    next_token_logits = logits[:, -1, :]
    next_token_probs = torch.softmax(next_token_logits, -1)
    return next_token_probs

In [150]:
def prefix_match_sents(toks):
    ptr = prefix2sents
    for t in toks:
        if t not in ptr:
            return []
        ptr = ptr[t]
    return ptr['sents']

In [151]:
import heapq 
import math
from tqdm import tqdm
def get_topk_prefixes(model, current_status, top_k = 2):
    ret = []
    status = current_status
    for _ in tqdm(range(5)):
        next_status = []
        for output, score, prev_toks in status:
            last_hidden_state = output.last_hidden_state
            next_token_probs = predict_next_token(model, last_hidden_state)
            top_k_candidates = []
            for id in all_prefix:
                if len(top_k_candidates) >= top_k:
                    if top_k_candidates[0][0] < next_token_probs[0][id]:
                        heapq.heappop(top_k_candidates)
                        heapq.heappush(top_k_candidates, (next_token_probs[0][id], id))
                else:
                    heapq.heappush(top_k_candidates, (next_token_probs[0][id], id))
            for prob, prefix_id in top_k_candidates:
                next_input_ids = torch.tensor([prefix_id])[:, None]
                next_attn_mask =  torch.tensor([1])[:, None]
                new_input = {'input_ids': next_input_ids, 'attention_mask': next_attn_mask}
                past_key_values = output.past_key_values
                new_ouput = model.model(**new_input, past_key_values=past_key_values, return_dict=True)
                new_prev_toks = prev_toks + [prefix_id]
                new_score = score+math.log(prob)
                match_prefix = prefix_match_sents(new_prev_toks)
                if len(match_prefix) == 1:
                    ret.append({'toks': new_prev_toks, 'score': new_score/len(new_prev_toks), 'idx': match_prefix[0]})
                if len(match_prefix) > 0:
                    next_status.append((new_ouput, new_score, new_prev_toks))
            status = next_status
    return ret

In [152]:
k = 3
topk_prefixes = get_topk_prefixes(model, [(ouputs, 0, [])], k)
topk_prefixes.sort(key=lambda item: item['score'], reverse=True)

100%|██████████| 5/5 [00:06<00:00,  1.25s/it]


In [233]:
def find_sent_index(promptToks, prefixToks):
    prefix_len = len(prefixToks)
    prefix_str = ''.join([id2token[id_].replace('▁', ' ') for id_ in prefixToks])
    for i in range(len(promptToks)):
        tmp = ''.join([id2token[int(id_)].replace('▁', ' ') for id_ in promptToks[i:i+prefix_len]])
        if prefix_str.strip() == tmp.strip():
            return i
    return -1

In [234]:
def find_end_of_seq(beg_idx, logits, end_of_sent='</s>', distance=200):
    len_of_prompt = len(inputs.input_ids[0])
    end_of_sent_id = tokenizer.get_vocab()[end_of_sent]
    max_score = float('-inf')
    max_idx = -1
    for i in range(beg_idx, min(beg_idx+distance, len_of_prompt)):
        end_score = float(logits[0][i][end_of_sent_id])
        if max_score < end_score:
            max_score = end_score
            max_idx = i
    return max_idx


In [237]:
prompt_toks = [int(i) for i in inputs.input_ids[0]]
for cand in topk_prefixes:
    sent_prefix = tokenizer.batch_decode([cand['toks']], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    score = cand['score']
    sent_beg_idx = find_sent_index(inputs.input_ids[0], cand['toks'])
    end_idx = find_end_of_seq(sent_beg_idx, logits)
    high_quality_sec = inputs.input_ids[0][sent_beg_idx: end_idx]
    print(tokenizer.batch_decode([high_quality_sec], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])

- The US CISA was also affected by an Ivanti system vulnerability at the beginning of the year. Hackers targeted the US federal high-risk chemical critical infrastructure, infiltrating the Chemical Security Assessment Tool (CSAT) provided by CISA and successfully deploying a Web Shell. Another key infrastructure cybersecurity information tool, CISA Gateway, was also affected.
- The OWASP Foundation, well-known for publishing the top ten web application security risks, recently issued a data breach notification. Member resume files from 2006 to 2014 may have been leaked due to a configuration issue on an old Wiki web server.
