Skip to content

Commit

Permalink
feat: Introduce a 'text' argument for the TextClassificationRecord (#…
Browse files Browse the repository at this point in the history
…1246)

* feat: introduce a 'text' argument for the TextClassificationRecord

* chore: deprecate providing str for the inputs parameter

* chore: allow same input to 'text' and 'inputs'

* test: add tests

* docs: update readme and index

* chore: adapt transformers monitor

* fix: fix prepare_for_training for datasets without annotations

* docs: adapt to new TextClassificationRecord(text=...)

* docs: adapt guides

* docs: adapt tutorials

* docs: adapt code

* test: fix test

* test: add one more test

* Update docs/getting_started/concepts.rst

Co-authored-by: Francisco Aranda <francisco@recogn.ai>

* feat: introduce a 'text' argument for the TextClassificationRecord

* fix: fix tests

Co-authored-by: Francisco Aranda <francisco@recogn.ai>
(cherry picked from commit 15d00a9)
  • Loading branch information
David Fidalgo authored and frascuchon committed Mar 30, 2022
1 parent 0321b88 commit bb7d93e
Show file tree
Hide file tree
Showing 29 changed files with 213 additions and 147 deletions.
6 changes: 3 additions & 3 deletions README.md
Expand Up @@ -140,7 +140,7 @@ The following code will log one record into a dataset called `example-dataset`:
import rubrix as rb

rb.log(
rb.TextClassificationRecord(inputs="My first Rubrix example"),
rb.TextClassificationRecord(text="My first Rubrix example"),
name='example-dataset'
)
```
Expand Down Expand Up @@ -210,7 +210,7 @@ for item in dataset:

records.append(
rb.TextClassificationRecord(
inputs=item["text"],
text=item["text"],
prediction=list(zip(prediction['labels'], prediction['scores']))
)
)
Expand All @@ -237,7 +237,7 @@ rb_df = rb_df[rb_df.status == "Validated"]

# select text input and the annotated label
train_df = pd.DataFrame({
"text": rb_df.inputs.transform(lambda r: r["text"]),
"text": rb_df.text,
"label": rb_df.annotation,
})
```
Expand Down
19 changes: 8 additions & 11 deletions docs/getting_started/concepts.rst
Expand Up @@ -22,13 +22,13 @@ Let's take a look at Rubrix's components and methods:
Dataset
^^^^^^^

A dataset is a collection of records stored in Rubrix. The main things you can do with a ``Dataset`` are to ``log`` records and to ``load`` the records of a ``Dataset`` into a ``Pandas.Dataframe`` from a Python app, script, or a Jupyter/Colab notebook.
A dataset is a collection of records stored in Rubrix. The main things you can do with a ``Dataset`` are to ``log`` records and to ``load`` the records of a ``Dataset`` into a ``Pandas.Dataframe`` from a Python app, script, or a Jupyter/Colab notebook.


Record
^^^^^^

A record is a data item composed of ``inputs`` and, optionally, ``predictions`` and ``annotations``. Usually, inputs are the information your model receives (for example: 'Macbeth').
A record is a data item composed of ``text`` inputs and, optionally, ``predictions`` and ``annotations``.

Think of predictions as the classification that your system made over that input (for example: 'Virginia Woolf'), and think of annotations as the ground truth that you manually assign to that input (because you know that, in this case, it would be 'William Shakespeare'). Records are defined by the type of ``Task``\ they are related to. Let's see three different examples:

Expand All @@ -42,10 +42,9 @@ Let's see examples of a spam classifier.
.. code-block:: python
record = rb.TextClassificationRecord(
inputs={
"text": "Access this link to get free discounts!"
},
prediction = [('SPAM', 0.8), ('HAM', 0.2)]
text="Access this link to get free discounts!",
prediction = [('SPAM', 0.8), ('HAM', 0.2)],
prediction_agent = "link or reference to agent",
annotation = "SPAM",
Expand All @@ -54,7 +53,6 @@ Let's see examples of a spam classifier.
metadata={ # Information about this record
"split": "train"
},
)
Multi-label text classification record
Expand All @@ -65,9 +63,8 @@ Another similar task to Text Classification, but yet a bit different, is Multi-l
.. code-block:: python
record = rb.TextClassificationRecord(
inputs={
"text": "I can't wait to travel to Egypts and visit the pyramids"
},
text="I can't wait to travel to Egypts and visit the pyramids",
multi_label = True,
prediction = [('travel', 0.8), ('history', 0.6), ('economy', 0.3), ('sports', 0.2)],
Expand Down Expand Up @@ -127,7 +124,7 @@ A prediction is a piece information assigned to a record, a label or a set of la
Metadata
^^^^^^^^

Metada will hold extra information that you want your record to have: if it belongs to the training or the test dataset, a quick fact about something regarding that specific record... Feel free to use it as you need!
Metada will hold extra information that you want your record to have: if it belongs to the training or the test dataset, a quick fact about something regarding that specific record... Feel free to use it as you need!

Methods
-------
Expand Down
2 changes: 1 addition & 1 deletion docs/getting_started/setup&installation.rst
Expand Up @@ -65,7 +65,7 @@ The following code will log one record into a data set called ``example-dataset`
import rubrix as rb
rb.log(
rb.TextClassificationRecord(inputs="My first Rubrix example"),
rb.TextClassificationRecord(text="My first Rubrix example"),
name='example-dataset'
)
Expand Down
50 changes: 20 additions & 30 deletions docs/guides/cookbook.ipynb
Expand Up @@ -130,7 +130,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -197,7 +197,7 @@
"4 6.291863e+17 If I make a game as a #windows10 Universal App..."
]
},
"execution_count": 17,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -257,26 +257,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from datasets import Dataset\n",
"import rubrix as rb\n",
"\n",
"# load rubrix dataset\n",
"df = rb.load('zeroshot_example')\n",
"\n",
"# inputs can be dicts to support multifield classifiers, we just use the text here. \n",
"df['text'] = df.inputs.transform(lambda r: r['text'])\n",
"\n",
"# we create a dict for turning our annotations (labels) into numeric ids\n",
"label2id = {label: id for id, label in enumerate(df.annotation.unique())}\n",
"dataset_rb = rb.load('zeroshot_example', as_pandas=False)\n",
"\n",
"\n",
"# create 🤗 dataset from pandas with labels as numeric ids\n",
"dataset = Dataset.from_pandas(df[['text', 'annotation']])\n",
"dataset = dataset.map(lambda example: {'labels': label2id[example['annotation']]})"
"# create 🤗 dataset with text and labels as numeric ids\n",
"train_ds = dataset_rb.prepare_for_training() "
]
},
{
Expand All @@ -296,7 +288,7 @@
"def tokenize_function(examples):\n",
" return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
"\n",
"train_dataset = dataset.map(tokenize_function, batched=True).shuffle(seed=42)\n",
"train_dataset = train_ds.map(tokenize_function, batched=True).shuffle(seed=42)\n",
"\n",
"trainer = Trainer(model=model, train_dataset=train_dataset)\n",
"\n",
Expand Down Expand Up @@ -736,10 +728,8 @@
"train_dataset = rb.load(\"tweet_eval_emojis\", limit=limit_num)\n",
"\n",
"# 2. Pre-processing training pandas dataframe\n",
"ready_input = [row['text'] for row in train_dataset.inputs]\n",
"\n",
"train_df = pd.DataFrame()\n",
"train_df['text'] = ready_input\n",
"train_df['text'] = train_dataset['text']\n",
"train_df['label'] = train_dataset['annotation']\n",
"\n",
"# 3. Save as csv with tab delimiter\n",
Expand Down Expand Up @@ -845,8 +835,8 @@
"labels = [\"happy\", \"sad\"]\n",
"\n",
"# Create a sentence\n",
"input_text = \"I am so glad you liked it!\"\n",
"sentence = Sentence(input_text)\n",
"text = \"I am so glad you liked it!\"\n",
"sentence = Sentence(text)\n",
"\n",
"# Predict for these labels\n",
"tars.predict_zero_shot(sentence, labels)\n",
Expand All @@ -857,7 +847,7 @@
"\n",
"# Building a TextClassificationRecord\n",
"record = rb.TextClassificationRecord(\n",
" inputs=input_text,\n",
" text=text,\n",
" prediction=prediction,\n",
" prediction_agent=\"tars-base\",\n",
")\n",
Expand Down Expand Up @@ -885,13 +875,13 @@
"from flair.models import TextClassifier\n",
"from flair.data import Sentence\n",
"\n",
"input_text = \"Du erzählst immer Quatsch.\" \n",
"text = \"Du erzählst immer Quatsch.\" \n",
"\n",
"# Load our pre-trained classifier\n",
"classifier = TextClassifier.load(\"de-offensive-language\")\n",
"\n",
"# Creating Sentence object\n",
"sentence = Sentence(input_text)\n",
"sentence = Sentence(text)\n",
"\n",
"# Make the prediction\n",
"classifier.predict(sentence, return_probabilities_for_all_classes=True)\n",
Expand All @@ -901,7 +891,7 @@
"\n",
"# Building a TextClassificationRecord\n",
"record = rb.TextClassificationRecord(\n",
" inputs=input_text,\n",
" text=text,\n",
" prediction=prediction,\n",
" prediction_agent=\"de-offensive-language\",\n",
")\n",
Expand Down Expand Up @@ -994,7 +984,7 @@
"import rubrix as rb\n",
"import stanza\n",
"\n",
"input_text = (\n",
"text = (\n",
" \"There are so many NLP libraries available, I don't know which one to choose!\"\n",
")\n",
"\n",
Expand All @@ -1005,7 +995,7 @@
"nlp = stanza.Pipeline(lang=\"en\", processors=\"tokenize,sentiment\")\n",
"\n",
"# Analizing the input text\n",
"doc = nlp(input_text)\n",
"doc = nlp(text)\n",
"\n",
"# This model returns 0 for negative, 1 for neutral and 2 for positive outcome.\n",
"# We are going to log them into Rubrix using a dictionary to translate numbers to labels.\n",
Expand All @@ -1026,7 +1016,7 @@
"\n",
"# Building a TextClassificationRecord\n",
"record = rb.TextClassificationRecord(\n",
" inputs=input_text,\n",
" text=text,\n",
" prediction=entities,\n",
" prediction_agent=\"stanza/en\",\n",
")\n",
Expand Down Expand Up @@ -1150,7 +1140,7 @@
"hash": "b709380ea7d1cb2eb4650c0f11ac7e002ec6a534602815725771481b4784238c"
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -1164,9 +1154,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
}
6 changes: 3 additions & 3 deletions docs/guides/datasets.ipynb
Expand Up @@ -42,7 +42,7 @@
" print(record)\n",
" \n",
"# Index into the dataset\n",
"dataset_rb[0] = rb.TextClassificationRecord(inputs=\"replace record\")\n",
"dataset_rb[0] = rb.TextClassificationRecord(text=\"replace record\")\n",
"\n",
"# log a dataset to the Rubrix web app\n",
"rb.log(dataset_rb, \"my_dataset\")"
Expand Down Expand Up @@ -154,7 +154,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -168,7 +168,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.8.12"
}
},
"nbformat": 4,
Expand Down
16 changes: 8 additions & 8 deletions docs/guides/metrics.ipynb

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions docs/guides/task_examples.ipynb
Expand Up @@ -71,7 +71,7 @@
" # Appending to the record list\n",
" records.append(\n",
" rb.TextClassificationRecord(\n",
" inputs=record[\"text\"],\n",
" text=record[\"text\"],\n",
" prediction=prediction,\n",
" prediction_agent=\"https://huggingface.co/typeform/squeezebert-mnli\",\n",
" metadata={\"split\": \"train\"},\n",
Expand Down Expand Up @@ -148,7 +148,7 @@
" # Appending to the record list\n",
" records.append(\n",
" rb.TextClassificationRecord(\n",
" inputs=record[\"content\"],\n",
" text=record[\"content\"],\n",
" prediction=prediction,\n",
" prediction_agent=\"https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment\",\n",
" metadata={\"split\": \"train\"},\n",
Expand Down Expand Up @@ -370,7 +370,7 @@
" # Appending to the record list\n",
" records.append(\n",
" rb.TextClassificationRecord(\n",
" inputs=record[\"statement\"],\n",
" text=record[\"statement\"],\n",
" prediction=prediction,\n",
" prediction_agent=\"https://huggingface.co/typeform/squeezebert-mnli\",\n",
" metadata={\"split\": \"train\"},\n",
Expand Down Expand Up @@ -434,7 +434,7 @@
" # Appending to the record list\n",
" records.append(\n",
" rb.TextClassificationRecord(\n",
" inputs=record[\"text\"],\n",
" text=record[\"text\"],\n",
" prediction=prediction,\n",
" prediction_agent=\"https://huggingface.co/typeform/squeezebert-mnli\",\n",
" metadata={\"split\": \"train\"},\n",
Expand Down Expand Up @@ -935,7 +935,7 @@
"hash": "4cac2ad44382dcbde9c9d45667b9ac0fec163e57feefe7fa8ea5d11fd16eb612"
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -949,9 +949,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.8.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
}

0 comments on commit bb7d93e

Please sign in to comment.