In [29]:
import baserun
import openai
import os
import tiktoken
import json
from pydantic import BaseModel
from tqdm.notebook import tqdm
from typing import Tuple
from tenacity import retry, wait_random_exponential

from dotenv import load_dotenv

Set API keys

In [30]:
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
baserun.init()



Connect to the database

In [31]:
import sqlite3


def connect(name: str) -> Tuple[sqlite3.Connection, sqlite3.Cursor]:
    conn = sqlite3.connect(name)
    c = conn.cursor()
    return conn, c

In [32]:
conn, c = connect("../db.sqlite")

Get the tokens for the classifier.

In [33]:
def get_tokens() -> dict:
    encoding = tiktoken.encoding_for_model("gpt-4")
    tokens = [encoding.encode(label) for label in ["0", "1", "2", "3", "N"]]

    # Flatten the list of lists and convert to a set.
    token_set = set([item for sublist in tokens for item in sublist])

    # Create a dictionary with the tokens and a default value of 100.
    token_dict = {str(token): 100 for token in token_set}

    return token_dict


logit_bias = get_tokens()
print(json.dumps(logit_bias))

{"45": 100, "15": 100, "16": 100, "17": 100, "18": 100}


~~Create the classification table if it doesn't exist already.~~

We don't need a classification table anymore.

In [34]:
def create_classification_table(c: sqlite3.Cursor, conn: sqlite3.Connection) -> None:
    dwa_classification_sql = """
    CREATE TABLE IF NOT EXISTS dwa_classification (
        onetsoc_code CHARACTER(10) NOT NULL,
        task_id DECIMAL(8,0) NOT NULL,    
        dwa_id CHARACTER VARYING(20) NOT NULL,
        classification CHARACTER(2) NOT NULL,
        FOREIGN KEY (onetsoc_code) REFERENCES occupation_data(onetsoc_code),
        FOREIGN KEY (task_id) REFERENCES task_statements(task_id),
        FOREIGN KEY (dwa_id) REFERENCES dwa_reference(dwa_id),
        PRIMARY KEY (onetsoc_code, task_id, dwa_id)
    );
    """

    c.execute(dwa_classification_sql)
    conn.commit()


# create_classification_table(c, conn)

Setup the prompts.

In [35]:
from pathlib import Path

user_template = Path(Path.cwd().parent / "prompts" / "user.txt").read_text()
system_template = Path(Path.cwd().parent / "prompts" / "system.txt").read_text()

Fetch all the rows from the database.

In [36]:
class DWAReference(BaseModel):
    dwa_id: str
    dwa_title: str
    classification: str | None = None

    @classmethod
    def from_tuple(cls, tup: tuple):
        return cls(
            dwa_id=tup[0],
            dwa_title=tup[1],
        )
    
    @retry(wait=wait_random_exponential(multiplier=1, max=60))
    async def classify(self) -> str:
        user_message = {
            "role": "user",
            "content": user_template.format(
                task=self.dwa_title,
            ),
        }
        system_message = {"role": "system", "content": system_template}

        response = await openai.ChatCompletion.acreate(
            messages=[system_message, user_message],
            logit_bias=logit_bias,
            max_tokens=1,
            temperature=0,
            model="gpt-3.5-turbo",
        )
        self.classification = response.choices[0]["message"]["content"]
        return self.classification

    def save(self, c: sqlite3.Cursor) -> None:
        insert_sql = "UPDATE dwa_reference SET classification = ? WHERE dwa_id = ?"
        c.execute(insert_sql, (self.classification, self.dwa_id))

In [37]:
def fetch_rows(c: sqlite3.Cursor, num: int = -1) -> list:
    select_query = """
select    
  t.dwa_id,
  t.dwa_title
from
  dwa_reference as t
where
  t.classification is null
order by
  t.dwa_id
  """
    c.execute(select_query)
    if num == -1:
        return [DWAReference.from_tuple(r) for r in c.fetchall()]
    else:
        return [DWAReference.from_tuple(r) for r in c.fetchmany(num)]

In [38]:
async def classify_all(rows: list[DWAReference]) -> None:
    for row in tqdm(rows, desc="Classifying"):
        await row.classify()
        print(f"[{row.dwa_id}]: {row.dwa_title}: {row.classification}")
        row.save(c)
    conn.commit()

In [42]:
rows = fetch_rows(c, 1000)
await classify_all(rows)

Classifying: 0it [00:00, ?it/s]