In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer

print("Loading model and tokenizer...")
# model_name = "databricks/dolly-v2-3b"
model_name = "failspy/kappa-3-phi-abliterated"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    use_cache=True,
    #  device="cuda:0",
    #  device_map="auto",
    trust_remote_code=True,
).to("cuda:0")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, use_cache=True)
print("Loaded model and tokenizer")

  from .autonotebook import tqdm as notebook_tqdm


Loading model and tokenizer...


`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attenton` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.
Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.10s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded model and tokenizer


# Scratch

In [65]:
from jaxtyping import Float, Int
import torch
from torch.nn import functional as F
from torch import Tensor
from typing import List, Callable, Tuple, Dict, Optional
import pandas as pd

In [43]:
# initital state
prompt = "The necromancer in his tower, what's his top problem? "
choices = ["1", "the skeleton", "the boney boys"]
choices_tokens = tokenizer(choices).input_ids
choices_tokens = [torch.tensor(c) for c in choices_tokens]
# current_tokens = torch.tensor([])

# next
input_ids = tokenizer([prompt], return_tensors="pt").to(model.device).input_ids[0]
choices_tokens

# for each next choice, continue down the tree, recording the log probs

[tensor([    1, 29871, 29896]),
 tensor([    1,   278, 18109, 11285]),
 tensor([    1,   278,   289,  4992, 12544])]

In [70]:
def get_valid_next_choices(choices_tokens, current_tokens):
    next_choices = []
    for choice_tokens in choices_tokens:
        # if we have some more slots left
        if len(current_tokens)<len(choice_tokens):
            # see if current_tokens matches
            if (choice_tokens[: len(current_tokens)] == current_tokens).all():
                c = choice_tokens[len(current_tokens)].item()
                next_choices.append(c)

    next_choices = list(set(next_choices))
    return torch.LongTensor(next_choices)


def next(
    input_ids: Int[Tensor, "seq"],
    choice: Optional[Int[Tensor, ""]] = None,
    prob: float = 1,
    current_tokens: Int[Tensor, "seq"] = torch.LongTensor([]),
    z=[],
):
    if choice is not None:
        c = choice[None].to(current_tokens.device)
        current_tokens = torch.cat([current_tokens, c], dim=-1)
        print(current_tokens, 'current_tokens')
        c = choice[None].to(input_ids.device)
        input_ids = torch.cat([input_ids, c], dim=-1)

    next_choices = get_valid_next_choices(choices_tokens, current_tokens)
    if len(next_choices) == 0:
        s = tokenizer.decode(current_tokens)
        r = dict(tokens=current_tokens.cpu(), prob=prob, choice=s)
        yield r
    else:
        o = model(input_ids[None])
        logits_constrained = o.logits[0, -1][next_choices]
        probs = F.softmax(logits_constrained, dim=-1)
        for i in range(len(next_choices)):
            next_choice = next_choices[i]
            next_prob = prob * probs[i].item()
            yield from next(
                input_ids=input_ids,
                choice=next_choice,
                prob=next_prob,
                current_tokens=current_tokens,
                z=z + [i],
            )


r = list(next(input_ids=input_ids))

tensor([1]) current_tokens
tensor([  1, 278]) current_tokens
tensor([  1, 278, 289]) current_tokens
tensor([   1,  278,  289, 4992]) current_tokens
tensor([    1,   278,   289,  4992, 12544]) current_tokens
tensor([    1,   278, 18109]) current_tokens
tensor([    1,   278, 18109, 11285]) current_tokens
tensor([    1, 29871]) current_tokens
tensor([    1, 29871, 29896]) current_tokens


In [74]:
pd.DataFrame(r).sort_values("prob", ascending=False).drop(columns=["tokens"])

Unnamed: 0,prob,choice
2,0.995732,<s> 1
0,0.004187,<s> the boney boys
1,8.1e-05,<s> the skeleton


# Continue

In [None]:
from prob_jsonformer.format import highlight_values
from prob_jsonformer.main import Jsonformer

ecomm = {
    "type": "object",
    "properties": {
        "store": {
            "type": "object",
            "properties": {
                "name": {"type": "string"},
                "location": {"type": "string"},
                "inventory": {
                    "type": "array",
                    "items": {
                        "type": "object",
                        "properties": {
                            "productId": {"type": "string"},
                            "name": {"type": "string"},
                            "description": {"type": "string"},
                            "category": {"type": "string"},
                            "price": {"type": "number"},
                            "inStock": {"type": "boolean"},
                            "rating": {"type": "number"},
                            "images": {"type": "array", "items": {"type": "string"}},
                        },
                    },
                },
            },
        }
    },
}


builder = Jsonformer(
    model=model,
    tokenizer=tokenizer,
    json_schema=ecomm,
    prompt="write a description about mike's ski shop which sells premium skis and snowboards",
    max_string_token_length=20,
)

print("Generating...")
output = builder()

highlight_values(output)

In [None]:
car = {
    "type": "object",
    "properties": {
        "make": {"type": "string"},
        "model": {"type": "string"},
        "year": {"type": "number"},
        "colors_available": {
            "type": "array",
            "items": {"type": "string"},
        },
    },
}

builder = Jsonformer(
    model=model,
    tokenizer=tokenizer,
    json_schema=car,
    prompt="generate an example car",
)

print("Generating...")
output = builder()

highlight_values(output)

In [None]:
complex_car = {
    "type": "object",
    "properties": {
        "car": {
            "type": "object",
            "properties": {
                "make": {"type": "string"},
                "model": {"type": "string"},
                "year": {"type": "number"},
                "colors": {"type": "array", "items": {"type": "string"}},
                "features": {
                    "type": "object",
                    "properties": {
                        "audio": {
                            "type": "object",
                            "properties": {
                                "brand": {"type": "string"},
                                "speakers": {"type": "number"},
                                "hasBluetooth": {"type": "boolean"},
                            },
                        },
                        "safety": {
                            "type": "object",
                            "properties": {
                                "airbags": {"type": "number"},
                                "parkingSensors": {"type": "boolean"},
                                "laneAssist": {"type": "boolean"},
                            },
                        },
                        "performance": {
                            "type": "object",
                            "properties": {
                                "engine": {"type": "string"},
                                "horsepower": {"type": "number"},
                                "topSpeed": {"type": "number"},
                            },
                        },
                    },
                },
            },
        },
        "owner": {
            "type": "object",
            "properties": {
                "firstName": {"type": "string"},
                "lastName": {"type": "string"},
                "age": {"type": "number"},
            },
        },
    },
}
builder = Jsonformer(
    model=model,
    tokenizer=tokenizer,
    json_schema=complex_car,
    prompt="generate an example Rolls Royce Phantom",
)

print("Generating...")
output = builder()

highlight_values(output)