In [None]:
%%capture
!pip install transformers

In [None]:
# 6b6873ae3f441b3cca2e88d943ea91fb93965273
!sha1sum "./decode_seqs.json"   # built from weave/etc/surprisal/10.00

6b6873ae3f441b3cca2e88d943ea91fb93965273  ./decode_seqs.json


## Load the model

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("ckip-joint/bloom-3b-zh", add_prefix_space=True)
model = AutoModelForCausalLM.from_pretrained("ckip-joint/bloom-3b-zh", 
                                             torch_dtype=torch.bfloat16,
                                             ).to("cuda")

In [None]:
## Test generation
tokenizer.batch_decode(model.generate(
    **tokenizer(list("我達達"), 
                is_split_into_words=True,
                return_tensors="pt").to("cuda"), 
    max_new_tokens=20,
    do_sample=True))

['我達達開講啦~我發現我這個部落格也滿難搞的,因為今天要寫']

## Load the transcripts

In [None]:
import json
from pathlib import Path
from tqdm.auto import tqdm
decodes = json.loads(Path("decode_seqs.json").read_text(encoding="UTF-8"))

In [None]:
speaker_ids = sorted(list(decodes.keys()))
nll_data = {}
for speaker_x in tqdm(speaker_ids):
  decode_seqs = decodes[speaker_x]
  nll_speaker = nll_data.setdefault(speaker_x, [])
  for seqdata_x in decode_seqs:
    words_x = seqdata_x["sequences"]
    words_x = sum(words_x, [])
    batch = tokenizer(words_x, 
                      is_split_into_words=True,
                      return_tensors="pt").to("cuda")
    with torch.no_grad():
      out = model(**batch)

    prob = -out.logits.log_softmax(axis=-1)
    shift_labels = batch["input_ids"][:, 1:]
    nlls = prob[:, :-1, :].gather(axis=2, index=shift_labels.unsqueeze(2))
    nlls = nlls.cpu().squeeze().to(torch.float16).numpy()
    nll_speaker.append({
        **seqdata_x, 
        "nll": nlls
    })

  0%|          | 0/55 [00:00<?, ?it/s]

In [None]:
import pickle
out_path = Path("decode_nll_bloom-zh-3b.pkl")
out_path.write_bytes(pickle.dumps(nll_data))

4896748

In [None]:
!sha1sum $out_path

37b7f67f5cc0685e0d223a1fc3f733f2fde70ee6  decode_nll_bloom-zh-3b.pkl
