Skip to content

Commit

Permalink
Switch to google gemma-7b model
Browse files Browse the repository at this point in the history
  • Loading branch information
polyrabbit committed May 3, 2024
1 parent b42a94c commit 07849b2
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 81 deletions.
106 changes: 102 additions & 4 deletions hacker_news/llm/openai.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,120 @@
import tiktoken
import json
import logging
import re
import time
from json import JSONDecodeError

import openai
import tiktoken
import config

logger = logging.getLogger(__name__)


def context_limit(model: str):
if '128k' in model:
return 128 * 1024
if '32k' in model or 'mistral-7b' in model:
return 32 * 1024
if 'gemma-7b' in model:
return 8 * 1024
return 4096


def sanitize_for_openai(text, overhead):
text = text.replace('```', ' ').strip() # in case of prompt injection

limit = context_limit(config.openai_model)
# one token generally corresponds to ~4 characters, from https://platform.openai.com/tokenizer
if len(text) > 4096 * 2:
if len(text) > limit * 2:
try:
enc = tiktoken.encoding_for_model(config.openai_model) # We have openai compatible apis now
except KeyError:
enc = tiktoken.encoding_for_model('gpt-3.5-turbo')
tokens = enc.encode(text)
if len(tokens) > 4096 - overhead: # 4096: model's context limit
text = enc.decode(tokens[:4096 - overhead])
if len(tokens) > limit - overhead: # 4096: model's context limit
text = enc.decode(tokens[:limit - overhead])
return text.strip(".").strip()


def sanitize_title(title):
return title.replace('"', "'").replace('\n', ' ').strip()


def summarize_by_openai_family(content: str, need_json: bool):
start_time = time.time()

# 200: function + prompt tokens (to reduce hitting rate limit)
content = sanitize_for_openai(content, overhead=200)

# title = sanitize_title(self.title) or 'no title'
# Hope one day this model will be clever enough to output correct json
# Note: sentence should end with ".", "third person" - https://news.ycombinator.com/item?id=36262670
# prompt = f'Output only answers to following 3 steps.\n' \
# f'1 - Summarize the article delimited by triple backticks in 2 sentences.\n' \
# f'2 - Translate the summary into Chinese.\n' \
# f'3 - Provide a Chinese translation of sentence: "{title}".\n' \
# f'```{content.strip(".")}.```'

prompt = (f'Use third person mood to summarize the main points of the following article delimited by triple backticks in 2 concise English sentences. Ensure the summary does not exceed 100 characters.\n'
f'```{content.strip(".")}.```')

kwargs = {'model': config.openai_model,
# one token generally corresponds to ~4 characters
# 'max_tokens': int(config.summary_size / 4),
'stream': False,
'temperature': 0,
'n': 1, # only one choice
'timeout': 30}
if need_json:
kwargs['functions'] = [{"name": "render", "parameters": {
"type": "object",
"properties": {
"summary": {
"type": "string",
"description": "English summary"
},
"summary_zh": {
"type": "string",
"description": "Chinese summary"
},
"translation": {
"type": "string",
"description": "Chinese translation of sentence"
},
},
# "required": ["summary"] # ChatGPT only returns the required field?
}}]
kwargs['function_call'] = {"name": "render"}

if config.openai_model.startswith('text-'):
resp = openai.Completion.create(
prompt=prompt,
**kwargs
)
answer = resp['choices'][0]['text'].strip()
else:
resp = openai.ChatCompletion.create(
messages=[
{'role': 'user', 'content': prompt},
],
**kwargs)
message = resp["choices"][0]["message"]
if message.get('function_call'):
json_str = message['function_call']['arguments']
if resp["choices"][0]['finish_reason'] == 'length':
json_str += '"}' # best effort to save truncated answers
try:
answer = json.loads(json_str)
except JSONDecodeError as e:
logger.warning(f'Failed to decode answer from openai, will fallback to plain text, error: {e}')
return '' # Let fallback code kicks in
else:
answer = message['content'].strip()
logger.info(f'prompt: {prompt}')
logger.info(f'took {time.time() - start_time}s to generate: '
# Default str(resp) prints \u516c
f'{json.dumps(resp.to_dict_recursive(), sort_keys=True, indent=2, ensure_ascii=False)}')
# Remove leading ': ', ' *-' etc. from answer
answer = re.sub(r'^[^a-zA-Z0-9]+', '', answer)
return answer.strip()
80 changes: 3 additions & 77 deletions hacker_news/news.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import db.summary
from db.summary import Model
from hacker_news.llm.coze import summarize_by_coze
from hacker_news.llm.openai import sanitize_for_openai, sanitize_title
from hacker_news.llm.openai import summarize_by_openai_family
from page_content_extractor import parser_factory
from page_content_extractor.webimage import WebImage

Expand Down Expand Up @@ -105,7 +105,7 @@ def summarize(self, content=None) -> (str, Model):
f'No need to summarize since we have a small text of size {len(content)}')
return content, Model.FULL

summary = self.summarize_by_coze(content) or self.summarize_by_openai(content)
summary = self.summarize_by_openai(content)
if summary:
return summary, Model.OPENAI
if self.get_score() >= config.local_llm_score_threshold: # Avoid slow local inference
Expand Down Expand Up @@ -135,90 +135,16 @@ def summarize_by_openai(self, content):
logger.info("Score %d is too small, ignore openai", self.get_score())
return ''

# 200: function + prompt tokens (to reduce hitting rate limit)
content = sanitize_for_openai(content, overhead=200)

title = sanitize_title(self.title) or 'no title'
# Hope one day this model will be clever enough to output correct json
# Note: sentence should end with ".", "third person" - https://news.ycombinator.com/item?id=36262670
prompt = f'Output only answers to following 3 steps.\n' \
f'1 - Summarize the article delimited by triple backticks in 2 sentences.\n' \
f'2 - Translate the summary into Chinese.\n' \
f'3 - Provide a Chinese translation of sentence: "{title}".\n' \
f'```{content.strip(".")}.```'
try:
# Too many exceptions to support translation, give up...
# answer = self.openai_complete(prompt, True)
# summary = self.parse_step_answer(answer).strip().strip(' *-')
# if not summary: # If step parse failed, ignore the translation
summary = self.openai_complete(
f'Use third person mood to summarize the main points of the following article delimited by triple backticks in 2 concise sentences. Ensure the summary does not exceed 100 characters.\n'
f'```{content.strip(".")}.```', False)
return summary
return summarize_by_openai_family(content, False)
except Exception as e:
logger.exception(f'Failed to summarize using openai, key #{config.openai_key_index}, {e}') # Make this error explicit in the log
return ''

# TODO: move to llm module
def openai_complete(self, prompt, need_json):
start_time = time.time()
kwargs = {'model': config.openai_model,
# one token generally corresponds to ~4 characters
# 'max_tokens': int(config.summary_size / 4),
'stream': False,
'temperature': 0,
'n': 1, # only one choice
'timeout': 30}
if need_json:
kwargs['functions'] = [{"name": "render", "parameters": {
"type": "object",
"properties": {
"summary": {
"type": "string",
"description": "English summary"
},
"summary_zh": {
"type": "string",
"description": "Chinese summary"
},
"translation": {
"type": "string",
"description": "Chinese translation of sentence"
},
},
# "required": ["summary"] # ChatGPT only returns the required field?
}}]
kwargs['function_call'] = {"name": "render"}
if config.openai_model.startswith('text-'):
resp = openai.Completion.create(
prompt=prompt,
**kwargs
)
answer = resp['choices'][0]['text'].strip()
else:
resp = openai.ChatCompletion.create(
messages=[
{'role': 'user', 'content': prompt},
],
**kwargs)
message = resp["choices"][0]["message"]
if message.get('function_call'):
json_str = message['function_call']['arguments']
if resp["choices"][0]['finish_reason'] == 'length':
json_str += '"}' # best effort to save truncated answers
try:
answer = json.loads(json_str)
except JSONDecodeError as e:
logger.warning(f'Failed to decode answer from openai, will fallback to plain text, error: {e}')
return '' # Let fallback code kicks in
else:
answer = message['content'].strip()
logger.info(f'prompt: {prompt}')
logger.info(f'took {time.time() - start_time}s to generate: '
# Default str(resp) prints \u516c
f'{json.dumps(resp.to_dict_recursive(), sort_keys=True, indent=2, ensure_ascii=False)}')
return answer

def parse_step_answer(self, answer):
if not answer or isinstance(answer, str):
return answer
Expand Down
13 changes: 13 additions & 0 deletions test/test_news_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import unittest
from unittest import TestCase, mock

import openai
import config
import db
from db.engine import session_scope
from db.summary import Model
from hacker_news.llm.coze import summarize_by_coze
from hacker_news.llm.openai import summarize_by_openai_family
from hacker_news.news import News


Expand Down Expand Up @@ -57,6 +59,17 @@ def test_summarize_by_coze(self):
self.assertGreater(len(summary), 80)
self.assertLess(len(summary), config.summary_size * 2)

@unittest.skipUnless(openai.api_key, 'openai families are disabled')
def test_summarize_by_openai_family(self):
fpath = os.path.join(os.path.dirname(__file__), 'fixtures/telnet.txt')
with open(fpath, 'r') as fp:
content = fp.read()
summary = summarize_by_openai_family(content, False)
self.assertIn('Telnet', summary)
self.assertFalse(summary.startswith(': '))
self.assertGreater(len(summary), 80)
self.assertLess(len(summary), config.summary_size * 2)

def test_parse_step_answer(self):
news = News('The guide to software development with Guix')
self.assertEqual(news.parse_title_translation('"Guix软件开发指南"的中文翻译。'),
Expand Down

0 comments on commit 07849b2

Please sign in to comment.