# Build an autonomous multi-agents workflow to write picture book

## Part 1
- **Goal**: Write a outline for a story

### 1. Install dependecies

In [2]:
%pip install -q langchain_community==0.0.32 langgraph==0.0.51 langchain-aws==0.1.6

Note: you may need to restart the kernel to use updated packages.


### 2. Some Utils functions

#### 2.1 Structrued Output parser
- In our case, we need to parse the LLM output to a pydantic object, so will define Structrued Output parser

In [10]:
import os
import json
import re
from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.pydantic_v1 import BaseModel, Field
from json import JSONDecodeError
def dict_to_obj(json_str:dict, target:object):
    return target.parse_obj(json_str)


class CustJsonOuputParser(BaseOutputParser[str]): 
    verbose :bool = Field( default=True)

    def parse(self, text: str) -> str:
        if self.verbose:
            print(text)
        pattern = r"<answer>(.*?)</answer>"
        match = re.search(pattern, text, re.DOTALL)
        if match:
            text = match.group(1)
        else:
            return {'answer':"no"}    
        new_dict = json.loads(text.replace('\n','  '))
        
        return new_dict

    @property
    def _type(self) -> str:
        return "cust_output_parser"

class TextOuputParser(BaseOutputParser[str]): 
    verbose :bool = Field( default=True)

    def parse(self, text: str) -> str:
        if self.verbose:
            print(text)
        pattern = r"<answer>(.*?)</answer>"
        match = re.search(pattern, text, re.DOTALL)
        if match:
            text = match.group(1)
            return text.strip()
        else:
            return ''

    @property
    def _type(self) -> str:
        return "TextOuputParser"

#### 2.2 LLM models

In [13]:
from langchain_aws import ChatBedrock

from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

llm_sonnet = ChatBedrock(model_id="anthropic.claude-3-sonnet-20240229-v1:0",
                  model_kwargs={"temperature": 0.8,
                                "top_k":250,
                                "max_tokens": 4096,
                                "top_p":0.9,
                                "stop_sequences":['</invoke>','</error>']
                               },
                  streaming=True,
                #   callbacks=[StreamingStdOutCallbackHandler()]
                )

llm_haiku = ChatBedrock(model_id="anthropic.claude-3-haiku-20240307-v1:0",
                  model_kwargs={"temperature": 0.8,
                                "top_k":250,
                                "max_tokens": 4096,
                                "top_p":0.9,
                                "stop_sequences":['</invoke>','</error>']
                               },
                  streaming=True,
                  # callbacks=[StreamingStdOutCallbackHandler()]
                  )

#### 2.3 Image Generation model invoke code
- Bedrock SD model invoke, here we use Bedrock SD to generate portrait of characters

In [14]:
import base64
import io
import json
import boto3
from PIL import Image
from botocore.exceptions import ClientError
from enum import Enum
from io import BytesIO

# profile='default'
profile=''

class StyleEnum(Enum):
    Photographic = "photographic"
    Tile_texture = "tile-texture"
    Digital_art = "digital-art"
    Origami = "origami"
    Modeling_compound = "modeling-compound"
    Anime = "anime"
    Cinematic = "cinematic"
    Model_3D = "3d-model"
    Comicbook = "comic-book"
    Enhance = "enhance"
    
class ImageError(Exception):
    "Custom exception for errors returned by SDXL"
    def __init__(self, message):
        self.message = message


class ImageGenerator(BaseModel):
    model_id: str = Field(default="stability.stable-diffusion-xl-v1")
    profile:str
    cfg_scale: int = Field( default=7)
    steps:int = Field( default=50)
    samples:int = Field( default=1)
    
    def _generate(self,model_id, body):
        """
        Generate an image using SDXL 1.0 on demand.
        Args:
            model_id (str): The model ID to use.
            body (str) : The request body to use.
        Returns:
            image_bytes (bytes): The image generated by the model.
        """

        # logger.info("Generating image with SDXL model %s", model_id)

        session = boto3.Session()
        # session = boto3.Session(profile_name=self.profile)
        #get bedrock service 
        bedrock = session.client(service_name='bedrock-runtime')
    
        accept = "application/json"
        content_type = "application/json"

        response = bedrock.invoke_model(
            body=body, modelId=model_id, accept=accept, contentType=content_type
        )
        response_body = json.loads(response.get("body").read())

        base64_image = response_body.get("artifacts")[0].get("base64")
        base64_bytes = base64_image.encode('ascii')
        image_bytes = base64.b64decode(base64_bytes)

        finish_reason = response_body.get("artifacts")[0].get("finishReason")

        if finish_reason == 'ERROR' or finish_reason == 'CONTENT_FILTERED':
            raise ImageError(f"Image generation error. Error code is {finish_reason}")


        # logger.info("Successfully generated image withvthe SDXL 1.0 model %s", model_id)

        return image_bytes

    def generate_image( self,prompt,seed=0,style_preset=StyleEnum.Photographic.value):
        if self.model_id.startswith('stability'):
            body=json.dumps({
                "text_prompts": [
                {
                "text": prompt
                }
            ],
            "cfg_scale": self.cfg_scale,
            "seed": seed,
            "steps": self.steps,
              "height": 768,
            "width": 768,
            "samples" : self.samples,
            "style_preset" : style_preset
            })
        elif self.model_id.startswith('amazon'):
            body = json.dumps({
            "taskType": "TEXT_IMAGE",
            "textToImageParams": {
                "text": prompt
            },
            "imageGenerationConfig": {
                "numberOfImages": 1,
                "height": 768,
                "width": 768,
                "cfgScale": self.cfg_scale,
                "seed": seed
            }
            })
        print(body)
        image= None
        try:
            image_bytes=self._generate(model_id = self.model_id,
                                    body = body)
            image = Image.open(io.BytesIO(image_bytes))

        except ClientError as err:
            message=err.response["Error"]["Message"]
            # logger.error("A client error occurred: %s", message)
            print("A client error occured: " +format(message))
        except ImageError as err:
            print(err)
        except Exception as err:
            print(err)
        finally:
            return image

- Sagemaker model invoke. Here we use StoryDiffusion to generate consecutive images 

In [None]:
import sagemaker
from typing import Any, List
import json
import io
from sagemaker.async_inference.waiter_config import WaiterConfig
import time
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.predictor import Predictor



def base64_to_image(base64_string):
    image_bytes = base64.b64decode(base64_string)
    image_buffer = BytesIO(image_bytes)
    image = Image.open(image_buffer)
    return image

def get_bucket_and_key(s3uri):
    pos = s3uri.find("/", 5)
    bucket = s3uri[5:pos]
    key = s3uri[pos + 1 :]
    return bucket, key



class StoryDiffusionGenerator():
    """ API schema
    class APIRequest(BaseModel):
        sd_type: Literal['RealVision','SDXL','Unstable'] = Field(default='SDXL')
        modeltype : Literal["Only Using Textual Description","Using Ref Images"] = Field(default="Only Using Textual Description")
        files: Any = Field(default=None)
        num_steps : int = Field(default=50)
        style : Literal["Japanese Anime","(No style)","Cinematic","Disney Charactor","Photographic","Comic book","Line art"] = Field(default="Comic book")
        Ip_Adapter_Strength : float = Field(default=0.5, descrition="The strength of the IP adapter. The value ranges from 0 to 1. The larger the value, the stronger the IP adapter.")
        style_strength_ratio : int = Field(default=20 ,descrition="Style strength of Ref Image (%)")
        guidance_scale: float = Field(default=5.0)
        seed_: int = Field(default=0)
        sa32_:float = Field(default=0.5)
        sa64_:float = Field(default=0.5)
        id_length_ : int = Field(default=2,descrition="Number of id images in total images")
        general_prompt : str = Field(...,descrition="Textual Description for Character")
        negative_prompt : str =  Field(default="naked, deformed, bad anatomy, disfigured, poorly drawn face, mutation, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, blurry, watermarks, oversaturated, distorted hands, amputation")
        prompt_array: str =  Field(...,descrition="Comic Description (each line corresponds to a frame)")
        G_height: int = Field(default=512)
        G_width: int = Field(default=512)
        comic_type : Literal['No typesetting (default)','Four Pannel','Classic Comic Style'] = Field(default='Classic Comic Style')
        font_choice :str = Field(default='Inkfree.ttf')
    """
    
    def __init__(self,endpoint_name,profile):
        endpoint_name = endpoint_name
        # boto_session= boto3.Session(profile_name=profile)
        boto_session= boto3.Session()
        self.s3_resource = boto_session.resource("s3")
        sagemaker_session = sagemaker.Session(boto_session = boto_session)
        bucket  = sagemaker_session.default_bucket()
        output_path  = "s3://{0}/{1}/asyncinvoke/out/".format(bucket, "story-diffusion")
        input_path :str = "s3://{0}/{1}/asyncinvoke/in/".format(bucket, "story-diffusion")
        
        predictor_ = Predictor(
            endpoint_name=endpoint_name,
            sagemaker_session=sagemaker_session,
            model_data_input_path=input_path,
            model_data_output_path=output_path,
        )
        predictor_.serializer = JSONSerializer()
        predictor_.deserializer = JSONDeserializer()
        self.config = WaiterConfig(
            max_attempts=100, delay=10  #  number of attempts  #  time in seconds to wait between attempts
        )
        self.predictor_async = AsyncPredictor(
                predictor_,
                name='story-diffusion'
        )

    def generate_images(self,general_prompt:str,prompt_array:str,id_length:int=2, ref_imgs: List[Any]= [],comic_type:str='Classic Comic Style', style:str = 'Japanese Anime',sd_type:str="Unstable", height:int = 512, width :int = 768) -> list:
        data = { "general_prompt": general_prompt,
                        "prompt_array" : prompt_array,
                        "style" : style,
                        "G_height" : height,
                        "G_width" : width,
                        "comic_type" : comic_type,
                       "files":ref_imgs,
                        "id_length_":id_length,
                        "sd_type":sd_type,
                }
        if not ref_imgs:
            del data['files']
        # print(data)
        prediction = self.predictor_async.predict_async(data)
        print(f"Response output path: {prediction.output_path}")
        start = time.time()
        prediction.get_result(self.config)
        print(f"Time taken: {time.time() - start}s")
        
        output_bucket, output_key = get_bucket_and_key(prediction.output_path)
        output_obj = self.s3_resource.Object(output_bucket, output_key)
        body = output_obj.get()["Body"].read().decode("utf-8")
        
        respobj = json.loads(body)
        images = []
        for img in respobj['images_base64']:
            images.append(base64_to_image(img))
            
        # images =get_async_result(prediction)
        return images

- generate image from model in ec2 server