In [1]:
import json
import outlines
import torch
from transformers import AutoTokenizer
from textwrap import dedent

In [2]:
model_name = "Qwen/Qwen2-0.5B-Instruct"
model = outlines.models.transformers(
    model_name,
    device='mps',
    model_kwargs={
        'torch_dtype': torch.bfloat16,
        'trust_remote_code': True
    })
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [3]:
with open("../examples.json",'r') as fin:
    complaint_data = json.loads(fin.read())

In [4]:
from pydantic import BaseModel, Field, constr
from enum import Enum


class Department(str, Enum):
    clothing = "clothing"
    electronics = "electronics"
    kitchen = "kitchen"
    automotive = "automotive"

class ComplaintData(BaseModel):
    first_name: str
    last_name: str
    order_number: str = Field(pattern=r'[ADZ][0-9]{2}-[0-9]{4}')
    department: Department
    
complaint_processor = outlines.generate.json(model, ComplaintData)

In [5]:
def create_prompt(complaint):
    complaint_messages = [
        {
        'role': 'user',
        'content': f"""
        You are a complaint processing assistent, you aim is to process complaints and return the following intformation in this JSON format:
        {{
            'first_name': <first name>,
            'last_name': <last name>,
            'order number': <order number has the following format (ADZ)XX-XXXXX>,
            'department': <{"|".join([e.value for e in Department])}>,
        }}
        """},
        {'role': 'assistant',
         'content': "I undersand and will process the complaints in the JSON format you described"
        },
        {'role': 'user',
        'content': complaint['message']
        }
    ]
    complaint_prompt = tokenizer.apply_chat_template(complaint_messages, tokenize=False)
    return complaint_prompt

In [6]:
complaint_processor = outlines.generate.json(model, ComplaintData)

In [7]:
results = []
for complaint in complaint_data[0:10]:
    prompt = create_prompt(complaint)
    result = complaint_processor(prompt)
    results.append(result)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


In [8]:
idx = 4
complaint_data[idx]['message']

'Hi, my name is Sarah Collins.I recently ordered your SmartWidget, but it hasn stopped working entirely! I just purchased the RapidCharge battery pack that does not charge at all.My order was Z123456'

In [9]:
results[idx].json()

'{"first_name":"Sarah","last_name":"Collins","order_number":"Z12-3456","department":"electronics"}'