From 3ab74360ce7fd0905da8f9080fd00b5c52663ad5 Mon Sep 17 00:00:00 2001 From: jihan yin Date: Thu, 20 Jul 2023 00:59:15 +0000 Subject: [PATCH 1/4] example notebook --- examples/finetune_llama_2_on_science_qa.ipynb | 363 ++++++++++++++++++ 1 file changed, 363 insertions(+) create mode 100644 examples/finetune_llama_2_on_science_qa.ipynb diff --git a/examples/finetune_llama_2_on_science_qa.ipynb b/examples/finetune_llama_2_on_science_qa.ipynb new file mode 100644 index 00000000..d9bfc131 --- /dev/null +++ b/examples/finetune_llama_2_on_science_qa.ipynb @@ -0,0 +1,363 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "56096409", + "metadata": {}, + "source": [ + "# Finetune on ScienceQA\n", + "Let's use LLM Engine to fine-tune Llama-2 on ScienceQA!" + ] + }, + { + "cell_type": "markdown", + "id": "4d212455", + "metadata": {}, + "source": [ + "# Data Preparation\n", + "Let's load in the dataset using Huggingface and view the features." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7a701984", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using custom data configuration derek-thomas--ScienceQA-ca4903a3b5795914\n", + "Found cached dataset parquet (/home/ubuntu/.cache/huggingface/datasets/derek-thomas___parquet/derek-thomas--ScienceQA-ca4903a3b5795914/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bcbcca158cc74706a1294dd24f011cc8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
promptresponse
0Context: Sturgeons eat invertebrates, plants, ...B
1Context: People can use the engineering-design...C
2Context: Figure: Chicago.\\nChicago is known as...B
3Context: In a group of cows, some individuals ...A
4Context: Bald eagles eat fish, mammals, and ot...B
.........
2090Context: Select the best estimate.\\nQuestion: ...C
2091Context: Flat-tail horned lizards live in the ...B
2092Context: Read the description of a trait.\\nTim...B
2093Context: Garrett enjoys feeding the squirrels ...A
2094Context: Read the description of a trait.\\nRic...B
\n", + "

2095 rows × 2 columns

\n", + "" + ], + "text/plain": [ + " prompt response\n", + "0 Context: Sturgeons eat invertebrates, plants, ... B\n", + "1 Context: People can use the engineering-design... C\n", + "2 Context: Figure: Chicago.\\nChicago is known as... B\n", + "3 Context: In a group of cows, some individuals ... A\n", + "4 Context: Bald eagles eat fish, mammals, and ot... B\n", + "... ... ...\n", + "2090 Context: Select the best estimate.\\nQuestion: ... C\n", + "2091 Context: Flat-tail horned lizards live in the ... B\n", + "2092 Context: Read the description of a trait.\\nTim... B\n", + "2093 Context: Garrett enjoys feeding the squirrels ... A\n", + "2094 Context: Read the description of a trait.\\nRic... B\n", + "\n", + "[2095 rows x 2 columns]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "choice_prefixes = [chr(ord('A') + i) for i in range(26)] # A-Z\n", + "def format_options(options, choice_prefixes):\n", + " return ' '.join([f'({c}) {o}' for c, o in zip(choice_prefixes, options)])\n", + "\n", + "def format_prompt(r, choice_prefixes):\n", + " options = format_options(r['choices'], choice_prefixes)\n", + " return f'''Context: {r[\"hint\"]}\\nQuestion: {r[\"question\"]}\\nOptions:{options}\\nAnswer:'''\n", + "\n", + "def format_label(r, choice_prefixes):\n", + " return choice_prefixes[r['answer']]\n", + "\n", + "def convert_dataset(ds):\n", + " prompts = [format_prompt(i, choice_prefixes) for i in ds if i['hint'] != '']\n", + " labels = [format_label(i, choice_prefixes) for i in ds if i['hint'] != '']\n", + " df = pd.DataFrame.from_dict({'prompt': prompts, 'response': labels})\n", + " return df\n", + "\n", + "train_s3_uri = 's3://...'\n", + "val_s3_uri = 's3://...'\n", + "df_train = convert_dataset(dataset['train'])\n", + "#with smart_open(train_s3_uri, 'wb') as f:\n", + "# df.to_csv(f)\n", + " \n", + "df_train = convert_dataset(dataset['validation'])\n", + "#with smart_open(val_s3_uri, 'wb') as f:\n", + "# df.to_csv(f)\n", + " \n", + "df_train" + ] + }, + { + "cell_type": "markdown", + "id": "9188f0d1", + "metadata": {}, + "source": [ + "# Fine-tune\n", + "Now, we can fine-tune the model using LLM Engine." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1736c00", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ['SCALE_API_KEY'] = 'xxx'\n", + "\n", + "from llmengine import FineTune\n", + "\n", + "response = FineTune.create(\n", + " model=\"llama-2-7b\",\n", + " training_file=train_s3_uri,\n", + " validation_file=val_s3_uri,\n", + " hyperparameters={\n", + " 'lr':2e-4,\n", + " },\n", + " suffix='science-qa-llama'\n", + ")\n", + "run_id = response.fine_tune_id" + ] + }, + { + "cell_type": "markdown", + "id": "861ea698", + "metadata": {}, + "source": [ + "We can sleep until the job completes." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5b6cad2c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BatchJobStatus.SUCCESS\n" + ] + } + ], + "source": [ + "while True:\n", + " job_status = FineTune.get(run_id).status\n", + " print(job_status)\n", + " if job_status == 'SUCCESS':\n", + " break\n", + " time.sleep(60)\n", + " \n", + "ft_model = FineTune.get(run_id).fine_tuned_model" + ] + }, + { + "cell_type": "markdown", + "id": "34dd15b6", + "metadata": {}, + "source": [ + "# Evaluation\n", + "Let's evaluate the new fine-tuned model by running inference against it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50134463", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "# Helper function to get outputs for fine-tuned model with retries\n", + "def get_output(prompt: str, num_retry: int = 5):\n", + " for _ in range(num_retry):\n", + " try:\n", + " response = Completion.create(\n", + " model=fine_tuned_model, \n", + " prompt=prompt, \n", + " max_new_tokens=1, \n", + " temperature=0.01\n", + " )\n", + " return response.output.text.strip()\n", + " except Exception as e:\n", + " print(e)\n", + " return \"\"\n", + "\n", + "# Read the test data\n", + "test = pd.read_csv(\"\")\n", + "\n", + "test[\"prediction\"] = test[\"prompt\"].apply(get_output)\n", + "print(f\"Accuracy: {(test['response'] == test[\"prediction\"]).mean() * 100:.2f}%\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Environment (conda_pytorch_p38)", + "language": "python", + "name": "conda_pytorch_p38" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From f450381fdb330910053ba3fdef54a765e7270f7f Mon Sep 17 00:00:00 2001 From: jihan yin Date: Thu, 20 Jul 2023 06:45:03 +0000 Subject: [PATCH 2/4] example notebook --- examples/finetune_llama_2_on_science_qa.ipynb | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/finetune_llama_2_on_science_qa.ipynb b/examples/finetune_llama_2_on_science_qa.ipynb index d9bfc131..9d738ef4 100644 --- a/examples/finetune_llama_2_on_science_qa.ipynb +++ b/examples/finetune_llama_2_on_science_qa.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "56096409", + "id": "a27c626d", "metadata": {}, "source": [ "# Finetune on ScienceQA\n", @@ -11,7 +11,7 @@ }, { "cell_type": "markdown", - "id": "4d212455", + "id": "c9233ca5", "metadata": {}, "source": [ "# Data Preparation\n", @@ -21,7 +21,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "7a701984", + "id": "686b8a81", "metadata": {}, "outputs": [ { @@ -81,7 +81,7 @@ }, { "cell_type": "markdown", - "id": "4c4db72d", + "id": "0bdf99e9", "metadata": {}, "source": [ "Now, let's format the dataset into what's acceptable for LLM Engine - a CSV file with 'prompt' and 'response' columns." @@ -90,7 +90,7 @@ { "cell_type": "code", "execution_count": 3, - "id": "13a471ab", + "id": "1c124dad", "metadata": {}, "outputs": [ { @@ -234,7 +234,7 @@ }, { "cell_type": "markdown", - "id": "9188f0d1", + "id": "445690c9", "metadata": {}, "source": [ "# Fine-tune\n", @@ -244,7 +244,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f1736c00", + "id": "3a4d4d4a", "metadata": {}, "outputs": [], "source": [ @@ -267,7 +267,7 @@ }, { "cell_type": "markdown", - "id": "861ea698", + "id": "295d129a", "metadata": {}, "source": [ "We can sleep until the job completes." @@ -276,7 +276,7 @@ { "cell_type": "code", "execution_count": 6, - "id": "5b6cad2c", + "id": "3f0baac4", "metadata": {}, "outputs": [ { @@ -300,17 +300,17 @@ }, { "cell_type": "markdown", - "id": "34dd15b6", + "id": "132c5c74", "metadata": {}, "source": [ - "# Evaluation\n", + "# Inference and Evaluation\n", "Let's evaluate the new fine-tuned model by running inference against it." ] }, { "cell_type": "code", "execution_count": null, - "id": "50134463", + "id": "7720696a", "metadata": {}, "outputs": [], "source": [ From 913364a385ce5d882fae6ccc2f5b6a31cf8bdf4e Mon Sep 17 00:00:00 2001 From: jihan yin Date: Thu, 20 Jul 2023 17:53:35 +0000 Subject: [PATCH 3/4] update dataset urls to gh gist --- examples/finetune_llama_2_on_science_qa.ipynb | 269 +++++++++++++----- 1 file changed, 203 insertions(+), 66 deletions(-) diff --git a/examples/finetune_llama_2_on_science_qa.ipynb b/examples/finetune_llama_2_on_science_qa.ipynb index 9d738ef4..3a962359 100644 --- a/examples/finetune_llama_2_on_science_qa.ipynb +++ b/examples/finetune_llama_2_on_science_qa.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "a27c626d", + "id": "ccf0439b", "metadata": {}, "source": [ "# Finetune on ScienceQA\n", @@ -11,7 +11,7 @@ }, { "cell_type": "markdown", - "id": "c9233ca5", + "id": "c7a6abb6", "metadata": {}, "source": [ "# Data Preparation\n", @@ -21,7 +21,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "686b8a81", + "id": "f059c6c8", "metadata": {}, "outputs": [ { @@ -35,7 +35,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "bcbcca158cc74706a1294dd24f011cc8", + "model_id": "71b0a2cda6744e4e96b725a13c91b8a9", "version_major": 2, "version_minor": 0 }, @@ -81,7 +81,7 @@ }, { "cell_type": "markdown", - "id": "0bdf99e9", + "id": "e003e175", "metadata": {}, "source": [ "Now, let's format the dataset into what's acceptable for LLM Engine - a CSV file with 'prompt' and 'response' columns." @@ -89,8 +89,8 @@ }, { "cell_type": "code", - "execution_count": 3, - "id": "1c124dad", + "execution_count": 2, + "id": "e5704d20", "metadata": {}, "outputs": [ { @@ -121,28 +121,28 @@ " \n", " \n", " 0\n", - " Context: Sturgeons eat invertebrates, plants, ...\n", + " Context: The passage below describes an experi...\n", " B\n", " \n", " \n", " 1\n", - " Context: People can use the engineering-design...\n", - " C\n", + " Context: The passage below describes an experi...\n", + " A\n", " \n", " \n", " 2\n", - " Context: Figure: Chicago.\\nChicago is known as...\n", - " B\n", + " Context: This passage describes the myotonia c...\n", + " A\n", " \n", " \n", " 3\n", - " Context: In a group of cows, some individuals ...\n", - " A\n", + " Context: The diagrams below show two pure samp...\n", + " C\n", " \n", " \n", " 4\n", - " Context: Bald eagles eat fish, mammals, and ot...\n", - " B\n", + " Context: Below is a food web from an ocean eco...\n", + " A\n", " \n", " \n", " ...\n", @@ -150,53 +150,53 @@ " ...\n", " \n", " \n", - " 2090\n", - " Context: Select the best estimate.\\nQuestion: ...\n", - " C\n", + " 6074\n", + " Context: The images below show two pairs of ma...\n", + " A\n", " \n", " \n", - " 2091\n", - " Context: Flat-tail horned lizards live in the ...\n", - " B\n", + " 6075\n", + " Context: Select the better answer.\\nQuestion: ...\n", + " A\n", " \n", " \n", - " 2092\n", - " Context: Read the description of a trait.\\nTim...\n", - " B\n", + " 6076\n", + " Context: Read the description of a trait.\\nHan...\n", + " A\n", " \n", " \n", - " 2093\n", - " Context: Garrett enjoys feeding the squirrels ...\n", + " 6077\n", + " Context: The objects are identical except for ...\n", " A\n", " \n", " \n", - " 2094\n", - " Context: Read the description of a trait.\\nRic...\n", - " B\n", + " 6078\n", + " Context: Read the description of a trait.\\nTom...\n", + " A\n", " \n", " \n", "\n", - "

2095 rows × 2 columns

\n", + "

6079 rows × 2 columns

\n", "" ], "text/plain": [ " prompt response\n", - "0 Context: Sturgeons eat invertebrates, plants, ... B\n", - "1 Context: People can use the engineering-design... C\n", - "2 Context: Figure: Chicago.\\nChicago is known as... B\n", - "3 Context: In a group of cows, some individuals ... A\n", - "4 Context: Bald eagles eat fish, mammals, and ot... B\n", + "0 Context: The passage below describes an experi... B\n", + "1 Context: The passage below describes an experi... A\n", + "2 Context: This passage describes the myotonia c... A\n", + "3 Context: The diagrams below show two pure samp... C\n", + "4 Context: Below is a food web from an ocean eco... A\n", "... ... ...\n", - "2090 Context: Select the best estimate.\\nQuestion: ... C\n", - "2091 Context: Flat-tail horned lizards live in the ... B\n", - "2092 Context: Read the description of a trait.\\nTim... B\n", - "2093 Context: Garrett enjoys feeding the squirrels ... A\n", - "2094 Context: Read the description of a trait.\\nRic... B\n", + "6074 Context: The images below show two pairs of ma... A\n", + "6075 Context: Select the better answer.\\nQuestion: ... A\n", + "6076 Context: Read the description of a trait.\\nHan... A\n", + "6077 Context: The objects are identical except for ... A\n", + "6078 Context: Read the description of a trait.\\nTom... A\n", "\n", - "[2095 rows x 2 columns]" + "[6079 rows x 2 columns]" ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -219,22 +219,29 @@ " df = pd.DataFrame.from_dict({'prompt': prompts, 'response': labels})\n", " return df\n", "\n", - "train_s3_uri = 's3://...'\n", - "val_s3_uri = 's3://...'\n", + "save_to_s3 = False\n", "df_train = convert_dataset(dataset['train'])\n", - "#with smart_open(train_s3_uri, 'wb') as f:\n", - "# df.to_csv(f)\n", - " \n", - "df_train = convert_dataset(dataset['validation'])\n", - "#with smart_open(val_s3_uri, 'wb') as f:\n", - "# df.to_csv(f)\n", + "if save_to_s3:\n", + " train_url = 's3://...'\n", + " val_url = 's3://...'\n", + " df_train = convert_dataset(dataset['train'])\n", + " with smart_open(train_url, 'wb') as f:\n", + " df_train.to_csv(f)\n", + "\n", + " df_val = convert_dataset(dataset['validation'])\n", + " with smart_open(val_url, 'wb') as f:\n", + " df_val.to_csv(f)\n", + "else:\n", + " # Gists of the already processed datasets\n", + " train_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_train.csv'\n", + " val_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_val.csv'\n", " \n", "df_train" ] }, { "cell_type": "markdown", - "id": "445690c9", + "id": "138a4b57", "metadata": {}, "source": [ "# Fine-tune\n", @@ -243,8 +250,8 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "3a4d4d4a", + "execution_count": 3, + "id": "d0222175", "metadata": {}, "outputs": [], "source": [ @@ -255,8 +262,8 @@ "\n", "response = FineTune.create(\n", " model=\"llama-2-7b\",\n", - " training_file=train_s3_uri,\n", - " validation_file=val_s3_uri,\n", + " training_file=train_url,\n", + " validation_file=val_url,\n", " hyperparameters={\n", " 'lr':2e-4,\n", " },\n", @@ -267,7 +274,7 @@ }, { "cell_type": "markdown", - "id": "295d129a", + "id": "1a3e80b4", "metadata": {}, "source": [ "We can sleep until the job completes." @@ -275,19 +282,124 @@ }, { "cell_type": "code", - "execution_count": 6, - "id": "3f0baac4", + "execution_count": 5, + "id": "214a9593", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", + "BatchJobStatus.RUNNING\n", "BatchJobStatus.SUCCESS\n" ] } ], "source": [ + "import time\n", + "\n", "while True:\n", " job_status = FineTune.get(run_id).status\n", " print(job_status)\n", @@ -295,12 +407,12 @@ " break\n", " time.sleep(60)\n", " \n", - "ft_model = FineTune.get(run_id).fine_tuned_model" + "fine_tuned_model = FineTune.get(run_id).fine_tuned_model" ] }, { "cell_type": "markdown", - "id": "132c5c74", + "id": "1d6614cf", "metadata": {}, "source": [ "# Inference and Evaluation\n", @@ -309,12 +421,29 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "7720696a", + "execution_count": 11, + "id": "eba61cf2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", + "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", + "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", + "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", + "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", + "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", + "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", + "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", + "Accuracy: 81.48%\n" + ] + } + ], "source": [ "import pandas as pd\n", + "from llmengine import Completion\n", "\n", "# Helper function to get outputs for fine-tuned model with retries\n", "def get_output(prompt: str, num_retry: int = 5):\n", @@ -332,11 +461,19 @@ " return \"\"\n", "\n", "# Read the test data\n", - "test = pd.read_csv(\"\")\n", + "test = pd.read_csv(val_url)\n", "\n", "test[\"prediction\"] = test[\"prompt\"].apply(get_output)\n", - "print(f\"Accuracy: {(test['response'] == test[\"prediction\"]).mean() * 100:.2f}%\")" + "print(f\"Accuracy: {(test['response'] == test['prediction']).mean() * 100:.2f}%\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af7ec1e4", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From 530e7c1135ba15e3ddeeed5a517b462f2ddd29f7 Mon Sep 17 00:00:00 2001 From: jihan yin Date: Thu, 20 Jul 2023 18:53:30 +0000 Subject: [PATCH 4/4] clear notebook output --- examples/finetune_llama_2_on_science_qa.ipynb | 323 ++---------------- 1 file changed, 21 insertions(+), 302 deletions(-) diff --git a/examples/finetune_llama_2_on_science_qa.ipynb b/examples/finetune_llama_2_on_science_qa.ipynb index 3a962359..9b4f77a4 100644 --- a/examples/finetune_llama_2_on_science_qa.ipynb +++ b/examples/finetune_llama_2_on_science_qa.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "ccf0439b", + "id": "8d3a4214", "metadata": {}, "source": [ "# Finetune on ScienceQA\n", @@ -11,7 +11,7 @@ }, { "cell_type": "markdown", - "id": "c7a6abb6", + "id": "a3dc2a56", "metadata": {}, "source": [ "# Data Preparation\n", @@ -20,56 +20,10 @@ }, { "cell_type": "code", - "execution_count": 1, - "id": "f059c6c8", + "execution_count": null, + "id": "e06ac39e", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using custom data configuration derek-thomas--ScienceQA-ca4903a3b5795914\n", - "Found cached dataset parquet (/home/ubuntu/.cache/huggingface/datasets/derek-thomas___parquet/derek-thomas--ScienceQA-ca4903a3b5795914/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "71b0a2cda6744e4e96b725a13c91b8a9", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/3 [00:00\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
promptresponse
0Context: The passage below describes an experi...B
1Context: The passage below describes an experi...A
2Context: This passage describes the myotonia c...A
3Context: The diagrams below show two pure samp...C
4Context: Below is a food web from an ocean eco...A
.........
6074Context: The images below show two pairs of ma...A
6075Context: Select the better answer.\\nQuestion: ...A
6076Context: Read the description of a trait.\\nHan...A
6077Context: The objects are identical except for ...A
6078Context: Read the description of a trait.\\nTom...A
\n", - "

6079 rows × 2 columns

\n", - "" - ], - "text/plain": [ - " prompt response\n", - "0 Context: The passage below describes an experi... B\n", - "1 Context: The passage below describes an experi... A\n", - "2 Context: This passage describes the myotonia c... A\n", - "3 Context: The diagrams below show two pure samp... C\n", - "4 Context: Below is a food web from an ocean eco... A\n", - "... ... ...\n", - "6074 Context: The images below show two pairs of ma... A\n", - "6075 Context: Select the better answer.\\nQuestion: ... A\n", - "6076 Context: Read the description of a trait.\\nHan... A\n", - "6077 Context: The objects are identical except for ... A\n", - "6078 Context: Read the description of a trait.\\nTom... A\n", - "\n", - "[6079 rows x 2 columns]" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "choice_prefixes = [chr(ord('A') + i) for i in range(26)] # A-Z\n", "def format_options(options, choice_prefixes):\n", @@ -241,7 +87,7 @@ }, { "cell_type": "markdown", - "id": "138a4b57", + "id": "e2fc8d76", "metadata": {}, "source": [ "# Fine-tune\n", @@ -250,8 +96,8 @@ }, { "cell_type": "code", - "execution_count": 3, - "id": "d0222175", + "execution_count": null, + "id": "4905d447", "metadata": {}, "outputs": [], "source": [ @@ -274,7 +120,7 @@ }, { "cell_type": "markdown", - "id": "1a3e80b4", + "id": "55074457", "metadata": {}, "source": [ "We can sleep until the job completes." @@ -282,121 +128,10 @@ }, { "cell_type": "code", - "execution_count": 5, - "id": "214a9593", + "execution_count": null, + "id": "840938dd", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.RUNNING\n", - "BatchJobStatus.SUCCESS\n" - ] - } - ], + "outputs": [], "source": [ "import time\n", "\n", @@ -412,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "1d6614cf", + "id": "31278c6d", "metadata": {}, "source": [ "# Inference and Evaluation\n", @@ -421,26 +156,10 @@ }, { "cell_type": "code", - "execution_count": 11, - "id": "eba61cf2", + "execution_count": null, + "id": "3b9d7643", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", - "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", - "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", - "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", - "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", - "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", - "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", - "upstream connect error or disconnect/reset before headers. reset reason: connection failure, transport failure reason: delayed connect error: 111\n", - "Accuracy: 81.48%\n" - ] - } - ], + "outputs": [], "source": [ "import pandas as pd\n", "from llmengine import Completion\n", @@ -470,7 +189,7 @@ { "cell_type": "code", "execution_count": null, - "id": "af7ec1e4", + "id": "9f2f3f43", "metadata": {}, "outputs": [], "source": []