Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
tatsuya.naiki committed Nov 19, 2023
0 parents commit 44cfc1a
Show file tree
Hide file tree
Showing 4 changed files with 332 additions and 0 deletions.
50 changes: 50 additions & 0 deletions claude2/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import boto3
import json
import os
import logging

logger = logging.getLogger()
logger.setLevel(logging.getLevelName(os.getenv('LOG_LEVEL', 'INFO')))
bedrock_runtime_client = boto3.client('bedrock-runtime')


def get_completion(user_prompt):
model_id = 'anthropic.claude-v2'
accept = 'application/json'
content_type = 'application/json'

body = json.dumps({
"prompt": user_prompt,
"max_tokens_to_sample": 600,
})

response = bedrock_runtime_client.invoke_model(
modelId=model_id,
accept=accept,
contentType=content_type,
body=body
)

response_body = json.loads(response.get('body').read())

print("Received response_body:" + json.dumps(response_body, ensure_ascii=False))

return response_body.get('completion')


# Lambda のハンドラー関数
def lambda_handler(event, context):
if 'AppsheetBot' not in event['headers']['user-agent']:
return {
'statusCode': 401,
'body': json.dumps('401 Unauthorized')
}
print(event)
# return get_completion(event.get('user_prompt'))

result = {
'statusCode': 200,
'body': {'completion': get_completion(json.loads(event['body'])['user_prompt'])}
}
print(result)
return result
91 changes: 91 additions & 0 deletions kendra/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import boto3
import json
import os
import logging

logger = logging.getLogger()
logger.setLevel(logging.getLevelName(os.getenv('LOG_LEVEL', 'INFO')))

kendra_client = boto3.client('kendra')
bedrock_runtime_client = boto3.client('bedrock-runtime')


# Kendra から検索結果を取得
def get_retrieval_result(query_text, index_id):
response = kendra_client.query(
QueryText=query_text,
IndexId=index_id,
AttributeFilter={
"EqualsTo": {
"Key": "_language_code",
"Value": {"StringValue": "ja"},
},
},
)

# Kendra の応答から最初の5つの結果を抽出
results = response['ResultItems'][:5] if response['ResultItems'] else []

# 結果からドキュメントの抜粋部分のテキストを抽出
for i in range(len(results)):
results[i] = results[i].get("DocumentExcerpt", {}).get("Text", "").replace('\\n', ' ')

print("Received results:" + json.dumps(results, ensure_ascii=False))

return json.dumps(results, ensure_ascii=False)


def get_completion(user_prompt):
index_id = os.getenv('KENDRA_INDEX_ID')

prompt = f"""\n\nHuman:
[参考]情報をもとに[質問]に適切に答えてください。
[質問]
{user_prompt}
[参考]
{get_retrieval_result(user_prompt, index_id)}
Assistant:
"""

# 各種パラメーターの指定
model_id = 'anthropic.claude-v2'
accept = 'application/json'
content_type = 'application/json'

body = json.dumps({
"prompt": prompt,
"max_tokens_to_sample": 600,
})

response = bedrock_runtime_client.invoke_model(
modelId=model_id,
accept=accept,
contentType=content_type,
body=body
)

response_body = json.loads(response.get('body').read())

print("Received response_body:" + json.dumps(response_body, ensure_ascii=False))

return response_body.get('completion')


# Lambda のハンドラー関数
def lambda_handler(event, context):
if 'AppsheetBot' not in event['headers']['user-agent']:
return {
'statusCode': 401,
'body': json.dumps('401 Unauthorized')
}
print(event)
# return get_completion(event.get('user_prompt'))

result = {
'statusCode': 200,
'body': {'completion': get_completion(json.loads(event['body'])['user_prompt'])}
}
print(result)
return result


52 changes: 52 additions & 0 deletions llama2/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import boto3
import json
import os
import logging

logger = logging.getLogger()
logger.setLevel(logging.getLevelName(os.getenv('LOG_LEVEL', 'INFO')))
bedrock_runtime_client = boto3.client('bedrock-runtime')


def get_completion(user_prompt):
model_id = 'meta.llama2-13b-chat-v1'
accept = 'application/json'
content_type = 'application/json'

body = json.dumps({
"prompt": user_prompt,
'max_gen_len': 1024,
'top_p': 0.9,
'temperature': 0.2
})

response = bedrock_runtime_client.invoke_model(
modelId=model_id,
accept=accept,
contentType=content_type,
body=body
)

response_body = json.loads(response.get('body').read())

print("Received response_body:" + json.dumps(response_body, ensure_ascii=False))

return response_body.get('generation')


# Lambda のハンドラー関数
def lambda_handler(event, context):
if 'AppsheetBot' not in event['headers']['user-agent']:
return {
'statusCode': 401,
'body': json.dumps('401 Unauthorized')
}
print(event)
# return get_completion(event.get('user_prompt'))

result = {
'statusCode': 200,
'body': {'completion': get_completion(json.loads(event['body'])['user_prompt'])}
}
print(result)
return result
139 changes: 139 additions & 0 deletions sd/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import base64
import datetime

import boto3
import json
import os
import logging

logger = logging.getLogger()
logger.setLevel(logging.getLevelName(os.getenv('LOG_LEVEL', 'INFO')))

bedrock_runtime_client = boto3.client('bedrock-runtime')
s3 = boto3.client('s3')


# Kendra から検索結果を取得
def get_sd_prompt(query_text):
prompt = "Human: " + query_text + """\n
あなたはStable Diffusionのプロンプトを生成するAIアシスタントです。
以下の step でStableDiffusionのプロンプトを生成してください。
<step>
* rule を理解してください。ルールは必ず守ってください。例外はありません。
* ユーザは生成して欲しい画像の要件をチャットで指示します。チャットのやり取りを全て理解してください。
* チャットのやり取りから、生成して欲しい画像の特徴を正しく認識してください。
* 画像生成において重要な要素をから順にプロンプトに出力してください。ルールで指定された文言以外は一切出力してはいけません。例外はありません。
</step>
<rule>
* プロンプトは output-format の通りに、JSON形式で出力してください。JSON以外の文字列は一切出力しないでください。JSONの前にも後にも出力禁止です。
* JSON形式以外の文言を出力することは一切禁止されています。挨拶、雑談、ルールの説明など一切禁止です。
* プロンプトは単語単位で、カンマ区切りで出力してください。長文で出力しないでください。プロンプトは必ず英語で出力してください。
* プロンプトには以下の要素を含めてください。
* 画像のクオリティ、被写体の情報、衣装・ヘアスタイル・表情・アクセサリーなどの情報、画風に関する情報、背景に関する情報、構図に関する情報、ライティングやフィルタに関する情報
* 画像に含めたくない要素については、negativePromptとして出力してください。なお、negativePromptは必ず出力してください。
* フィルタリング対象になる不適切な要素は出力しないでください。
</rule>
<output-format>
{
prompt: string,
negativePrompt: string,
}
</output-format>
\n\nAssistant:
"""

body = json.dumps(
{
"prompt": prompt,
"max_tokens_to_sample": 500,
}
)

resp = bedrock_runtime_client.invoke_model(
modelId="anthropic.claude-v2",
body=body,
contentType="application/json",
accept="application/json",
)
answer = resp["body"].read().decode()
return json.loads(json.loads(answer)['completion'])


def get_location(sd_prompt):
model_id = "stability.stable-diffusion-xl-v0"
accept = "application/json"
content_type = "application/json"

# 10回の繰り返しを実行
body = json.dumps({
"text_prompts": [
{
"text": sd_prompt['prompt'],
"weight": 1.0
},
{
"text": sd_prompt['negativePrompt'],
"weight": -1.0
}
],
"cfg_scale": 10,
"seed": 20,
"steps": 50
})

response = bedrock_runtime_client.invoke_model(
body=body, modelId=model_id, accept=accept, contentType=content_type
)
response_body = json.loads(response.get("body").read())
print(response_body['result'])

# 取得した画像データのデコード
encoded_png_data = response_body.get("artifacts")[0].get("base64")
decoded_png_data = base64.b64decode(encoded_png_data)

# ファイル名の付与
now = datetime.datetime.now()
formatted_date = now.strftime('%y%m%d-%H%M%S%f')[:-4]
file_name = f"output-{formatted_date}.png"

# S3バケットへ出力
s3.put_object(Bucket=os.getenv('BUCKET_NAME'),
Key=file_name,
Body=decoded_png_data,
ContentType="image/png")

# 署名つきURLにするとサムネが見えません
# header_location = s3.generate_presigned_url(
# ClientMethod='get_object',
# Params={'Bucket': os.getenv('BUCKET_NAME'), 'Key': file_name},
# ExpiresIn=3600,
# HttpMethod='GET'
# )
#
# return {"Location": header_location}
return 'https://{0}.s3.amazonaws.com/{1}'.format(os.getenv('BUCKET_NAME'), file_name)


# Lambda のハンドラー関数
def lambda_handler(event, context):
if 'AppsheetBot' not in event['headers']['user-agent']:
return {
'statusCode': 401,
'body': json.dumps('401 Unauthorized')
}
print(event)

# sd_prompt = get_sd_prompt("ラーメンを食べる猫")
sd_prompt = get_sd_prompt(json.loads(event['body'])['user_prompt'])
location = get_location(sd_prompt)

result = {
'statusCode': 200,
'body': {'completion': location}
}
print(result)
return result

0 comments on commit 44cfc1a

Please sign in to comment.