In [None]:
import torch
import tensorflow as tf
from google.cloud import aiplatform
import vertexai
from models.flant5.FlanT5 import T5FineTuner, tokenize_dataset
import os

from app.utils.flant5_client import FlanT5Client


In [None]:
# 1. Set up paths and configurations
LOCAL_MODEL_DIR = os.path.abspath("local_model")
SAVED_MODEL_DIR = os.path.abspath("saved_model")
GCS_BUCKET = "ontologykg2"
GCS_MODEL_PATH = f"gs://{GCS_BUCKET}/model/flant5"
PROJECT_ID = "deft-return-439619-h9"
REGION = "us-west1"

# Set up paths using os.path.join for cross-platform compatibility
LOCAL_CHECKPOINT_DIR = os.path.abspath("local_checkpoint/checkpoint_1000001")  # Use absolute path
LOCAL_EXPORT_DIR = os.path.abspath("exported_model")  # Use absolute path

CKPT_PATH = "models/flant5/epoch=3-step=2072-train_loss=0.35.ckpt"
#CKPT_PATH = "lightning_logs/version_34/epoch=3-step=2072-train_loss=0.35.ckpt"
#CKPT_PATH = "lightning_logs/version_30/final.ckpt"

In [None]:
# load the local model

checkpoint = torch.load(CKPT_PATH)
print(checkpoint.keys())

llm = T5FineTuner.load_from_checkpoint(CKPT_PATH)

llm.model.eval() # set model to evaluation mode
llm = llm.to("cpu") # use CPU since I don't have GPU
print("Done")

model = llm.model

In [None]:
# 1. Convert PyTorch model to TensorFlow
def convert_pt_to_tf(llm):
    """Convert PyTorch T5 model to TensorFlow"""
    from transformers import TFT5ForConditionalGeneration
    
    # Create TF model with same config
    tf_model = TFT5ForConditionalGeneration.from_pretrained(
        llm.hparam.model_name_or_path,
        from_pt=True,
        config=llm.model.config
    )
    
    # Verify the conversion
    print("Model converted from PyTorch to TensorFlow")
    print(f"Model type: {type(tf_model)}")

    return tf_model

# Convert PyTorch model to TF
tf_model = convert_pt_to_tf(llm)
model = tf_model  # Use converted model for serving

In [None]:
@tf.function(input_signature=[{
    'input_ids': tf.TensorSpec(shape=(1, 512), dtype=tf.int32, name='input_ids'),
    'attention_mask': tf.TensorSpec(shape=(1, 512), dtype=tf.int32, name='attention_mask')
}])
def serving_fn(inputs):
    outputs = model.generate(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        max_length=128,
        num_beams=4,
        pad_token_id=model.config.pad_token_id,
        eos_token_id=model.config.eos_token_id,
        bos_token_id=model.config.bos_token_id,
        use_cache=True,
        do_sample=False,
        num_return_sequences=1,
        return_dict_in_generate=True,
        output_scores=True
    )
    return {'sequences': outputs.sequences}

# 2.5 Test the serving function
def test_serving_fn():
    # Create dummy input
    dummy_input = {
        'input_ids': tf.zeros((1, 512), dtype=tf.int32),
        'attention_mask': tf.ones((1, 512), dtype=tf.int32)
    }
    
    # Test the function
    result = serving_fn(dummy_input)
    print("Serving function test successful")
    return result

# Test before saving
test_serving_fn()

In [None]:
# 3. Save as SavedModel format for VectorAI
tf.saved_model.save(
    model,
    SAVED_MODEL_DIR,
    signatures={
        'serving_default': serving_fn
    }
)
print(f"SavedModel saved to {SAVED_MODEL_DIR}")

In [None]:
GCS_BUCKET = "ontologykg2"
GCS_MODEL_PATH = f"gs://{GCS_BUCKET}/model/flant5"

# 4. Upload to GCS and deploy to Vertex AI
print("Uploading to GCS and deploying to Vertex AI...")
vertexai.init(project=PROJECT_ID, location=REGION)

# Upload and create model
model = aiplatform.Model.upload(
    display_name="flan-t5-base",
    artifact_uri=GCS_MODEL_PATH,
    serving_container_image_uri="us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-12:latest", 
)

# Deploy to endpoint
endpoint = model.deploy(
    machine_type="a2-highgpu-1g",  # Using GPU machine type for base
    accelerator_type="NVIDIA_TESLA_A100",
    accelerator_count=1,
    min_replica_count=1,
    max_replica_count=1,
)

print(f"Model deployed to endpoint: {endpoint.resource_name}")


    Uploading to GCS and deploying to Vertex AI...
    Creating Model
    INFO:google.cloud.aiplatform.models:Creating Model
    Create Model backing LRO: projects/371748443295/locations/us-west1/models/5844832289443282944/operations/7813018026546036736
    INFO:google.cloud.aiplatform.models:Create Model backing LRO: projects/371748443295/locations/us-west1/models/5844832289443282944/operations/7813018026546036736
    Model created. Resource name: projects/371748443295/locations/us-west1/models/5844832289443282944@1
    INFO:google.cloud.aiplatform.models:Model created. Resource name: projects/371748443295/locations/us-west1/models/5844832289443282944@1
    To use this Model in another session:
    INFO:google.cloud.aiplatform.models:To use this Model in another session:
    model = aiplatform.Model('projects/371748443295/locations/us-west1/models/5844832289443282944@1')
    INFO:google.cloud.aiplatform.models:model = aiplatform.Model('projects/371748443295/locations/us-west1/models/5844832289443282944@1')
   
    Creating Endpoint
    INFO:google.cloud.aiplatform.models:Creating Endpoint
    Create Endpoint backing LRO: projects/371748443295/locations/us-west1/endpoints/8518052919822516224/operations/8211586593568325632
    INFO:google.cloud.aiplatform.models:Create Endpoint backing LRO: projects/371748443295/locations/us-west1/endpoints/8518052919822516224/operations/8211586593568325632
    Endpoint created. Resource name: projects/371748443295/locations/us-west1/endpoints/8518052919822516224
    INFO:google.cloud.aiplatform.models:Endpoint created. Resource name: projects/371748443295/locations/us-west1/endpoints/8518052919822516224
    To use this Endpoint in another session:
    INFO:google.cloud.aiplatform.models:To use this Endpoint in another session:
    endpoint = aiplatform.Endpoint('projects/371748443295/locations/us-west1/endpoints/8518052919822516224')
    INFO:google.cloud.aiplatform.models:endpoint = aiplatform.Endpoint('projects/371748443295/locations/us-west1/endpoints/8518052919822516224')
    Deploying model to Endpoint : projects/371748443295/locations/us-west1/endpoints/8518052919822516224
    INFO:google.cloud.aiplatform.models:Deploying model to Endpoint : projects/371748443295/locations/us-west1/endpoints/8518052919822516224
    Deploy Endpoint model backing LRO: projects/371748443295/locations/us-west1/endpoints/8518052919822516224/operations/1435920954189414400
    INFO:google.cloud.aiplatform.models:Deploy Endpoint model backing LRO: projects/371748443295/locations/us-west1/endpoints/8518052919822516224/operations/1435920954189414400
    Endpoint model deployed. Resource name: projects/371748443295/locations/us-west1/endpoints/8518052919822516224
    INFO:google.cloud.aiplatform.models:Endpoint model deployed. Resource name: projects/371748443295/locations/us-west1/endpoints/8518052919822516224
    Model deployed to endpoint: projects/371748443295/locations/us-west1/endpoints/8518052919822516224

In [None]:
flant5 = FlanT5Client() #uses FLANT5_ENDPOINT set in utils.rag_constants.py by default
# Test prediction
test_text = "Translate to French: Hello, how are you?"
response = flant5.generate_content(test_text)
print(f"Test prediction response: {response}")