# Project - Train Assistant for Ireland

## Libraries

In [84]:
import os
import json
from dotenv import load_dotenv
from openai import OpenAI
import gradio as gr
import sqlite3
from datetime import date, datetime
import difflib

import tiktoken
import numpy as np
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sklearn.manifold import TSNE
import plotly.graph_objects as go

# Load Env and Model Init

In [85]:
load_dotenv(override=True)

open_ai_key = os.getenv("OPENAI_API_KEY")

if open_ai_key is None:
    raise ValueError("OPENAI_API_KEY environment variable not set")
else:
    print("OPENAI_API_KEY found and set")

open_ai = OpenAI() # OPENAI instance is created here

MODEL = "gpt-4.1-mini"
DB = "irish_rail_services_2026_sample.db"
DEP_TABLE = "departure_station"
ARR_TABLE = "arrival_station"
DEP_TIME_TABLE = "departure_time"
ARR_TIME_TABLE = "arrival_time"
FARE_TABLE = "adult_single_from_eur"
DURATION_TABLE = "duration_min"
SERVICE_DATE_TABLE = "service_date"
LangChain_DB = "rail_vector_db"




OPENAI_API_KEY found and set


# System Prompt

In [None]:
system_prompt = """
You are a helpful assistant for rail travel in Ireland.
Use the provided tool to look up scheduled train services and indicative fares.
Please provide accurate and concise information.
When listing services, show up to 3 options at earliest with departure time, arrival time, duration, and the fare (as a “from” price). 
Please state that you are showing 3 options at earliest.
If the user’s date/time is missing, you should ask before providing options.
If the user says travel from dublin to belfast consider that the user refers to dublin connolly station. Do not ask for clarification.
If the user says travel from dublin to cork or galway consider that the user refers to dublin heuston station. Do not ask for clarification.
If the user does not tell you the date and time, ask before providing options.
If you don't know, say so.
"""

SYSTEM_PROMPT_TEMPLATE = """
You are a helpful assistant for rail travel in Ireland.
Use the provided tool to look up scheduled train services and indicative fares.
Please provide accurate and concise information.
If the user’s date/time is missing, you should ask before providing options.
When listing services, show up to 3 options at earliest with departure time, arrival time, duration, and the fare (as a “from” price).
Please state that you are showing 3 options at earliest.
If the user says travel from dublin to belfast consider that the user refers to dublin connolly station. Do not ask for clarification.
If the user says travel from dublin to cork or galway consider that the user refers to dublin heuston station. Do not ask for clarification.
If user says a date that is in the past, inform them that you can only provide information for current or future dates.
If user says a fuzzy date in the future you can calculate the exact date based on the current date.
If you don't know, say so.
{context}
"""


# Fetching Data From DB

In [87]:
import os
import sqlite3
import pandas as pd

# 1) Point to your CSV (edit if needed)
CSV_PATH = os.path.join(os.getcwd(), "irish_rail_services_2026_sample.csv")

# 2) Output DB path (this matches the file you showed with spaces)
DB_PATH = os.path.join(os.getcwd(), "irish rail services 2026 sample.db")

TABLE = "rail_services"

print("CSV:", CSV_PATH, "exists:", os.path.exists(CSV_PATH))
print("DB :", DB_PATH)

# Load CSV
df = pd.read_csv(CSV_PATH)

# (Optional) ensure expected columns exist
expected = {
    "service_date","operator","route_code",
    "departure_station","arrival_station",
    "departure_time","arrival_time",
    "duration_min","adult_single_from_eur","notes"
}
missing = expected - set(df.columns)
if missing:
    raise ValueError(f"CSV is missing columns: {missing}")

# Create/overwrite DB and write table
if os.path.exists(DB_PATH):
    os.remove(DB_PATH)

with sqlite3.connect(DB_PATH) as conn:
    df.to_sql(TABLE, conn, index=False)
    cur = conn.cursor()
    cur.execute(f"CREATE INDEX IF NOT EXISTS idx_service_date ON {TABLE}(service_date);")
    cur.execute(f"CREATE INDEX IF NOT EXISTS idx_route_code   ON {TABLE}(route_code);")
    cur.execute(f"CREATE INDEX IF NOT EXISTS idx_stations     ON {TABLE}(departure_station, arrival_station);")
    cur.execute(f"CREATE INDEX IF NOT EXISTS idx_times        ON {TABLE}(departure_time, arrival_time);")
    conn.commit()

# Verify
print("DB exists:", os.path.exists(DB_PATH))
print("DB size :", os.path.getsize(DB_PATH))

with sqlite3.connect(DB_PATH) as conn:
    cur = conn.cursor()
    cur.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;")
    print("Tables:", cur.fetchall())
    cur.execute(f"SELECT COUNT(*) FROM {TABLE};")
    print("Rows:", cur.fetchone()[0])


CSV: /Users/yasinemirkutlu/LLMProjectsUdemy/llm_engineering/week2/week2_YasinEmirKutlu/irish_rail_services_2026_sample.csv exists: True
DB : /Users/yasinemirkutlu/LLMProjectsUdemy/llm_engineering/week2/week2_YasinEmirKutlu/irish rail services 2026 sample.db
DB exists: True
DB size : 5263360
Tables: [('rail_services',)]
Rows: 18149


In [None]:
import sqlite3
import difflib
import re
import json
from datetime import date, datetime, timedelta
from typing import Optional
DB = "irish rail services 2026 sample.db"
TABLE = "rail_services"
STATION_ALIASES = {
    "dublin": ["Dublin Heuston", "Dublin Connolly"],
    "cork": ["Cork (Kent)"],
    "galway": ["Galway (Ceannt)"],
    "limerick": ["Limerick (Colbert)"],
    "waterford": ["Waterford (Plunkett)"],
    "belfast": ["Belfast (Grand Central)"],
}
_STATIONS_CACHE = None
_DATE_BOUNDS_CACHE = None
# ---- Ticket cache for "return the ticket that was inquired" ----
_LAST_SEARCH_CACHE = {
    "query": None,
    "tickets": []
}
def _make_ticket_id(row, service_date: str) -> str:
    # Stable-enough identifier for the notebook demo.
    return f"{service_date}|{row['route_code']}|{row['departure_time']}|{row['arrival_time']}"
def get_inquired_ticket(selection: int = 1, ticket_id: Optional[str] = None) -> str:
    """Return the ticket from the most recent inquiry.
    - If ticket_id is provided, returns the matching ticket.
    - Else uses 1-based selection: 1 = first option shown in the last search.
    Returns a JSON string.
    """
    tickets = _LAST_SEARCH_CACHE.get("tickets") or []
    if not tickets:
        return "No ticket found yet. Please inquire a route first (e.g., 'Dublin to Cork tomorrow after 14:00')."
    if ticket_id:
        for t in tickets:
            if t.get("ticket_id") == ticket_id:
                return json.dumps(t, ensure_ascii=False, indent=2)
        return f"I couldn't find a ticket with ticket_id='{ticket_id}'."
    try:
        idx = int(selection) - 1
    except Exception:
        idx = 0
    if idx < 0 or idx >= len(tickets):
        return f"Selection out of range. Choose between 1 and {len(tickets)}."
    return json.dumps(tickets[idx], ensure_ascii=False, indent=2)
def _normalize(text: str) -> str:
    return " ".join(text.lower().strip().split())
def _load_stations():
    global _STATIONS_CACHE
    if _STATIONS_CACHE is not None:
        return _STATIONS_CACHE
    with sqlite3.connect(DB) as conn:
        cur = conn.cursor()
        cur.execute(f"SELECT DISTINCT departure_station FROM {TABLE}")
        deps = [r[0] for r in cur.fetchall()]
        cur.execute(f"SELECT DISTINCT arrival_station FROM {TABLE}")
        arrs = [r[0] for r in cur.fetchall()]
    _STATIONS_CACHE = sorted(set(deps + arrs))
    return _STATIONS_CACHE
def _get_date_bounds():
    global _DATE_BOUNDS_CACHE
    if _DATE_BOUNDS_CACHE is not None:
        return _DATE_BOUNDS_CACHE
    with sqlite3.connect(DB) as conn:
        cur = conn.cursor()
        cur.execute(f"SELECT MIN(service_date), MAX(service_date) FROM {TABLE}")
        mn, mx = cur.fetchone()
    _DATE_BOUNDS_CACHE = (mn, mx)
    return _DATE_BOUNDS_CACHE
def _clamp_date(service_date: str) -> str:
    mn, mx = _get_date_bounds()
    if service_date < mn:
        return mn
    if service_date > mx:
        return mx
    return service_date
def _candidate_stations(user_text: str):
    if not user_text:
        return []
    key = _normalize(user_text)
    stations = _load_stations()
    if key in STATION_ALIASES:
        return STATION_ALIASES[key]
    for st in stations:
        if _normalize(st) == key:
            return [st]
    subs = [st for st in stations if key in _normalize(st)]
    if subs:
        return subs[:5]
    return difflib.get_close_matches(user_text, stations, n=5, cutoff=0.55)
def search_train_services(departure_station: str,
                          arrival_station: str,
                          service_date: str | None = None,
                          depart_after: str | None = None,
                          limit: int = 3) -> str:
    print(
        f"DB TOOL CALLED: search_train_services({departure_station=}, {arrival_station=}, {service_date=}, {depart_after=}, {limit=})",
        flush=True
    )
    service_date = _clamp_date(_parse_service_date(service_date))
    if not depart_after or not str(depart_after).strip():
        depart_after = "00:00"
    depart_after = str(depart_after).strip()
    dep_candidates = _candidate_stations(departure_station)
    arr_candidates = _candidate_stations(arrival_station)
    if not dep_candidates:
        return f"I couldn't match the departure station '{departure_station}'. Try: Dublin Heuston, Dublin Connolly, Cork (Kent), Galway (Ceannt), Limerick (Colbert), Waterford (Plunkett), Belfast (Grand Central)."
    if not arr_candidates:
        return f"I couldn't match the arrival station '{arrival_station}'. Try: Dublin Heuston, Dublin Connolly, Cork (Kent), Galway (Ceannt), Limerick (Colbert), Waterford (Plunkett), Belfast (Grand Central)."
    with sqlite3.connect(DB) as conn:
        conn.row_factory = sqlite3.Row
        cur = conn.cursor()
        for dep in dep_candidates:
            for arr in arr_candidates:
                cur.execute(
                    f"""
                    SELECT service_date, operator, route_code,
                           departure_station, arrival_station,
                           departure_time, arrival_time,
                           duration_min, adult_single_from_eur
                    FROM {TABLE}
                    WHERE service_date = ?
                      AND departure_station = ?
                      AND arrival_station = ?
                      AND departure_time >= ?
                    ORDER BY departure_time
                    LIMIT ?
                    """,
                    (service_date, dep, arr, depart_after, int(limit)),
                )
                rows = cur.fetchall()
                if rows:
                    tickets = []
                    for r in rows:
                        price_val = None if r["adult_single_from_eur"] is None else float(r["adult_single_from_eur"])
                        ticket = {
                            "ticket_id": _make_ticket_id(r, service_date),
                            "service_date": r["service_date"],
                            "operator": r["operator"],
                            "route_code": r["route_code"],
                            "departure_station": r["departure_station"],
                            "arrival_station": r["arrival_station"],
                            "departure_time": r["departure_time"],
                            "arrival_time": r["arrival_time"],
                            "duration_min": int(r["duration_min"]) if r["duration_min"] is not None else None,
                            "adult_single_from_eur": price_val,
                            "currency": "EUR",
                            "booking_url": "https://booking.cf.irishrail.ie/",
                            "type": "adult_single_from_price",
                        }
                        tickets.append(ticket)
                    # cache last search so we can later "return the ticket"
                    _LAST_SEARCH_CACHE["query"] = {
                        "departure_station": dep,
                        "arrival_station": arr,
                        "service_date": service_date,
                        "depart_after": depart_after,
                    }
                    _LAST_SEARCH_CACHE["tickets"] = tickets
                    header = f"Next {len(rows)} train(s) {dep} → {arr} on {service_date} (after {depart_after}):"
                    lines = []
                    for i, t in enumerate(tickets, start=1):
                        price = f"€{t['adult_single_from_eur']:.2f}+" if t["adult_single_from_eur"] is not None else "price N/A"
                        lines.append(
                            f"- Option {i}: {t['departure_time']}–{t['arrival_time']} ({t['duration_min']} min), {price} [{t['route_code']}]"
                        )
                    return header + "\n" + "\n".join(lines)
    return f"No direct scheduled services found for {dep_candidates[0]} → {arr_candidates[0]} on {service_date} after {depart_after}."

In [89]:
import os, sqlite3

print("CWD:", os.getcwd())
print("DB absolute:", os.path.abspath(DB))
print("DB exists:", os.path.exists(DB))
print("DB size:", os.path.getsize(DB) if os.path.exists(DB) else None)

with sqlite3.connect(DB) as conn:
    cur = conn.cursor()
    cur.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;")
    print("Tables:", cur.fetchall())


CWD: /Users/yasinemirkutlu/LLMProjectsUdemy/llm_engineering/week2/week2_YasinEmirKutlu
DB absolute: /Users/yasinemirkutlu/LLMProjectsUdemy/llm_engineering/week2/week2_YasinEmirKutlu/irish rail services 2026 sample.db
DB exists: True
DB size: 5263360
Tables: [('rail_services',)]


In [123]:
def get_current_date(timezone="UTC", include_time=True):
    now = datetime.now()
    # Always return the full date and time
    return now.strftime("%A, %B %d, %Y at %H:%M:%S")  # e.g., "Sunday, January 19, 2026 at 14:30:45"

# Function Structure (JSON)

In [122]:
train_search_function = {
    "name": "search_train_services",
    "description": "Find scheduled train services in Ireland for a given route and date/time, including duration and indicative fare.",
    "parameters": {
        "type": "object",
        "properties": {
            "departure_station": {
                "type": "string",
                "description": "Departure station or city (e.g., 'Dublin Heuston' or 'Dublin')"
            },
            "arrival_station": {
                "type": "string",
                "description": "Arrival station or city (e.g., 'Cork (Kent)' or 'Cork')"
            },
            "service_date": {
                "type": "string",
                "description": "Travel date (supports YYYY-MM-DD, 'tomorrow', 'Thursday', 'next Thursday', etc.)"
            },
            "depart_after": {
                "type": "string",
                "description": "Only show trains departing at or after this time (HH:MM, optional)"
            },
            "limit": {
                "type": "integer",
                "description": "Maximum number of services to return (default 5)"
            }
        },
        "required": ["departure_station", "arrival_station"],
        "additionalProperties": False
    }
}

get_ticket_function = {
    "name": "get_inquired_ticket",
    "description": "Return the ticket from the most recent inquiry (by option number or ticket_id).",
    "parameters": {
        "type": "object",
        "properties": {
            "selection": {
                "type": "integer",
                "description": "1-based option number (1 = first option)"
            },
            "ticket_id": {
                "type": "string",
                "description": "Exact ticket_id (optional)"
            }
        },
        "additionalProperties": False
    }
}

get_current_date_function = {
        "name": "get_current_date",
        "description": "Returns the current date and time including day of week and hour. Use this when you need to know what day/time it is.",
        "input_schema": {
            "type": "object",
            "properties": {
                "timezone": {
                    "type": "string",
                    "description": "Timezone (e.g., 'UTC', 'America/New_York'). Defaults to UTC.",
                    "default": "UTC"
                },
                "include_time": {
                    "type": "boolean",
                    "description": "Whether to include the time (hour, minute, second). Defaults to true.",
                    "default": True
                }
            }
        }
    }

tools = [
    {"type": "function", "function": train_search_function},
    {"type": "function", "function": get_ticket_function},
    {"type": "function", "function": get_current_date_function},
]

tools

[{'type': 'function',
  'function': {'name': 'search_train_services',
   'description': 'Find scheduled train services in Ireland for a given route and date/time, including duration and indicative fare.',
   'parameters': {'type': 'object',
    'properties': {'departure_station': {'type': 'string',
      'description': "Departure station or city (e.g., 'Dublin Heuston' or 'Dublin')"},
     'arrival_station': {'type': 'string',
      'description': "Arrival station or city (e.g., 'Cork (Kent)' or 'Cork')"},
     'service_date': {'type': 'string',
      'description': "Travel date (supports YYYY-MM-DD, 'tomorrow', 'Thursday', 'next Thursday', etc.)"},
     'depart_after': {'type': 'string',
      'description': 'Only show trains departing at or after this time (HH:MM, optional)'},
     'limit': {'type': 'integer',
      'description': 'Maximum number of services to return (default 5)'}},
    'required': ['departure_station', 'arrival_station'],
    'additionalProperties': False}}},
 {'ty

In [92]:
def chat_with_flight_bot(message, history):
    messages = [{"role": "system", "content": SYSTEM_PROMPT_TEMPLATE}] + history + [{"role": "user", "content": message}]
    response = open_ai.chat.completions.create(model=MODEL, messages=messages, tools=tools)

    while response.choices[0].finish_reason=="tool_calls":
        message = response.choices[0].message
        responses = handle_tool_calls(message)
        messages.append(message)
        messages.extend(responses)
        response = open_ai.chat.completions.create(model=MODEL, messages=messages, tools=tools)
    
    return response.choices[0].message.content

In [137]:
def handle_tool_calls(message):
    responses = []
    for tool_call in message.tool_calls:
        if tool_call.function.name == "search_train_services":
            arguments = json.loads(tool_call.function.arguments)
            dep = arguments.get("departure_station")
            arr = arguments.get("arrival_station")
            service_date = arguments.get("service_date")
            depart_after = arguments.get("depart_after")
            limit = arguments.get("limit", 5000)

            details = search_train_services(
                departure_station=dep,
                arrival_station=arr,
                service_date=service_date,
                depart_after=depart_after,
                limit=limit,
            )

            responses.append({
                "role": "tool",
                "content": details,
                "tool_call_id": tool_call.id
            })

        elif tool_call.function.name == "get_inquired_ticket":
            arguments = json.loads(tool_call.function.arguments or "{}")
            selection = arguments.get("selection", 1)
            ticket_id = arguments.get("ticket_id")

            details = get_inquired_ticket(selection=selection, ticket_id=ticket_id)

            responses.append({
                "role": "tool",
                "content": details,
                "tool_call_id": tool_call.id
            })
            
        elif tool_call.function.name == "get_current_date":
            arguments = json.loads(tool_call.function.arguments or "{}")
            timezone = arguments.get("timezone", "GMT")
            include_time = arguments.get("include_time", True)

            details = get_current_date(timezone=timezone, include_time=include_time)

            

            responses.append({
                "role": "tool",
                "content": details,
                "tool_call_id": tool_call.id
            })

    return responses

In [136]:
gr.ChatInterface(fn=chat_with_flight_bot, title="Irish Rail Chatbot", type="messages").launch()

* Running on local URL:  http://127.0.0.1:7882
* To create a public link, set `share=True` in `launch()`.




# Multi-modal Chatbot

# Talker Function

In [98]:
def talker(message):
    response = open_ai.audio.speech.create(
      model="gpt-4o-mini-tts",
      voice="coral",    # Also, try replacing onyx with alloy or coral
      input=message
    )
    return response.content

In [99]:
def talker(message, accent="Irish"):
    response = open_ai.audio.speech.create(
        model="gpt-4o-mini-tts",
        voice="marin",  # or try "marin"/"cedar" (often very natural)
        input=message,
        instructions=f"Speak in a natural {accent} English accent."
    )
    return response.content


In [100]:
def chat(history):
    messages = [{"role": "system", "content": SYSTEM_PROMPT_TEMPLATE}] + history
    response = open_ai.chat.completions.create(model=MODEL, messages=messages, tools=tools)
    cities = []
    image = None

    while response.choices[0].finish_reason=="tool_calls":
        message = response.choices[0].message
        responses, cities = handle_tool_calls_and_return_cities(message)
        messages.append(message)
        messages.extend(responses)
        response = open_ai.chat.completions.create(model=MODEL, messages=messages, tools=tools)

    reply = response.choices[0].message.content
    history += [{"role":"assistant", "content":reply}]

    voice = talker(reply)

    if cities:
        image = artist(cities[0])
    
    return history, voice, image


In [138]:
def handle_tool_calls_and_return_cities(message):
    responses = []
    cities = []
    for tool_call in message.tool_calls:
        if tool_call.function.name == "search_train_services":
            arguments = json.loads(tool_call.function.arguments)
            dep = arguments.get("departure_station")
            arr = arguments.get("arrival_station")
            service_date = arguments.get("service_date")
            depart_after = arguments.get("depart_after")
            limit = arguments.get("limit", 5000)

            # Keep "cities" list for optional image generation (arrival is most relevant)
            if arr:
                cities.append(arr)

            details = search_train_services(
                departure_station=dep,
                arrival_station=arr,
                service_date=service_date,
                depart_after=depart_after,
                limit=limit,
            )

            responses.append({
                "role": "tool",
                "content": details,
                "tool_call_id": tool_call.id
            })

        elif tool_call.function.name == "get_inquired_ticket":
            arguments = json.loads(tool_call.function.arguments or "{}")
            selection = arguments.get("selection", 1)
            ticket_id = arguments.get("ticket_id")

            details = get_inquired_ticket(selection=selection, ticket_id=ticket_id)

            responses.append({
                "role": "tool",
                "content": details,
                "tool_call_id": tool_call.id
            })

        elif tool_call.function.name == "get_current_date":
            arguments = json.loads(tool_call.function.arguments or "{}")
            timezone = arguments.get("timezone", "GMT")
            include_time = arguments.get("include_time", True)

            details = get_current_date(timezone=timezone, include_time=include_time)

            responses.append({
                "role": "tool",
                "content": details,
                "tool_call_id": tool_call.id
            })

    return responses, cities

# Knowledge_Base Folder

In [139]:
import glob as glob_mod


knowledge_base_path = "knowledge_base/**/*.md"
files = glob_mod.glob(knowledge_base_path, recursive=True)
print(f"Found {len(files)} files in the knowledge base")

entire_knowledge_base = ""

for file_path in files:
    with open(file_path, 'r', encoding='utf-8') as f:
        entire_knowledge_base += f.read()
        entire_knowledge_base += "\n\n"

print(f"Total characters in knowledge base: {len(entire_knowledge_base):,}")

Found 21 files in the knowledge base
Total characters in knowledge base: 19,635


In [103]:
encoding_folders = tiktoken.encoding_for_model(MODEL)
token_folders = encoding_folders.encode(entire_knowledge_base)
token_counts = len(token_folders)
print(f"Total tokens for {MODEL}: {token_counts:,}")

Total tokens for gpt-4.1-mini: 4,989


In [104]:
# Load in everything in the knowledgebase using LangChain's loaders

folders = glob_mod.glob("knowledge_base/*")

documents = []
for folder in folders:
    doc_type = os.path.basename(folder)
    loader = DirectoryLoader(folder, glob="**/*.md", loader_cls=TextLoader, loader_kwargs={'encoding': 'utf-8'})
    folder_docs = loader.load()
    for doc in folder_docs:
        doc.metadata["doc_type"] = doc_type
        documents.append(doc)

print(f"Loaded {len(documents)} documents")

Loaded 21 documents


In [105]:
documents[:2]  # Show first 2 documents

[Document(metadata={'source': 'knowledge_base/stations/galway_ceannt.md', 'doc_type': 'stations'}, page_content='---\ntype: station\nstation: Galway (Ceannt)\naliases: ["Galway Ceannt", "Ceannt Station", "Galway Station", "Céannt Station"]\noperator: Iarnród Éireann (Irish Rail)\nlat: 53.27384\nlon: -9.04749\nsource_urls: ["https://www.irishrail.ie/en-ie/station/galway-ceannt"]\nlast_updated: 2026-01-14\n---\n# Galway (Ceannt)\n\nCeannt Station is Galway’s central station near Eyre Square.\n\n## Travel alerts (station page)\n- The station page may show temporary notices (e.g., booking office closure or parking restrictions). For near-term travel, check the station page.\n\n## Location & contact\n- **Address:** Céannt Station, Galway, H91 T9CE (from station page).\n- Local booking office/ticket enquiries phone and lost-property email are listed on the station page.\n\n## Hours (published on station page)\n- **Staffing hours:** Mon–Fri 06:00–23:30\n- **Booking office:** Mon–Fri 06:00–13:

# Divide into Chunks

In [106]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=100,
    chunk_overlap=100)

chunks = text_splitter.split_documents(documents)

print(f"Divided into {len(chunks)} chunks")
print(f"First chunk:\n\n{chunks[0]}")

Divided into 447 chunks
First chunk:

page_content='---
type: station
station: Galway (Ceannt)' metadata={'source': 'knowledge_base/stations/galway_ceannt.md', 'doc_type': 'stations'}


# Produce vectors and store in Chroma

In [107]:
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")

if os.path.exists(LangChain_DB):
    print(f"Loading existing vector store from {LangChain_DB}")
    Chroma(persist_directory=LangChain_DB, embedding_function=embeddings).delete_collection()

vectorstore = Chroma.from_documents(documents=chunks, embedding=embeddings, persist_directory=LangChain_DB)
print(f"Vectorstore created with {vectorstore._collection.count()} documents")

Loading existing vector store from rail_vector_db
Vectorstore created with 447 documents


In [108]:
from langchain_openai import ChatOpenAI


retriever = vectorstore.as_retriever()
llm = ChatOpenAI(temperature=0, model_name=MODEL)

# Investigate the Vectors

In [109]:
collection = vectorstore._collection
count= collection.count()

sample_embedding = collection.get(limit=1, include=["embeddings"])["embeddings"][0]
dimensions = len(sample_embedding)
print(f"There are {count:,} vectors with {dimensions:,} dimensions in the vector store")

There are 447 vectors with 3,072 dimensions in the vector store


# Visualise the Vector Space

In [110]:
result = collection.get(include=['embeddings', 'documents', 'metadatas'])
vectors = np.array(result['embeddings'])
documents = result['documents']
metadatas = result['metadatas']
doc_types = [metadata['doc_type'] for metadata in metadatas]
colors = [['blue', 'green', 'red', 'orange'][['helpers', 'policies', 'routes', 'stations'].index(t)] for t in doc_types]

In [111]:
# Reduce the dimensionality of the vectors to 2D using t-SNE
# (t-distributed stochastic neighbor embedding)

tsne = TSNE(n_components=2, random_state=42)
reduced_vectors = tsne.fit_transform(vectors)

# Create the 2D scatter plot
fig = go.Figure(data=[go.Scatter(
    x=reduced_vectors[:, 0],
    y=reduced_vectors[:, 1],
    mode='markers',
    marker=dict(size=5, color=colors, opacity=0.8),
    text=[f"Type: {t}<br>Text: {d[:100]}..." for t, d in zip(doc_types, documents)],
    hoverinfo='text'
)])

fig.update_layout(title='2D Chroma Vector Store Visualization',
    scene=dict(xaxis_title='x',yaxis_title='y'),
    width=800,
    height=600,
    margin=dict(r=20, b=10, l=10, t=40)
)

fig.show()

In [112]:
# Let's try 3D!

tsne = TSNE(n_components=3, random_state=42)
reduced_vectors = tsne.fit_transform(vectors)

# Create the 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=reduced_vectors[:, 0],
    y=reduced_vectors[:, 1],
    z=reduced_vectors[:, 2],
    mode='markers',
    marker=dict(size=5, color=colors, opacity=0.8),
    text=[f"Type: {t}<br>Text: {d[:100]}..." for t, d in zip(doc_types, documents)],
    hoverinfo='text'
)])

fig.update_layout(
    title='3D Chroma Vector Store Visualization',
    scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'),
    width=900,
    height=700,
    margin=dict(r=10, b=10, l=10, t=40)
)

fig.show()

# Connect to Chroma using OPENAI Embedding

In [113]:
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
vectorstore= Chroma(persist_directory=LangChain_DB, embedding_function=embeddings)

In [114]:
from langchain_openai import ChatOpenAI


retriever = vectorstore.as_retriever()
llm = ChatOpenAI(temperature=0, model_name=MODEL)

In [115]:
retriever.invoke("Can you tell me something about heuston station in dublin?")

[Document(id='f2498f4e-c5bd-4efa-bdab-b353584f24d8', metadata={'doc_type': 'stations', 'source': 'knowledge_base/stations/dublin_heuston.md'}, page_content='---\ntype: station\nstation: Dublin Heuston'),
 Document(id='09ab19ba-b877-42e6-bf5f-5374b3c2489c', metadata={'source': 'knowledge_base/stations/dublin_heuston.md', 'doc_type': 'stations'}, page_content='last_updated: 2026-01-14\n---\n# Dublin Heuston (Dublin)'),
 Document(id='d72f874d-71fd-491f-8046-f106ad703b95', metadata={'source': 'knowledge_base/stations/dublin_heuston.md', 'doc_type': 'stations'}, page_content='Heuston is Dublin’s main **InterCity** station for routes to the south, southwest and west (Cork,'),
 Document(id='e532c3d6-646c-4fd1-a6b0-2e7ae86b6a00', metadata={'source': 'knowledge_base/stations/dublin_heuston.md', 'doc_type': 'stations'}, page_content='lat: 53.34655\nlon: -6.29162\nsource_urls: ["https://www.irishrail.ie/en-ie/station/dublin-heuston"]')]

In [116]:
from langchain_core.messages import SystemMessage, HumanMessage

def answer_question(question: str, history):
    docs = retriever.invoke(question)
    context = "\n\n".join(doc.page_content for doc in docs)
    system_prompt = SYSTEM_PROMPT_TEMPLATE.format(context=context)
    response = llm.invoke([SystemMessage(content=system_prompt), HumanMessage(content=question)])
    return response.content

In [117]:
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage

def _history_to_messages(history, max_turns: int = 8):
    """
    Converts Gradio 'messages' history -> LangChain messages.
    Keeps only the last max_turns user+assistant turns.
    """
    if not history:
        return []

    msgs = []
    # keep only the last N messages (2 per turn approx)
    tail = history[-max_turns*100:]

    for m in tail:
        role = (m.get("role") or "").lower()
        content = m.get("content") or ""
        if not content.strip():
            continue

        if role == "user":
            msgs.append(HumanMessage(content=content))
        elif role == "assistant":
            msgs.append(AIMessage(content=content))
        # ignore other roles if any (system/tool)

    return msgs

def answer_question(question: str, history):
    docs = retriever.invoke(question)
    context = "\n\n".join(doc.page_content for doc in docs)

    system_prompt = SYSTEM_PROMPT_TEMPLATE.format(context=context)

    messages = [SystemMessage(content=system_prompt)]
    messages += _history_to_messages(history, max_turns=100)  # <-- add chat history here
    messages += [HumanMessage(content=question)]

    response = llm.invoke(messages)
    return response.content


# Gradio UI with Talker

In [174]:
import gradio as gr

# -------------------------
# Callbacks (along with chat() above)
# -------------------------

def put_message_in_chatbot(message, history):
    if history is None:
        history = []
    if message is None or not str(message).strip():
        return "", history
    return "", history + [{"role": "user", "content": str(message).strip()}]


# -------------------------
# UI Styling
# -------------------------

CSS = """
#app-title { font-weight: 800; font-size: 1.25rem; }
#app-subtitle { opacity: 0.85; margin-top: 2px; margin-bottom: 4px; }
.gradio-container { max-width: 1600px !important; }
#chat-wrap { border-radius: 14px; margin-top: -40px; }
"""

# This helps autoplay work in many browsers by "unlocking" audio after first interaction
AUDIO_UNLOCK_HTML = """
<script>
(function() {
  if (window.__audioUnlockedInit) return;
  window.__audioUnlockedInit = true;

  const unlock = async () => {
    try {
      const AudioCtx = window.AudioContext || window.webkitAudioContext;
      if (!AudioCtx) return;

      const ctx = new AudioCtx();
      const buffer = ctx.createBuffer(1, 1, 22050);
      const src = ctx.createBufferSource();
      src.buffer = buffer;
      src.connect(ctx.destination);
      src.start(0);

      if (ctx.state === "suspended") await ctx.resume();
    } catch (e) {
      // Some browsers (esp. iOS Safari) may still require user to press play once.
    }
  };

  // First click or keypress unlocks audio
  document.addEventListener("click", unlock, { once: true });
  document.addEventListener("keydown", unlock, { once: true });
})();
</script>
"""

# -------------------------
# UI Definition
# -------------------------
with gr.Blocks(css=CSS, title="☘️ AI-powered Train Assistant for Ireland", fill_width=True) as ui:
    gr.Markdown(
        """
        <div style="text-align: center;font-size:28px"><strong>☘️🚆🇮🇪 AI-powered Train Assistant for Ireland ☘️🚆🇮🇪</strong></div>
        <div style="text-align: center;font-size:20px"> Powered by frontier models, GPT-4.1-mini and GPT-4o-mini-TTS, with RAG for accurate information.</div>
        <div style="text-align: center;font-size:18px">Timetables • Stations • Tickets • Bikes • Accessibility</div>
        """,
        elem_id="header",
    )

    # inject unlock script (no visible UI change)
    gr.HTML(AUDIO_UNLOCK_HTML)

    with gr.Row():
        chatbot = gr.Chatbot(
            height=360,
            type="messages",
            show_label=False,
            elem_id="chat-wrap",
        )

    with gr.Row():
        # ✅ Important: filepath output is most reliable for autoplay on web
        audio_output = gr.Audio(
            autoplay=True,
            type="filepath",
            label="🔊 Spoken reply",
        )

    with gr.Row():
        message = gr.Textbox(
            label="💬 Plan your journey",
            placeholder="Type a message and press Enter… (e.g., Dublin to Belfast on 25 Jan after 14:00)",
            lines=1,
            autofocus=True,
        )

    # Enter-to-send
    message.submit(
        put_message_in_chatbot,
        inputs=[message, chatbot],
        outputs=[message, chatbot],
    ).then(
        chat,  # MUST return exactly: (updated_history, audio_filepath_or_None)
        inputs=chatbot,
        outputs=[chatbot, audio_output],
    )

ui.queue()
ui.launch(inbrowser=True)


* Running on local URL:  http://127.0.0.1:7912
* To create a public link, set `share=True` in `launch()`.





A function (chat) returned too many output values (needed: 2, returned: 3). Ignoring extra values.
    Output components:
        [chatbot, audio]
    Output values returned:
        [[{'role': 'user', 'metadata': None, 'content': 'Hi there', 'options': None}, {'role': 'assistant', 'metadata': None, 'content': 'Hello! How can I assist you with your rail travel plans in Ireland today?', 'options': None}, {'role': 'user', 'metadata': None, 'content': 'I would like to go on a trip to Belfast from dublin', 'options': None}, {'role': 'assistant', 'metadata': None, 'content': 'Could you please provide the date and time you plan to travel from Dublin Connolly to Belfast? This will help me find the best train options for you.', 'options': None}, {'role': 'user', 'metadata': None, 'content': 'I would like to travel on next friday after 4 pm', 'options': None}, {'role': 'assistant', 'content': 'Here are the nearest 3 train options from Dublin Connolly to Belfast on next Friday, January 23rd, de

DB TOOL CALLED: search_train_services(departure_station='Dublin Connolly', arrival_station='Belfast', service_date='2026-01-23', depart_after='18:00', limit=3)



A function (chat) returned too many output values (needed: 2, returned: 3). Ignoring extra values.
    Output components:
        [chatbot, audio]
    Output values returned:
        [[{'role': 'user', 'metadata': None, 'content': 'Hi there', 'options': None}, {'role': 'assistant', 'metadata': None, 'content': 'Hello! How can I assist you with your rail travel in Ireland today?', 'options': None}, {'role': 'user', 'metadata': None, 'content': 'ı would like to go to belfast from dublin', 'options': None}, {'role': 'assistant', 'metadata': None, 'content': 'Could you please provide the date and time you would like to travel from Dublin to Belfast? This will help me find the best train options for you.', 'options': None}, {'role': 'user', 'metadata': None, 'content': 'next friday after 6 pm', 'options': None}, {'role': 'assistant', 'content': 'Here are the nearest 2 train options from Dublin Connolly to Belfast on next Friday, January 23, after 6 pm:\n\n1. Departure at 18:50, arrival at 