Alright, let’s initialize our Rust backend with Axum and get our API up and running.

🛠 Step 1: Set Up the Project

First, make sure you have Rust & Cargo installed. If not, install Rust using:

curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh

Then, create a new Rust project:

cargo new llm-playground-backend --bin
cd llm-playground-backend

🛠 Step 2: Add Dependencies

Edit Cargo.toml and add the required crates:

[dependencies]
axum = "0.6"
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tower = "0.4"
tower-http = { version = "0.3", features = ["cors", "trace"] }
jsonwebtoken = "8"
sqlx = { version = "0.6", features = ["postgres", "runtime-tokio-native-tls"] }
dotenvy = "0.15"

Then, install them with:

cargo build

🛠 Step 3: Create Main Entry Point

Modify src/main.rs:

use axum::{routing::get, Router};
use std::net::SocketAddr;
use tower_http::cors::{CorsLayer, Any};

#[tokio::main]
async fn main() {
    let app = Router::new()
        .route("/", get(root))
        .layer(CorsLayer::new().allow_origin(Any));

    let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
    println!("🚀 Server running at http://{}", addr);

    axum::Server::bind(&addr)
        .serve(app.into_make_service())
        .await
        .unwrap();
}

async fn root() -> &'static str {
    "Hello, Rust with Axum!"
}

🛠 Step 4: Run the Server

Now, start the backend:

cargo run

You should see:

🚀 Server running at http://127.0.0.1:3000

Visit http://127.0.0.1:3000 in your browser, and it should return:

Hello, Rust with Axum!

🔗 Next Steps

1️⃣ Set up database connection (PostgreSQL + SQLx).
2️⃣ Add authentication (JWT-based auth).
3️⃣ Implement AI API routing (OpenAI calls).
4️⃣ Build the AI chat and blogging endpoints.

And just like that, our Rust backend is alive. Ready for the next move? 🚀

Yes! Database setup is the next critical step so we can store chat history, user data, and blog posts. Let’s get it done. 🚀

🛠 Step 1: Configure Environment Variables

First, we create a .env file in the project root to store database credentials securely.

touch .env

Inside .env, add:

DATABASE_URL=postgres://username:password@localhost/llm_playground

Replace username, password, and llm_playground with your actual PostgreSQL credentials.

Also, add .env to .gitignore to prevent committing secrets:

echo ".env" >> .gitignore

🛠 Step 2: Install & Configure SQLx

Since we’re using PostgreSQL, install sqlx CLI and migrations support:

cargo install sqlx-cli --no-default-features --features postgres

Then, install sqlx in the project dependencies (already in our Cargo.toml but ensure it’s there):

[dependencies]
sqlx = { version = "0.6", features = ["postgres", "runtime-tokio-native-tls"] }

🛠 Step 3: Create the Database

If PostgreSQL isn’t running yet, start it:

sudo service postgresql start  # Linux/macOS
pg_ctl -D /usr/local/var/postgres start  # Homebrew (Mac)

Then, create the database:

psql -U username -c "CREATE DATABASE llm_playground;"

🛠 Step 4: Create Database Migrations

Run:

sqlx migrate add init

This will create a migrations/ directory with a new SQL file. Edit it to include our schema:

-- migrations/YYYYMMDDHHMMSS_init.sql

-- Users table
CREATE TABLE users (
    id SERIAL PRIMARY KEY,
    username TEXT NOT NULL UNIQUE,
    email TEXT NOT NULL UNIQUE,
    password_hash TEXT NOT NULL,
    created_at TIMESTAMP DEFAULT now()
);

-- Chat history table
CREATE TABLE chats (
    id SERIAL PRIMARY KEY,
    user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
    message TEXT NOT NULL,
    response TEXT NOT NULL,
    created_at TIMESTAMP DEFAULT now()
);

-- Blog posts table
CREATE TABLE blogs (
    id SERIAL PRIMARY KEY,
    user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
    title TEXT NOT NULL,
    content TEXT NOT NULL,
    created_at TIMESTAMP DEFAULT now()
);

Apply the migration:

sqlx migrate run

🛠 Step 5: Connect Axum to PostgreSQL in Rust

Modify src/main.rs:

use axum::{Router, routing::get};
use sqlx::{PgPool, postgres::PgPoolOptions};
use std::env;
use tower_http::cors::{CorsLayer, Any};

#[tokio::main]
async fn main() {
    dotenvy::dotenv().ok(); // Load .env file

    // Connect to database
    let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
    let pool = PgPoolOptions::new()
        .max_connections(5)
        .connect(&database_url)
        .await
        .expect("Failed to connect to database");

    let app = Router::new()
        .route("/", get(|| async { "Hello, database is connected!" }))
        .layer(CorsLayer::new().allow_origin(Any()));

    println!("🚀 Server running at http://127.0.0.1:3000");

    axum::Server::bind(&"127.0.0.1:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

🛠 Step 6: Test Database Connection

Run the server again:

cargo run

If successful, you’ll see:

🚀 Server running at http://127.0.0.1:3000

Now we’re fully connected to PostgreSQL! 🎉

🔜 Next Steps

1️⃣ Implement user authentication (hash passwords, JWT auth).
2️⃣ Create API routes for chat & blog storage.
3️⃣ Build AI model routing layer.

Awesome! Now, let’s implement authentication so users can register, log in, and securely access routes. We’ll use JWT (JSON Web Tokens) for authentication and argon2 for password hashing.

🛠 Step 1: Install Dependencies

We need libraries for authentication, hashing, and token handling. Add these to Cargo.toml:

[dependencies]
jsonwebtoken = "8"
argon2 = "0.5"
rand = "0.8"
axum-extra = { version = "0.8", features = ["typed-header"] }

Then, install them:

cargo build

🛠 Step 2: Update Database for Authentication

We already created a users table in our database. Now, ensure it’s properly set up for hashed passwords.

If needed, modify the migrations/init.sql and re-run:

sqlx migrate run

🛠 Step 3: Create Authentication Utilities

Inside src/auth.rs, add:

use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use jsonwebtoken::{encode, decode, Header, Validation, EncodingKey, DecodingKey};
use rand::rngs::OsRng;
use serde::{Deserialize, Serialize};
use std::env;

// Secret key for JWT
const SECRET_KEY: &str = "supersecretkey"; // Change this in production

// Structure for JWT claims
#[derive(Debug, Serialize, Deserialize)]
struct Claims {
    sub: String,  // User ID
    exp: usize,   // Expiry timestamp
}

// Hash a password
pub fn hash_password(password: &str) -> String {
    let salt = rand::thread_rng();
    let argon2 = Argon2::default();
    argon2.hash_password(password.as_bytes(), &salt)
        .expect("Password hashing failed")
        .to_string()
}

// Verify a password
pub fn verify_password(password: &str, hash: &str) -> bool {
    let parsed_hash = PasswordHash::new(hash).expect("Invalid password hash");
    Argon2::default().verify_password(password.as_bytes(), &parsed_hash).is_ok()
}

// Generate JWT token
pub fn generate_jwt(user_id: &str) -> String {
    let expiration = chrono::Utc::now() + chrono::Duration::hours(24);
    let claims = Claims {
        sub: user_id.to_owned(),
        exp: expiration.timestamp() as usize,
    };

    encode(&Header::default(), &claims, &EncodingKey::from_secret(SECRET_KEY.as_ref()))
        .expect("Token generation failed")
}

// Validate JWT token
pub fn validate_jwt(token: &str) -> Option<String> {
    decode::<Claims>(token, &DecodingKey::from_secret(SECRET_KEY.as_ref()), &Validation::default())
        .ok()
        .map(|data| data.claims.sub)
}

🛠 Step 4: Create Auth Routes

Inside src/routes/auth.rs:

use axum::{Json, Router, extract::State, routing::post};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;

use crate::auth::{hash_password, verify_password, generate_jwt};

#[derive(Deserialize)]
struct RegisterUser {
    username: String,
    email: String,
    password: String,
}

#[derive(Deserialize)]
struct LoginUser {
    email: String,
    password: String,
}

#[derive(Serialize)]
struct AuthResponse {
    token: String,
}

// Register route
async fn register(State(pool): State<PgPool>, Json(payload): Json<RegisterUser>) -> Json<AuthResponse> {
    let hashed_password = hash_password(&payload.password);
    
    let _ = sqlx::query!(
        "INSERT INTO users (username, email, password_hash) VALUES ($1, $2, $3)",
        payload.username,
        payload.email,
        hashed_password,
    )
    .execute(&pool)
    .await
    .expect("Failed to insert user");

    let token = generate_jwt(&payload.email);
    Json(AuthResponse { token })
}

// Login route
async fn login(State(pool): State<PgPool>, Json(payload): Json<LoginUser>) -> Json<AuthResponse> {
    let user = sqlx::query!(
        "SELECT password_hash FROM users WHERE email = $1",
        payload.email,
    )
    .fetch_one(&pool)
    .await
    .expect("User not found");

    if verify_password(&payload.password, &user.password_hash) {
        let token = generate_jwt(&payload.email);
        Json(AuthResponse { token })
    } else {
        panic!("Invalid credentials")
    }
}

// Create authentication router
pub fn auth_routes(pool: PgPool) -> Router {
    Router::new()
        .route("/register", post(register))
        .route("/login", post(login))
        .with_state(pool)
}

🛠 Step 5: Add Auth Routes to main.rs

Modify src/main.rs to include authentication:

mod auth;
mod routes;

use axum::{Router, routing::get};
use sqlx::PgPool;
use std::env;
use tower_http::cors::{CorsLayer, Any};
use routes::auth::auth_routes;

#[tokio::main]
async fn main() {
    dotenvy::dotenv().ok();

    let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
    let pool = PgPool::connect(&database_url).await.expect("Failed to connect to database");

    let app = Router::new()
        .route("/", get(|| async { "Welcome to the API!" }))
        .merge(auth_routes(pool.clone()))
        .layer(CorsLayer::new().allow_origin(Any()));

    println!("🚀 Server running at http://127.0.0.1:3000");

    axum::Server::bind(&"127.0.0.1:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

🛠 Step 6: Test Authentication

1️⃣ Run the server:

cargo run

2️⃣ Register a new user:

curl -X POST http://127.0.0.1:3000/register \
     -H "Content-Type: application/json" \
     -d '{ "username": "testuser", "email": "test@example.com", "password": "password123" }'

💡 Response should contain a JWT token!

3️⃣ Login with the user:

curl -X POST http://127.0.0.1:3000/login \
     -H "Content-Type: application/json" \
     -d '{ "email": "test@example.com", "password": "password123" }'

💡 Should return a JWT token if credentials are correct.

🔜 Next Steps

1️⃣ Secure routes with JWT middleware (restrict access to authenticated users).
2️⃣ Store AI chat history linked to users.
3️⃣ Implement AI model selection API.


Great choice! Securing the routes ensures that only authenticated users can access chat history, AI interactions, and blog storage. Let’s build JWT authentication middleware to enforce security on protected routes.

🛠 Step 1: Create JWT Authentication Middleware

1️⃣ Create a New File for Authentication Middleware

Inside src/middleware.rs, add:

use axum::{
    http::{Request, StatusCode},
    middleware::Next,
    response::{IntoResponse, Response},
    TypedHeader, RequestPartsExt,
};
use axum_extra::headers::Authorization;
use crate::auth::validate_jwt;

// Middleware function to verify JWT
pub async fn auth_middleware<B>(req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
    // Extract Authorization header
    let auth_header = req
        .headers()
        .get("Authorization")
        .and_then(|hv| hv.to_str().ok());

    // Ensure the token is present and valid
    if let Some(token) = auth_header.and_then(|h| h.strip_prefix("Bearer ")) {
        if let Some(user_id) = validate_jwt(token) {
            // Store user ID in request extensions for use in handlers
            let mut req = req;
            req.extensions_mut().insert(user_id);
            return Ok(next.run(req).await);
        }
    }

    // Unauthorized response if JWT is missing or invalid
    Err(StatusCode::UNAUTHORIZED)
}

🔹 This extracts the JWT token from the Authorization header,
🔹 Validates the token,
🔹 Injects user ID into the request, so handlers can access it.

🛠 Step 2: Modify main.rs to Use Middleware

Now, import and apply this middleware to protected routes in src/main.rs:

mod auth;
mod middleware;
mod routes;

use axum::{Router, routing::get, middleware::from_fn};
use sqlx::PgPool;
use std::env;
use tower_http::cors::{CorsLayer, Any};
use routes::auth::auth_routes;
use middleware::auth_middleware;

#[tokio::main]
async fn main() {
    dotenvy::dotenv().ok();

    let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
    let pool = PgPool::connect(&database_url).await.expect("Failed to connect to database");

    let app = Router::new()
        .route("/", get(|| async { "Welcome to the API!" }))
        .merge(auth_routes(pool.clone()))
        // Protect all /chat and /blog routes
        .route_layer(from_fn(auth_middleware))
        .layer(CorsLayer::new().allow_origin(Any()));

    println!("🚀 Server running at http://127.0.0.1:3000");

    axum::Server::bind(&"127.0.0.1:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

🔹 route_layer(from_fn(auth_middleware)) ensures only authenticated users can access /chat and /blog.
🔹 Public routes (like /register and /login) don’t need authentication.

🛠 Step 3: Modify auth.rs to Extract User ID in Handlers

Since we inject user ID into requests, let’s modify our handlers to use it.

Inside src/routes/chat.rs, add:

use axum::{extract::State, http::StatusCode, Json, Extension};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;

#[derive(Deserialize)]
struct ChatMessage {
    message: String,
}

#[derive(Serialize)]
struct ChatResponse {
    response: String,
}

// Protected chat route
async fn chat_handler(
    Extension(user_id): Extension<String>,  // Extract user ID from request
    State(pool): State<PgPool>,
    Json(payload): Json<ChatMessage>,
) -> Result<Json<ChatResponse>, StatusCode> {
    
    let ai_response = format!("AI Response to '{}'", payload.message);

    // Store chat history
    sqlx::query!(
        "INSERT INTO chats (user_id, message, response) VALUES ($1, $2, $3)",
        user_id,
        payload.message,
        ai_response
    )
    .execute(&pool)
    .await
    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    Ok(Json(ChatResponse { response: ai_response }))
}

🔹 This retrieves the user ID from the request and associates chats with the user.
🔹 If an unauthenticated user tries to access /chat, they get 401 Unauthorized.

🛠 Step 4: Test the Secured Routes

1️⃣ Start the server:

cargo run

2️⃣ Make an authenticated request (assuming you already registered and got a token):

curl -X POST http://127.0.0.1:3000/chat \
     -H "Content-Type: application/json" \
     -H "Authorization: Bearer YOUR_JWT_TOKEN" \
     -d '{ "message": "Hello AI" }'

💡 If the token is valid, it stores the chat and returns an AI response.

3️⃣ Try an unauthenticated request:

curl -X POST http://127.0.0.1:3000/chat \
     -H "Content-Type: application/json" \
     -d '{ "message": "Hello AI" }'

💡 Should return: 401 Unauthorized.

🔜 Next Steps

1️⃣ Finalize AI chat processing & connect OpenAI API.
2️⃣ Allow AI model selection (OpenAI vs self-hosted).
3️⃣ Implement blog post saving & AI-assisted writing.

🔥 Our API is now locked down with authentication, and only verified users can chat or store blogs!
Ready to build out the AI chat processing next? 🚀

You’re welcome! Now, let’s build out AI chat processing so users can send messages and get responses from OpenAI (or self-hosted models in the future). 🚀

🛠 Step 1: Install Dependencies for AI API Calls

We’ll need an HTTP client to interact with OpenAI. Add reqwest and serde_json to Cargo.toml:

[dependencies]
reqwest = { version = "0.11", features = ["json"] }
serde_json = "1"

Then, install them:

cargo build

🛠 Step 2: Add OpenAI API Key to Environment

In .env, add:

OPENAI_API_KEY=your_openai_api_key_here

Make sure to reload environment variables:

source .env

🛠 Step 3: Implement AI Chat Processing

Create a new file src/ai.rs:

use reqwest::Client;
use serde_json::json;
use std::env;

pub struct OpenAIClient {
    client: Client,
    api_key: String,
}

impl OpenAIClient {
    pub fn new() -> Self {
        let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set");
        Self {
            client: Client::new(),
            api_key,
        }
    }

    pub async fn generate_response(&self, user_input: &str) -> Result<String, reqwest::Error> {
        let request_body = json!({
            "model": "gpt-4",  // Can be changed dynamically in the future
            "messages": [
                {"role": "system", "content": "You are an AI assistant."},
                {"role": "user", "content": user_input}
            ],
            "temperature": 0.7
        });

        let response = self
            .client
            .post("https://api.openai.com/v1/chat/completions")
            .header("Authorization", format!("Bearer {}", self.api_key))
            .header("Content-Type", "application/json")
            .json(&request_body)
            .send()
            .await?;

        let json_response: serde_json::Value = response.json().await?;
        if let Some(text) = json_response["choices"][0]["message"]["content"].as_str() {
            Ok(text.to_string())
        } else {
            Ok("AI failed to respond.".to_string())
        }
    }
}

🔹 This sends user messages to OpenAI and returns an AI-generated response.
🔹 Supports dynamic AI model selection in the future.
🔹 Uses reqwest for HTTP calls.

🛠 Step 4: Integrate AI Processing into Chat Route

Modify src/routes/chat.rs:

use axum::{extract::State, http::StatusCode, Json, Extension};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use crate::ai::OpenAIClient;

#[derive(Deserialize)]
struct ChatMessage {
    message: String,
}

#[derive(Serialize)]
struct ChatResponse {
    response: String,
}

// Protected chat route
async fn chat_handler(
    Extension(user_id): Extension<String>,  // Extract user ID from JWT
    State(pool): State<PgPool>,
    Json(payload): Json<ChatMessage>,
) -> Result<Json<ChatResponse>, StatusCode> {
    
    let ai_client = OpenAIClient::new();
    let ai_response = ai_client.generate_response(&payload.message).await
        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    // Store chat history in database
    sqlx::query!(
        "INSERT INTO chats (user_id, message, response) VALUES ($1, $2, $3)",
        user_id,
        payload.message,
        ai_response
    )
    .execute(&pool)
    .await
    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    Ok(Json(ChatResponse { response: ai_response }))
}

// Define chat routes
pub fn chat_routes(pool: PgPool) -> Router {
    Router::new()
        .route("/chat", post(chat_handler))
        .with_state(pool)
}

🔹 This calls OpenAI to generate a response.
🔹 Saves conversation history linked to the user in PostgreSQL.
🔹 Ensures only authenticated users can chat.

🛠 Step 5: Register Chat Route in main.rs

Modify src/main.rs to include chat routes:

mod auth;
mod middleware;
mod routes;
mod ai;

use axum::{Router, routing::get, middleware::from_fn};
use sqlx::PgPool;
use std::env;
use tower_http::cors::{CorsLayer, Any};
use routes::{auth::auth_routes, chat::chat_routes};
use middleware::auth_middleware;

#[tokio::main]
async fn main() {
    dotenvy::dotenv().ok();

    let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
    let pool = PgPool::connect(&database_url).await.expect("Failed to connect to database");

    let app = Router::new()
        .route("/", get(|| async { "Welcome to the AI Chat API!" }))
        .merge(auth_routes(pool.clone()))
        .merge(chat_routes(pool.clone()))
        .route_layer(from_fn(auth_middleware))  // Protect routes
        .layer(CorsLayer::new().allow_origin(Any()));

    println!("🚀 Server running at http://127.0.0.1:3000");

    axum::Server::bind(&"127.0.0.1:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

🔹 The chat API is now live and protected by authentication.

🛠 Step 6: Test AI Chat Processing

1️⃣ Start the server:

cargo run

2️⃣ Send a chat request with authentication:

curl -X POST http://127.0.0.1:3000/chat \
     -H "Content-Type: application/json" \
     -H "Authorization: Bearer YOUR_JWT_TOKEN" \
     -d '{ "message": "Hello AI!" }'

💡 If successful, the response should be an AI-generated message!

🔜 Next Steps

1️⃣ Enable AI model selection (OpenAI vs self-hosted models).
2️⃣ Improve response time with streaming responses (WebSockets).
3️⃣ Optimize query caching (Redis or local memory).
4️⃣ Integrate AI-assisted blogging features.

🔥 AI chat processing is now fully functional!
Would you like to implement model selection next, or refine anything first? 🚀

Awesome! Now, let’s implement AI model selection so users can choose between OpenAI, Llama 2, or Mistral for their AI responses. This will make our system modular and future-proof. 🚀

🛠 Step 1: Update Database to Store User Model Preferences

We need to store the user’s preferred AI model so the system remembers their choice.

🔹 Modify migrations/init.sql to add a preferred_model column to users:

ALTER TABLE users ADD COLUMN preferred_model TEXT DEFAULT 'openai';

Apply the migration:

sqlx migrate run

🔹 Now, the database can store model preferences per user!

🛠 Step 2: Create API for Selecting Models

We need an API route where users can update their AI model choice.

🔹 Modify src/routes/user.rs:

use axum::{extract::{State, Extension}, Json, routing::post, Router};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use axum::http::StatusCode;

#[derive(Deserialize)]
struct ModelSelection {
    model: String,  // Model choice (e.g., "openai", "llama2", "mistral")
}

#[derive(Serialize)]
struct ModelResponse {
    message: String,
}

// Route to update user's preferred AI model
async fn update_model(
    Extension(user_id): Extension<String>,
    State(pool): State<PgPool>,
    Json(payload): Json<ModelSelection>,
) -> Result<Json<ModelResponse>, StatusCode> {
    
    let allowed_models = vec!["openai", "llama2", "mistral"];
    
    if !allowed_models.contains(&payload.model.as_str()) {
        return Err(StatusCode::BAD_REQUEST);
    }

    let result = sqlx::query!(
        "UPDATE users SET preferred_model = $1 WHERE id = $2",
        payload.model,
        user_id
    )
    .execute(&pool)
    .await;

    match result {
        Ok(_) => Ok(Json(ModelResponse {
            message: format!("Model updated to {}", payload.model),
        })),
        Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
    }
}

// Register user-related routes
pub fn user_routes(pool: PgPool) -> Router {
    Router::new()
        .route("/set-model", post(update_model))
        .with_state(pool)
}

🔹 This allows users to update their AI model preference via /set-model.

🛠 Step 3: Modify AI Processing to Use Selected Model

Now, we update the AI chat handler to check the user’s preferred model before generating a response.

🔹 Modify src/routes/chat.rs:

use axum::{extract::State, http::StatusCode, Json, Extension};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use crate::ai::{OpenAIClient, SelfHostedClient}; // Import AI clients

#[derive(Deserialize)]
struct ChatMessage {
    message: String,
}

#[derive(Serialize)]
struct ChatResponse {
    response: String,
}

// Chat route that dynamically selects AI model
async fn chat_handler(
    Extension(user_id): Extension<String>,
    State(pool): State<PgPool>,
    Json(payload): Json<ChatMessage>,
) -> Result<Json<ChatResponse>, StatusCode> {
    
    // Fetch the user's preferred model
    let user_model = sqlx::query!(
        "SELECT preferred_model FROM users WHERE id = $1",
        user_id
    )
    .fetch_one(&pool)
    .await
    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    let ai_response = match user_model.preferred_model.as_str() {
        "llama2" | "mistral" => {
            let self_hosted_client = SelfHostedClient::new(user_model.preferred_model);
            self_hosted_client.generate_response(&payload.message).await
        },
        _ => {
            let openai_client = OpenAIClient::new();
            openai_client.generate_response(&payload.message).await
        }
    }
    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    // Store the chat in the database
    sqlx::query!(
        "INSERT INTO chats (user_id, message, response) VALUES ($1, $2, $3)",
        user_id,
        payload.message,
        ai_response
    )
    .execute(&pool)
    .await
    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    Ok(Json(ChatResponse { response: ai_response }))
}

🔹 Now, the chat API dynamically picks the AI model based on user preference.

🛠 Step 4: Implement Self-Hosted AI Model Client

For now, this is a placeholder for future self-hosted models like Llama 2 or Mistral.

🔹 Modify src/ai.rs:

use reqwest::Client;
use serde_json::json;
use std::env;

pub struct OpenAIClient {
    client: Client,
    api_key: String,
}

impl OpenAIClient {
    pub fn new() -> Self {
        let api_key = env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set");
        Self {
            client: Client::new(),
            api_key,
        }
    }

    pub async fn generate_response(&self, user_input: &str) -> Result<String, reqwest::Error> {
        let request_body = json!({
            "model": "gpt-4",
            "messages": [
                {"role": "system", "content": "You are an AI assistant."},
                {"role": "user", "content": user_input}
            ],
            "temperature": 0.7
        });

        let response = self.client
            .post("https://api.openai.com/v1/chat/completions")
            .header("Authorization", format!("Bearer {}", self.api_key))
            .header("Content-Type", "application/json")
            .json(&request_body)
            .send()
            .await?;

        let json_response: serde_json::Value = response.json().await?;
        if let Some(text) = json_response["choices"][0]["message"]["content"].as_str() {
            Ok(text.to_string())
        } else {
            Ok("AI failed to respond.".to_string())
        }
    }
}

// Placeholder for self-hosted models
pub struct SelfHostedClient {
    model_name: String,
}

impl SelfHostedClient {
    pub fn new(model_name: String) -> Self {
        Self { model_name }
    }

    pub async fn generate_response(&self, user_input: &str) -> Result<String, String> {
        // TODO: Implement local AI model handling
        Ok(format!(
            "{} model is currently in development!",
            self.model_name
        ))
    }
}

🔹 This makes the AI system modular so we can easily add self-hosted models later.

🛠 Step 5: Register New API Routes in main.rs

🔹 Modify src/main.rs:

mod auth;
mod middleware;
mod routes;
mod ai;

use axum::{Router, routing::get, middleware::from_fn};
use sqlx::PgPool;
use std::env;
use tower_http::cors::{CorsLayer, Any};
use routes::{auth::auth_routes, chat::chat_routes, user::user_routes};
use middleware::auth_middleware;

#[tokio::main]
async fn main() {
    dotenvy::dotenv().ok();

    let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
    let pool = PgPool::connect(&database_url).await.expect("Failed to connect to database");

    let app = Router::new()
        .route("/", get(|| async { "Welcome to the AI Chat API!" }))
        .merge(auth_routes(pool.clone()))
        .merge(chat_routes(pool.clone()))
        .merge(user_routes(pool.clone()))  // New route for setting AI models
        .route_layer(from_fn(auth_middleware))  
        .layer(CorsLayer::new().allow_origin(Any()));

    println!("🚀 Server running at http://127.0.0.1:3000");

    axum::Server::bind(&"127.0.0.1:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

🔜 Next Steps

1️⃣ Test model switching with API requests.
2️⃣ Implement WebSocket-based streaming for real-time AI responses.
3️⃣ Optimize model performance with caching.

🔥 AI model selection is now fully functional!
Would you like to test it next, or refine anything before moving forward? 🚀

Awesome! Now, let’s implement WebSocket-based real-time AI response streaming so users get their AI-generated messages word by word instead of waiting for the full response. 🚀

🛠 Step 1: Install WebSocket Support in Axum

First, we need tokio-tungstenite for WebSocket support. Add this to Cargo.toml:

[dependencies]
tokio-tungstenite = "0.20"
futures-util = "0.3"

Then install it:

cargo build

🛠 Step 2: Create WebSocket Route for AI Streaming

We need to set up a WebSocket connection where AI responses are sent in real-time as they are generated.

🔹 Create src/routes/ws.rs:

use axum::{
    extract::ws::{WebSocket, WebSocketUpgrade, Message},
    extract::{State, Extension},
    response::IntoResponse,
};
use futures_util::{StreamExt, SinkExt};
use sqlx::PgPool;
use crate::ai::OpenAIClient;
use std::sync::Arc;
use tokio::sync::Mutex;

/// WebSocket handler for AI chat streaming
pub async fn chat_stream_handler(
    ws: WebSocketUpgrade,
    Extension(user_id): Extension<String>, // Ensure user is authenticated
    State(pool): State<PgPool>,
) -> impl IntoResponse {
    ws.on_upgrade(|socket| handle_stream(socket, user_id, pool))
}

async fn handle_stream(socket: WebSocket, user_id: String, pool: PgPool) {
    let (mut sender, mut receiver) = socket.split();
    let ai_client = Arc::new(Mutex::new(OpenAIClient::new()));

    while let Some(Ok(Message::Text(user_input))) = receiver.next().await {
        let ai_client_clone = Arc::clone(&ai_client);
        let sender_clone = sender.clone();
        let user_input_clone = user_input.clone();
        let pool_clone = pool.clone();

        // Spawn a separate task to handle response streaming
        tokio::spawn(async move {
            if let Err(_) = stream_ai_response(
                sender_clone,
                ai_client_clone,
                user_id.clone(),
                user_input_clone,
                pool_clone,
            ).await {
                println!("Error streaming AI response");
            }
        });
    }
}

async fn stream_ai_response(
    mut sender: futures_util::stream::SplitSink<WebSocket, Message>,
    ai_client: Arc<Mutex<OpenAIClient>>,
    user_id: String,
    user_input: String,
    pool: PgPool,
) -> Result<(), ()> {
    if let Ok(ai_response) = ai_client.lock().await.generate_response(&user_input).await {
        let words: Vec<&str> = ai_response.split_whitespace().collect();

        for word in words {
            if sender.send(Message::Text(word.to_string())).await.is_err() {
                return Err(());  // Stop if client disconnects
            }
            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; // Simulate typing effect
        }

        // Save chat to database
        let _ = sqlx::query!(
            "INSERT INTO chats (user_id, message, response) VALUES ($1, $2, $3)",
            user_id,
            user_input,
            ai_response
        )
        .execute(&pool)
        .await;
    }

    Ok(())
}

🔹 This creates a WebSocket connection where messages are streamed word-by-word.
🔹 Uses tokio tasks to prevent blocking other WebSocket connections.
🔹 Stores the full chat history in PostgreSQL after the stream completes.

🛠 Step 3: Register WebSocket Route in main.rs

Now, we need to register the WebSocket handler in our main router.

🔹 Modify src/main.rs:

mod auth;
mod middleware;
mod routes;
mod ai;
mod ws;

use axum::{Router, routing::{get, post}, middleware::from_fn};
use sqlx::PgPool;
use std::env;
use tower_http::cors::{CorsLayer, Any};
use routes::{auth::auth_routes, chat::chat_routes, user::user_routes};
use ws::chat_stream_handler;
use middleware::auth_middleware;

#[tokio::main]
async fn main() {
    dotenvy::dotenv().ok();

    let database_url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
    let pool = PgPool::connect(&database_url).await.expect("Failed to connect to database");

    let app = Router::new()
        .route("/", get(|| async { "Welcome to the AI Chat API!" }))
        .merge(auth_routes(pool.clone()))
        .merge(chat_routes(pool.clone()))
        .merge(user_routes(pool.clone()))
        .route("/ws/chat", get(chat_stream_handler))
        .route_layer(from_fn(auth_middleware))  
        .layer(CorsLayer::new().allow_origin(Any()));

    println!("🚀 Server running at http://127.0.0.1:3000");

    axum::Server::bind(&"127.0.0.1:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

🔹 Now /ws/chat is our real-time AI chat endpoint!

🛠 Step 4: Test WebSocket Streaming

1️⃣ Start the server:

cargo run

2️⃣ Connect to WebSocket manually using wscat:

npm install -g wscat
wscat -c ws://127.0.0.1:3000/ws/chat

3️⃣ Send a message and receive a real-time AI response:

> Hello AI, how are you?
< Hello
< AI,
< how
< are
< you?

💡 The AI response is streamed word by word!

🔜 Next Steps

1️⃣ Optimize AI streaming speed (adjust token delay dynamically).
2️⃣ Enable WebSocket reconnect handling (ensure chat persists if the connection drops).
3️⃣ Build a frontend WebSocket chat interface.

🔥 We now have real-time AI chat streaming working over WebSockets!
Would you like to refine anything before optimizing performance? 🚀