In [21]:
import os
from dotenv import load_dotenv

from ibm_watson_machine_learning.foundation_models.utils.enums import ModelTypes
from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from ibm_watson_machine_learning.foundation_models.extensions.langchain import WatsonxLLM
from ibm_watson_machine_learning.experiment import TuneExperiment
from ibm_watson_machine_learning import APIClient

In [22]:
load_dotenv()
api_key = os.getenv("API_KEY", None)
project_id = os.getenv("PROJECT_ID", None)

creds = {
    "url"    : "https://us-south.ml.cloud.ibm.com",
    "apikey" : api_key
}

client = APIClient(creds)

experiment = TuneExperiment(creds,
    project_id=project_id
)

prompt_tuner = experiment.prompt_tuner(
    name="prompt tuning name",
    task_id=experiment.Tasks.CLASSIFICATION,
    base_model=ModelTypes.FLAN_T5_XL,
    accumulate_steps=32,
    batch_size=16,
    learning_rate=0.2,
    max_input_tokens=256,
    max_output_tokens=2,
    num_epochs=6,
    tuning_type=experiment.PromptTuningTypes.PT,
    verbalizer="Extract the satisfaction from the comment. Return simple '1' for satisfied customer or '0' for unsatisfied. Input: {{input}} Output: ",
    auto_update_model=True
)


In [23]:
config_parameters = prompt_tuner.get_params()
print(config_parameters)

{'base_model': {'model_id': 'google/flan-t5-xl'}, 'accumulate_steps': 32, 'batch_size': 16, 'learning_rate': 0.2, 'max_input_tokens': 256, 'max_output_tokens': 2, 'num_epochs': 6, 'task_id': 'classification', 'tuning_type': 'prompt_tuning', 'verbalizer': "Extract the satisfaction from the comment. Return simple '1' for satisfied customer or '0' for unsatisfied. Input: {{input}} Output: ", 'name': 'prompt tuning name', 'description': 'Prompt tuning with SDK', 'auto_update_model': True, 'group_by_name': False}


In [24]:
client.set.default_project(project_id)
asset_details = client.data_assets.create(name="caneng1000.json", file_path="caneng1000.json")

Creating data asset...
SUCCESS


In [25]:
print(asset_details['metadata']['guid'])

1f36411e-bfaa-4dc8-acf8-28eba9630d1c


In [26]:
from ibm_watson_machine_learning.helpers import DataConnection, ContainerLocation

tuning_details = prompt_tuner.run(
    training_data_references=[DataConnection(
        data_asset_id=asset_details['metadata']['guid'])
    ],
    background_mode=True)

# OR

tuning_details = prompt_tuner.run(
    training_data_references=[DataConnection(
        location=ContainerLocation("caneng1000.json"))
    ],
    background_mode=True)

Failure during training. (POST https://us-south.ml.cloud.ibm.com/ml/v4/trainings?version=2023-12-05)
Status code: 400, body: {"trace":"a9e7dc784848fbedcf576d66ba76fdd7","errors":[{"code":"pt_unavailable_for_plan","message":"To use Tuning Studio, upgrade your Watson Machine Learning service plan 'Lite' to a paid plan.","more_info":"https://cloud.ibm.com/apidocs/machine-learning"}],"status_code":"400"}


ApiRequestFailure: Failure during training. (POST https://us-south.ml.cloud.ibm.com/ml/v4/trainings?version=2023-12-05)
Status code: 400, body: {"trace":"a9e7dc784848fbedcf576d66ba76fdd7","errors":[{"code":"pt_unavailable_for_plan","message":"To use Tuning Studio, upgrade your Watson Machine Learning service plan 'Lite' to a paid plan.","more_info":"https://cloud.ibm.com/apidocs/machine-learning"}],"status_code":"400"}

In [None]:
status = prompt_tuner.get_run_status()
print(status)

In [None]:
data_connections = prompt_tuner.get_data_connections()

# Get data in binary format
binary_data = data_connections[0].read(binary=True)

In [None]:
prompt_tuner.plot_learning_curve()

In [None]:
model_id = prompt_tuner.get_model_id()
print(model_id)

In [None]:
results = prompt_tuner.summary()
print(results)

In [6]:
from datasets import load_dataset

dataset_id = "samsum"
# Load dataset from the hub
dataset = load_dataset(dataset_id)

print(f"Train dataset size: {len(dataset['train'])}")
print(f"Test dataset size: {len(dataset['test'])}")

# Train dataset size: 14732
# Test dataset size: 819

NameError: name 'dataset_id' is not defined