Skip to content

Commit 7514de6

Browse files
committed
generated file: main.py
1 parent 03458c1 commit 7514de6

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

main.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from fastapi import FastAPI, Request
2+
from fastapi.responses import JSONResponse
3+
from fastapi.middleware.cors import CORSMiddleware
4+
from fastapi.encoders import jsonable_encoder
5+
from typing import Optional
6+
from pydantic import BaseModel, validator
7+
import os
8+
import json
9+
from openai import OpenAI
10+
from sqlalchemy import create_engine
11+
from sqlalchemy.orm import sessionmaker
12+
from sqlalchemy.ext.declarative import declarative_base
13+
from .utils.logger import get_logger
14+
from .core.models.models import RequestBody, ResponseModel
15+
from .core.services.openai_service import OpenAIService
16+
from .db.models import User
17+
from .db.schemas import UserSchema
18+
from .auth.jwt_handler import create_access_token
19+
20+
app = FastAPI()
21+
logger = get_logger(__name__)
22+
23+
# Load environment variables
24+
if os.environ.get("ENV") == "production":
25+
logger.info("Production environment detected, loading environment variables.")
26+
from dotenv import load_dotenv
27+
load_dotenv()
28+
29+
# Define the database engine
30+
DATABASE_URL = os.environ.get("DATABASE_URL")
31+
32+
engine = create_engine(DATABASE_URL)
33+
34+
# Define a session factory
35+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
36+
37+
# Define a base model for database models
38+
Base = declarative_base()
39+
40+
# Create a database session
41+
def get_db():
42+
db = SessionLocal()
43+
try:
44+
yield db
45+
finally:
46+
db.close()
47+
48+
# CORS middleware for development
49+
app.add_middleware(
50+
CORSMiddleware,
51+
allow_origins=["*"],
52+
allow_credentials=True,
53+
allow_methods=["*"],
54+
allow_headers=["*"],
55+
)
56+
57+
@app.on_event("startup")
58+
async def startup_event():
59+
Base.metadata.create_all(bind=engine)
60+
61+
# Initialize the OpenAI service
62+
openai_service = OpenAIService(os.environ.get("OPENAI_API_KEY"))
63+
64+
# --- Authentication ---
65+
66+
@app.post("/register", response_model=UserSchema)
67+
async def register(request_body: UserSchema, db: Session = fastapi.Depends(get_db)):
68+
new_user = User(**request_body.dict())
69+
db.add(new_user)
70+
db.commit()
71+
db.refresh(new_user)
72+
return new_user
73+
74+
@app.post("/login", response_model=ResponseModel)
75+
async def login(request_body: UserSchema, db: Session = fastapi.Depends(get_db)):
76+
user = db.query(User).filter(User.username == request_body.username, User.password == request_body.password).first()
77+
if user:
78+
access_token = create_access_token(data={"sub": user.username})
79+
return ResponseModel(text=f"Login successful! Access token: {access_token}")
80+
else:
81+
return ResponseModel(text="Invalid username or password")
82+
83+
# --- API Endpoints ---
84+
85+
@app.post("/generate", response_model=ResponseModel, dependencies=[fastapi.Depends(get_db)])
86+
async def generate_text(request_body: RequestBody, db: Session = fastapi.Depends(get_db), request: Request = fastapi.Depends()):
87+
try:
88+
# Extract authentication token
89+
auth_header = request.headers.get("Authorization")
90+
if not auth_header:
91+
return JSONResponse(status_code=401, content={"message": "Unauthorized"})
92+
93+
token = auth_header.split(" ")[1]
94+
if not token:
95+
return JSONResponse(status_code=401, content={"message": "Unauthorized"})
96+
97+
# TODO: Implement token validation (using JWT) to check for valid authentication
98+
# ...
99+
100+
# Perform text generation
101+
response = await openai_service.generate_text(request_body.text, request_body.model)
102+
return response
103+
except Exception as e:
104+
logger.error(f"Error generating text: {e}")
105+
return ResponseModel(text=f"Error generating text: {e}")
106+
107+
@app.post("/translate", response_model=ResponseModel, dependencies=[fastapi.Depends(get_db)])
108+
async def translate_text(request_body: RequestBody, db: Session = fastapi.Depends(get_db), request: Request = fastapi.Depends()):
109+
try:
110+
# Extract authentication token
111+
auth_header = request.headers.get("Authorization")
112+
if not auth_header:
113+
return JSONResponse(status_code=401, content={"message": "Unauthorized"})
114+
115+
token = auth_header.split(" ")[1]
116+
if not token:
117+
return JSONResponse(status_code=401, content={"message": "Unauthorized"})
118+
119+
# TODO: Implement token validation (using JWT) to check for valid authentication
120+
# ...
121+
122+
# Perform text translation
123+
response = await openai_service.translate_text(request_body.text, request_body.model)
124+
return response
125+
except Exception as e:
126+
logger.error(f"Error translating text: {e}")
127+
return ResponseModel(text=f"Error translating text: {e}")
128+
129+
# --- Health Check ---
130+
131+
@app.get("/health")
132+
async def health_check():
133+
return {"status": "healthy"}
134+
135+
if __name__ == "__main__":
136+
import uvicorn
137+
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")

0 commit comments

Comments
 (0)