# gRPC Text Generation Inference with Caikit+TGIS Serving

### Set the Inference server url (replace with your own address)

In [1]:
inference_server_url = "your_server_address:port"

inference_server_url = "caikit-tgis-example-isvc-predictor.kserve-demo.svc.cluster.local:80"

In [None]:
# Optional, requirements if they are not already present
# !pip -q install grpcio grpcio-reflection

!python -m pip install --upgrade pip
!pip install -U grpcio-reflection grpcio grpcio-tools

### Imports

In [2]:
import grpc
from grpc_reflection.v1alpha.proto_reflection_descriptor_database import ProtoReflectionDescriptorDatabase
from google.protobuf.descriptor_pool import DescriptorPool
from google.protobuf.message_factory import GetMessageClass

### Caikit+TGIS Stub class for text generation

In [3]:
# Service Stub definition
class CaikitTgisTextGeneration(object):
    def __init__(self, channel):
        """Constructor.

        Args:
            channel: A grpc.Channel.
        """
        reflection_db = ProtoReflectionDescriptorDatabase(channel)
        desc_pool = DescriptorPool(reflection_db)
        self.TextGenerationTaskRequest = GetMessageClass(desc_pool.FindMessageTypeByName('caikit.runtime.Nlp.TextGenerationTaskRequest'))()
        self.GeneratedTextResult = GetMessageClass(desc_pool.FindMessageTypeByName('caikit_data_model.nlp.GeneratedTextResult'))()
        self.TextGenerationTaskPredict = channel.unary_unary(
                '/caikit.runtime.Nlp.NlpService/TextGenerationTaskPredict',
                request_serializer=self.TextGenerationTaskRequest.SerializeToString,
                response_deserializer=self.GeneratedTextResult.FromString,
                )
        self.ServerStreamingTextGenerationTaskPredict = channel.unary_stream(
                '/caikit.runtime.Nlp.NlpService/ServerStreamingTextGenerationTaskPredict',
                request_serializer=self.TextGenerationTaskRequest.SerializeToString,
                response_deserializer=self.GeneratedTextResult.FromString,
                )

### Create the channel with self-signed certificate

Note: to extract the certificate chain, you can use the following command:

`openssl s_client -showcerts -verify 5 -connect your_server_address:port < /dev/null |    awk '/BEGIN CERTIFICATE/,/END CERTIFICATE/{ if(/BEGIN CERTIFICATE/){a++}; out="cert"a".pem"; print >out}'`

In [4]:
with open('certificate.pem', 'rb') as f:
    creds = grpc.ssl_channel_credentials(f.read())

server_address = inference_server_url

channel = grpc.secure_channel(server_address, creds)

### Query the service

In [5]:
# Create the service connection
caikit_tgis_text_generation_stub = CaikitTgisTextGeneration(channel)

In [6]:
# Additional parameters needed to query the right model
model_id = 'Llama-2-7b-chat-hf'
model_id = 'flan-t5-small-caikit'
metadata = [("mm-model-id", model_id)]

In [7]:
# Let's query the model!
request = caikit_tgis_text_generation_stub.TextGenerationTaskRequest
request.text = 'How do you bake a cake?'
request.preserve_input_text = False
request.max_new_tokens = 200
request.min_new_tokens = 10

response = caikit_tgis_text_generation_stub.TextGenerationTaskPredict(
    request=request,
    metadata=metadata
)
print(response.generated_text)



Baking a cake is a straightforward process that requires a few basic ingredients and some time in the oven. Here's a step-by-step guide on how to bake a cake:

1. Preheat the oven: Preheat the oven to the temperature specified in the recipe you're using. This can range from 325°F to 375°F (160°C to 190°C), depending on the type of cake you're making.

2. Prepare the cake pan: Choose a cake pan that's the right size for the recipe you're using. Grease the pan with butter or cooking spray to prevent the cake from sticking.

3. Mix the ingredients: In a large mixing bowl, combine the dry ingredients (flour,


### Query the service - Streaming answer

In [8]:
# Let's get some streaming answers!
request = caikit_tgis_text_generation_stub.TextGenerationTaskRequest
request.text = 'How do you bake a cake?'
request.preserve_input_text = False
request.max_new_tokens = 200
request.min_new_tokens = 10

for response in caikit_tgis_text_generation_stub.ServerStreamingTextGenerationTaskPredict(request=request, metadata=metadata):
    print(response.generated_text, end ="")



Baking a cake is a straightforward process that requires a few basic ingredients and some time in the oven. Here's a step-by-step guide on how to bake a cake:

1. Preheat the oven: Preheat the oven to the temperature specified in the recipe you're using. This can range from 325°F to 375°F (160°C to 190°C), depending on the type of cake you're making.

2. Prepare the cake pan: Choose a cake pan that's the right size for the recipe you're using. Grease the pan with butter or cooking spray to prevent the cake from sticking.

3. Mix the ingredients: In a large mixing bowl, combine the dry ingredients (flour,

### To go further: service, methods and parameters discovery

In [9]:
# List available services
reflection_db = ProtoReflectionDescriptorDatabase(channel)
services = reflection_db.get_services()
print(f'Available services: {services}')

Available services: ['caikit.runtime.Nlp.NlpService', 'caikit.runtime.Nlp.NlpTrainingService', 'caikit.runtime.training.TrainingManagement', 'grpc.reflection.v1alpha.ServerReflection', 'mmesh.ModelRuntime']


In [10]:
# Selecting the NlpService, list available methods
desc_pool = DescriptorPool(reflection_db)
nlp_service = desc_pool.FindServiceByName('caikit.runtime.Nlp.NlpService')
print('Available methods:')
for m in nlp_service.methods:
    print(m.name)

Available methods:
TextClassificationTaskPredict
TextGenerationTaskPredict
ServerStreamingTextGenerationTaskPredict
TokenizationTaskPredict
TokenClassificationTaskPredict
BidiStreamingTokenClassificationTaskPredict


In [11]:
# Selecting the TextGenerationTaskPredict method, list available fields with their types (num id) and default values
# Types reference: https://protobuf.dev/reference/csharp/api-docs/class/google/protobuf/well-known-types/field/types
method_desc = nlp_service.FindMethodByName('TextGenerationTaskPredict')
for field in method_desc.input_type.fields:
    print(f'{field.name}, {field.type}, default: {field.default_value}')

text, 9, default: 
preserve_input_text, 8, default: False
max_new_tokens, 3, default: 0
min_new_tokens, 3, default: 0
device, 9, default: 
