In [None]:
from typing import List, Optional
import itertools
import requests

import pandas as pd
from pydantic import BaseModel, Field, validator
from kor import extract_from_documents, from_pydantic, create_extraction_chain
from kor.documents.html import MarkdownifyHTMLProcessor
from langchain.chat_models import ChatOpenAI
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.llms.openai import OpenAI

In [None]:
# Using gpt-3.5-turbo which is pretty cheap, but has worse quality
llm = ChatOpenAI(temperature=0)
# llm = OpenAI(temperature=0, openai_api_key="YOUR_API_KEY", model)

In [None]:
class train(BaseModel):
    departure: str = Field(
        description="出发地点",
    )
    destination: str = Field(
        description="到达地点",
    )
    departure_time: str = Field(
        description="出发时间",
    )
    arrival_time: str = Field(
        description="到达时间",
    )
    train_number: str = Field(
        description="火车班次号码",
    )
    train_money: str = Field(
        description="价格",
    )
    time: str = Field(
        description="全程时间",
    )


schema, extraction_validator = from_pydantic(
    train,
    description="提取有关火车时刻表的信息，包括它们的出发、目的地、出发时间、到达时间、车次、价格和全程时间。",
    examples=[
        (
            """
            抢票成功率：07:40青岛2时50分D291610:30 灌南127* 5月20日09:30开售,可预约抢票,开售自动抢
                抢* **二等座**127抢票
                * **一等座**203抢票
                * **无座**127抢票
            """,
            {"departure": "青岛", "destination": "灌南", "departure_time": "07:40", "arrival_time": "10:30", "train_number": "D2916", "train_money": "127", "time": "2时50分"},
        ),
        (
            """
            抢票成功率：10:27青岛北2时26分G155312:53 灌南123* 5月20日09:30开售,可预约抢票,开售自动抢
            抢* **二等座**123抢票
            * **一等座**197抢票
            * **商务座**370抢票
            """,
            {"departure": "青岛北", "destination": "灌南", "departure_time": "10:27", "arrival_time": "12:53", "train_number": "G1553", "train_money": "123", "time": "2时26分"},
        )
        
    ],
    many=True,
)

In [None]:
chain = create_extraction_chain(
    llm,
    schema,
    encoder_or_encoder_class="csv",
    validator=extraction_validator,
    input_formatter="triple_quotes",
)

In [None]:
print(chain.prompt.format_prompt(text="[user input]").to_string())

In [None]:
url = "https://trains.ctrip.com/webapp/train/list?ticketType=0&dStation=%E9%9D%92%E5%B2%9B%E6%9C%BA%E5%9C%BA&aStation=%E9%9D%92%E5%B2%9B%E5%8C%97&dDate=2023-06-03&rDate=&trainsType=gaotie-dongche&hubCityName=&highSpeedOnly=0"
response = requests.get(url)  # Please see comment at top about using Selenium or

In [None]:
doc = Document(page_content=response.text)
md = MarkdownifyHTMLProcessor().process(doc)
md

In [None]:
md.page_content = md.page_content.split("### 中转方案推荐")[0]

In [None]:
split_docs = RecursiveCharacterTextSplitter().split_documents([md])
print(split_docs[0].page_content)
len(split_docs)

In [None]:
from langchain.callbacks import get_openai_callback

In [None]:
with get_openai_callback() as cb:
    document_extraction_results = await extract_from_documents(
        chain, split_docs, max_concurrency=5, use_uid=False, return_exceptions=True
    )
    print(f"Total Tokens: {cb.total_tokens}")
    print(f"Prompt Tokens: {cb.prompt_tokens}")
    print(f"Completion Tokens: {cb.completion_tokens}")
    print(f"Successful Requests: {cb.successful_requests}")
    print(f"Total Cost (USD): ${cb.total_cost}")

In [None]:
validated_data = list(
    itertools.chain.from_iterable(
        extraction["validated_data"] for extraction in document_extraction_results
    )
)

In [None]:
result = pd.DataFrame(record.dict() for record in validated_data)

result[(~result['train_money'].isin(['无票', '未知'])) & (~result['time'].isin(['无票', '未知']))].query('departure == "青岛北"')

In [None]:
result