Skip to content

Commit

Permalink
Charlemagne mid level api (#1262)
Browse files Browse the repository at this point in the history
* low-level-api

* mid-level api

* add viz file

* insert visualization in key and associated table in key_table

* comptibility with 2.7

* Cleanups, pass fields and string fields separately.

* Include color and id in default fields

* tox

* Include _vis_ids file

* Make line importable, adapt precision_recall.py

* Add histogram and roc uses new plots

* Add scatter

* Update for production implementation.

* bar plot

* New wandb.plot.* plots

* Reset wandb.plots.* to master branch

* Get rid of plots in wrong directory.

* Deprecation notices for all wandb.plots.* plots

* Clean up old plots and add docstrings

* Fix flake

* Add basic_plots and tweets standalone tests

* Fix more flake

* Fix flake again?

* Remove deprecation notice, fix Python2

* Run codemode

Co-authored-by: Shawn Lewis <shlewis@gmail.com>
Co-authored-by: Jeff Raubitschek <jeff@wandb.com>
  • Loading branch information
3 people committed Oct 7, 2020
1 parent 9288adf commit 8d913e5
Show file tree
Hide file tree
Showing 24 changed files with 9,836 additions and 9 deletions.
27 changes: 27 additions & 0 deletions standalone_tests/basic_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import wandb
import random
import math

wandb.init(entity='wandb', project='new-plots-test-5')
data = [[i, random.random() + math.sin(i / 10)] for i in range(100)]
table = wandb.Table(data=data, columns=["step", "height"])
line_plot = wandb.plot.line(table, x='step', y='height', title='what a great line plot')
histogram = wandb.plot.histogram(table, value='height', title='my-histo')
scatter = wandb.plot.scatter(table, x='step', y='height', title='scatter!')

bar_table = wandb.Table(data=[
['car', random.random()],
['bus', random.random()],
['road', random.random()],
['person', random.random()],
['cyclist', random.random()],
['tree', random.random()],
['sky', random.random()]
], columns=["class", "acc"])
bar = wandb.plot.bar(bar_table, label='class', value='acc', title='bar')

wandb.log({
'line1': line_plot,
'histogram1': histogram,
'scatter1': scatter,
'bar1': bar})
9,289 changes: 9,289 additions & 0 deletions standalone_tests/tweets.csv

Large diffs are not rendered by default.

48 changes: 48 additions & 0 deletions standalone_tests/tweets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pandas as pd
import numpy as np
import wandb
from sklearn.metrics import confusion_matrix
wandb.init(entity='wandb', project="tweets-test-2")

# Get a pandas DataFrame object of all the data in the csv file:
df = pd.read_csv('tweets.csv')

# Get pandas Series object of the "tweet text" column:
text = df['tweet_text']

# Get pandas Series object of the "emotion" column:
target = df['is_there_an_emotion_directed_at_a_brand_or_product']

# Remove the blank rows from the series:
target = target[pd.notnull(text)]
text = text[pd.notnull(text)]

# Perform feature extraction:
from sklearn.feature_extraction.text import CountVectorizer
count_vect = CountVectorizer()
count_vect.fit(text)
counts = count_vect.transform(text)

counts_train = counts[:6000]
target_train = target[:6000]
counts_test = counts[6000:]
target_test = target[6000:]

# Train with this data with a Naive Bayes classifier:
from sklearn.naive_bayes import MultinomialNB

nb = MultinomialNB()
nb.fit(counts, target)

X_test = counts_test
y_test = target_test
y_probas = nb.predict_proba(X_test)
y_pred = nb.predict(X_test)

print("y", y_probas.shape)

# ROC
wandb.log({'roc': wandb.plot.roc_curve(y_test, y_probas, nb.classes_)})

# Precision Recall
wandb.log({'pr': wandb.plot.pr_curve(y_test, y_probas, nb.classes_)})
7 changes: 6 additions & 1 deletion wandb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@

# from wandb.core import *
from wandb.viz import visualize
from wandb import plots
from wandb import plot
from wandb import plots # deprecating this
from wandb.integration.sagemaker import sagemaker_auth


Expand Down Expand Up @@ -125,6 +126,10 @@ def _is_internal_process():
log_artifact = _preinit.PreInitCallable(
"wandb.log_artifact", wandb_sdk.wandb_run.Run.log_artifact
)
plot_table = _preinit.PreInitCallable(
"wandb.plot_table", wandb_sdk.wandb_run.Run.plot_table
)

# record of patched libraries
patched = {"tensorboard": [], "keras": [], "gym": []}

Expand Down
3 changes: 3 additions & 0 deletions wandb/lib/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def set_global(
restore=None,
use_artifact=None,
log_artifact=None,
plot_table=None,
):
if run:
wandb.run = run
Expand All @@ -28,6 +29,8 @@ def set_global(
wandb.use_artifact = use_artifact
if log_artifact:
wandb.log_artifact = log_artifact
if plot_table:
wandb.plot_table = plot_table


def unset_globals():
Expand Down
16 changes: 16 additions & 0 deletions wandb/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from wandb.plot.bar import bar
from wandb.plot.histogram import histogram
from wandb.plot.line import line
from wandb.plot.pr_curve import pr_curve
from wandb.plot.roc_curve import roc_curve
from wandb.plot.scatter import scatter


__all__ = [
"line",
"histogram",
"scatter",
"bar",
"roc_curve",
"pr_curve",
]
30 changes: 30 additions & 0 deletions wandb/plot/bar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import wandb


def bar(table, label, value, title=None):
"""
Construct a bar plot.
Arguments:
table (wandb.Table): Table of data.
label (string): Name of column to use as each bar's label.
value (string): Name of column to use as each bar's value.
title (string): Plot title.
Returns:
A plot object, to be passed to wandb.log()
Example:
table = wandb.Table(data=[
['car', random.random()],
['bus', random.random()],
['road', random.random()],
['person', random.random()],
], columns=["class", "acc"])
wandb.log({'bar-plot1': wandb.plot.bar(table, "class", "acc")})
"""
return wandb.plot_table(
'wandb/bar/v0',
table,
{'label': label, 'value': value},
{'title': title})
25 changes: 25 additions & 0 deletions wandb/plot/histogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import wandb


def histogram(table, value, title=None):
"""
Construct a histogram plot.
Arguments:
table (wandb.Table): Table of data.
label (string): Name of column to use as data for bucketing.
title (string): Plot title.
Returns:
A plot object, to be passed to wandb.log()
Example:
data = [[i, random.random() + math.sin(i / 10)] for i in range(100)]
table = wandb.Table(data=data, columns=["step", "height"])
wandb.log({'histogram-plot1': wandb.plot.histogram(table, "height")})
"""
return wandb.plot_table(
'wandb/histogram/v0',
table,
{'value': value},
{'title': title})
27 changes: 27 additions & 0 deletions wandb/plot/line.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import wandb


def line(table, x, y, stroke=None, title=None):
"""
Construct a line plot.
Arguments:
table (wandb.Table): Table of data.
x (string): Name of column to as for x-axis values.
y (string): Name of column to as for y-axis values.
stroke (string): Name of column to map to the line stroke scale.
title (string): Plot title.
Returns:
A plot object, to be passed to wandb.log()
Example:
data = [[i, random.random() + math.sin(i / 10)] for i in range(100)]
table = wandb.Table(data=data, columns=["step", "height"])
wandb.log({'line-plot1': wandb.plot.line(table, "step", "height")})
"""
return wandb.plot_table(
'wandb/line/v0',
table,
{'x': x, 'y': y, 'stroke': stroke},
{'title': title})
93 changes: 93 additions & 0 deletions wandb/plot/pr_curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import wandb
from wandb import util
from wandb.plots.utils import test_missing, test_types


chart_limit = wandb.Table.MAX_ROWS


def pr_curve(y_true=None, y_probas=None, labels=None, classes_to_plot=None):
"""
Computes the tradeoff between precision and recall for different thresholds.
A high area under the curve represents both high recall and high precision,
where high precision relates to a low false positive rate, and high recall
relates to a low false negative rate. High scores for both show that the
classifier is returning accurate results (high precision), as well as
returning a majority of all positive results (high recall).
PR curve is useful when the classes are very imbalanced.
Arguments:
y_true (arr): Test set labels.
y_probas (arr): Test set predicted probabilities.
labels (list): Named labels for target varible (y). Makes plots easier to
read by replacing target values with corresponding index.
For example labels= ['dog', 'cat', 'owl'] all 0s are
replaced by 'dog', 1s by 'cat'.
Returns:
Nothing. To see plots, go to your W&B run page then expand the 'media' tab
under 'auto visualizations'.
Example:
wandb.log({'pr-curve': wandb.plot.pr_curve(y_true, y_probas, labels)})
"""
np = util.get_module("numpy", required="roc requires the numpy library, install with `pip install numpy`")
scikit = util.get_module("sklearn", "roc requires the scikit library, install with `pip install scikit-learn`")

y_true = np.array(y_true)
y_probas = np.array(y_probas)

if (test_missing(y_true=y_true, y_probas=y_probas)
and test_types(y_true=y_true, y_probas=y_probas)):
classes = np.unique(y_true)
probas = y_probas

if classes_to_plot is None:
classes_to_plot = classes

binarized_y_true = scikit.preprocessing.label_binarize(y_true, classes=classes)
if len(classes) == 2:
binarized_y_true = np.hstack(
(1 - binarized_y_true, binarized_y_true))

pr_curves = {}
indices_to_plot = np.in1d(classes, classes_to_plot)
for i, to_plot in enumerate(indices_to_plot):
if to_plot:
precision, recall, _ = scikit.metrics.precision_recall_curve(
y_true, probas[:, i], pos_label=classes[i])

samples = 20
sample_precision = []
sample_recall = []
for k in range(samples):
sample_precision.append(precision[int(len(precision) * k / samples)])
sample_recall.append(recall[int(len(recall) * k / samples)])

pr_curves[classes[i]] = (sample_precision, sample_recall)

data = []
count = 0
for class_name in pr_curves.keys():
precision, recall = pr_curves[class_name]
for p, r in zip(precision, recall):
# if class_names are ints and labels are set
if labels is not None and (isinstance(class_name, int)
or isinstance(class_name, np.integer)):
class_name = labels[class_name]
# if class_names are ints and labels are not set
# or, if class_names have something other than ints
# (string, float, date) - user class_names
data.append([class_name, round(p, 3), round(r, 3)])
count += 1
if count >= chart_limit:
wandb.termwarn("wandb uses only the first %d datapoints to create the plots." % wandb.Table.MAX_ROWS)
break
table = wandb.Table(
columns=['class', 'precision', 'recall'],
data=data)
return wandb.plot_table(
'wandb/area-under-curve/v0',
table,
{'x': 'recall', 'y': 'precision', 'class': 'class'},
{'title': 'Precision v. Recall'})
72 changes: 72 additions & 0 deletions wandb/plot/roc_curve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import wandb
from wandb import util
from wandb.plots.utils import test_missing, test_types
chart_limit = wandb.Table.MAX_ROWS


def roc_curve(y_true=None, y_probas=None, labels=None, classes_to_plot=None):
"""
Calculates receiver operating characteristic scores and visualizes them as the
ROC curve.
Arguments:
y_true (arr): Test set labels.
y_probas (arr): Test set predicted probabilities.
labels (list): Named labels for target varible (y). Makes plots easier to
read by replacing target values with corresponding index.
For example labels= ['dog', 'cat', 'owl'] all 0s are
replaced by 'dog', 1s by 'cat'.
Returns:
Nothing. To see plots, go to your W&B run page then expand the 'media' tab
under 'auto visualizations'.
Example:
wandb.log({'roc-curve': wandb.plot.roc_curve(y_true, y_probas, labels)})
"""
np = util.get_module("numpy", required="roc requires the numpy library, install with `pip install numpy`")
util.get_module("sklearn", required="roc requires the scikit library, install with `pip install scikit-learn`")
from sklearn.metrics import roc_curve

if (test_missing(y_true=y_true, y_probas=y_probas)
and test_types(y_true=y_true, y_probas=y_probas)):
y_true = np.array(y_true)
y_probas = np.array(y_probas)
classes = np.unique(y_true)
probas = y_probas

if classes_to_plot is None:
classes_to_plot = classes

fpr_dict = dict()
tpr_dict = dict()

indices_to_plot = np.in1d(classes, classes_to_plot)

data = []
count = 0

for i, to_plot in enumerate(indices_to_plot):
fpr_dict[i], tpr_dict[i], _ = roc_curve(y_true, probas[:, i],
pos_label=classes[i])
if to_plot:
for j in range(len(fpr_dict[i])):
if labels is not None and (isinstance(classes[i], int)
or isinstance(classes[0], np.integer)):
class_dict = labels[classes[i]]
else:
class_dict = classes[i]
fpr = [class_dict, round(fpr_dict[i][j], 3), round(tpr_dict[i][j], 3)]
data.append(fpr)
count += 1
if count >= chart_limit:
wandb.termwarn("wandb uses only the first %d datapoints to create the plots." % wandb.Table.MAX_ROWS)
break
table = wandb.Table(
columns=['class', 'fpr', 'tpr'],
data=data)
return wandb.plot_table(
'wandb/area-under-curve/v0',
table,
{'x': 'fpr', 'y': 'tpr', 'class': 'class'},
{'title': 'ROC', 'x-axis-title': 'False positive rate', 'y-axis-title': 'True positive rate'})

0 comments on commit 8d913e5

Please sign in to comment.