-
Notifications
You must be signed in to change notification settings - Fork 621
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* low-level-api * mid-level api * add viz file * insert visualization in key and associated table in key_table * comptibility with 2.7 * Cleanups, pass fields and string fields separately. * Include color and id in default fields * tox * Include _vis_ids file * Make line importable, adapt precision_recall.py * Add histogram and roc uses new plots * Add scatter * Update for production implementation. * bar plot * New wandb.plot.* plots * Reset wandb.plots.* to master branch * Get rid of plots in wrong directory. * Deprecation notices for all wandb.plots.* plots * Clean up old plots and add docstrings * Fix flake * Add basic_plots and tweets standalone tests * Fix more flake * Fix flake again? * Remove deprecation notice, fix Python2 * Run codemode Co-authored-by: Shawn Lewis <shlewis@gmail.com> Co-authored-by: Jeff Raubitschek <jeff@wandb.com>
- Loading branch information
1 parent
9288adf
commit 8d913e5
Showing
24 changed files
with
9,836 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import wandb | ||
import random | ||
import math | ||
|
||
wandb.init(entity='wandb', project='new-plots-test-5') | ||
data = [[i, random.random() + math.sin(i / 10)] for i in range(100)] | ||
table = wandb.Table(data=data, columns=["step", "height"]) | ||
line_plot = wandb.plot.line(table, x='step', y='height', title='what a great line plot') | ||
histogram = wandb.plot.histogram(table, value='height', title='my-histo') | ||
scatter = wandb.plot.scatter(table, x='step', y='height', title='scatter!') | ||
|
||
bar_table = wandb.Table(data=[ | ||
['car', random.random()], | ||
['bus', random.random()], | ||
['road', random.random()], | ||
['person', random.random()], | ||
['cyclist', random.random()], | ||
['tree', random.random()], | ||
['sky', random.random()] | ||
], columns=["class", "acc"]) | ||
bar = wandb.plot.bar(bar_table, label='class', value='acc', title='bar') | ||
|
||
wandb.log({ | ||
'line1': line_plot, | ||
'histogram1': histogram, | ||
'scatter1': scatter, | ||
'bar1': bar}) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import pandas as pd | ||
import numpy as np | ||
import wandb | ||
from sklearn.metrics import confusion_matrix | ||
wandb.init(entity='wandb', project="tweets-test-2") | ||
|
||
# Get a pandas DataFrame object of all the data in the csv file: | ||
df = pd.read_csv('tweets.csv') | ||
|
||
# Get pandas Series object of the "tweet text" column: | ||
text = df['tweet_text'] | ||
|
||
# Get pandas Series object of the "emotion" column: | ||
target = df['is_there_an_emotion_directed_at_a_brand_or_product'] | ||
|
||
# Remove the blank rows from the series: | ||
target = target[pd.notnull(text)] | ||
text = text[pd.notnull(text)] | ||
|
||
# Perform feature extraction: | ||
from sklearn.feature_extraction.text import CountVectorizer | ||
count_vect = CountVectorizer() | ||
count_vect.fit(text) | ||
counts = count_vect.transform(text) | ||
|
||
counts_train = counts[:6000] | ||
target_train = target[:6000] | ||
counts_test = counts[6000:] | ||
target_test = target[6000:] | ||
|
||
# Train with this data with a Naive Bayes classifier: | ||
from sklearn.naive_bayes import MultinomialNB | ||
|
||
nb = MultinomialNB() | ||
nb.fit(counts, target) | ||
|
||
X_test = counts_test | ||
y_test = target_test | ||
y_probas = nb.predict_proba(X_test) | ||
y_pred = nb.predict(X_test) | ||
|
||
print("y", y_probas.shape) | ||
|
||
# ROC | ||
wandb.log({'roc': wandb.plot.roc_curve(y_test, y_probas, nb.classes_)}) | ||
|
||
# Precision Recall | ||
wandb.log({'pr': wandb.plot.pr_curve(y_test, y_probas, nb.classes_)}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from wandb.plot.bar import bar | ||
from wandb.plot.histogram import histogram | ||
from wandb.plot.line import line | ||
from wandb.plot.pr_curve import pr_curve | ||
from wandb.plot.roc_curve import roc_curve | ||
from wandb.plot.scatter import scatter | ||
|
||
|
||
__all__ = [ | ||
"line", | ||
"histogram", | ||
"scatter", | ||
"bar", | ||
"roc_curve", | ||
"pr_curve", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import wandb | ||
|
||
|
||
def bar(table, label, value, title=None): | ||
""" | ||
Construct a bar plot. | ||
Arguments: | ||
table (wandb.Table): Table of data. | ||
label (string): Name of column to use as each bar's label. | ||
value (string): Name of column to use as each bar's value. | ||
title (string): Plot title. | ||
Returns: | ||
A plot object, to be passed to wandb.log() | ||
Example: | ||
table = wandb.Table(data=[ | ||
['car', random.random()], | ||
['bus', random.random()], | ||
['road', random.random()], | ||
['person', random.random()], | ||
], columns=["class", "acc"]) | ||
wandb.log({'bar-plot1': wandb.plot.bar(table, "class", "acc")}) | ||
""" | ||
return wandb.plot_table( | ||
'wandb/bar/v0', | ||
table, | ||
{'label': label, 'value': value}, | ||
{'title': title}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import wandb | ||
|
||
|
||
def histogram(table, value, title=None): | ||
""" | ||
Construct a histogram plot. | ||
Arguments: | ||
table (wandb.Table): Table of data. | ||
label (string): Name of column to use as data for bucketing. | ||
title (string): Plot title. | ||
Returns: | ||
A plot object, to be passed to wandb.log() | ||
Example: | ||
data = [[i, random.random() + math.sin(i / 10)] for i in range(100)] | ||
table = wandb.Table(data=data, columns=["step", "height"]) | ||
wandb.log({'histogram-plot1': wandb.plot.histogram(table, "height")}) | ||
""" | ||
return wandb.plot_table( | ||
'wandb/histogram/v0', | ||
table, | ||
{'value': value}, | ||
{'title': title}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import wandb | ||
|
||
|
||
def line(table, x, y, stroke=None, title=None): | ||
""" | ||
Construct a line plot. | ||
Arguments: | ||
table (wandb.Table): Table of data. | ||
x (string): Name of column to as for x-axis values. | ||
y (string): Name of column to as for y-axis values. | ||
stroke (string): Name of column to map to the line stroke scale. | ||
title (string): Plot title. | ||
Returns: | ||
A plot object, to be passed to wandb.log() | ||
Example: | ||
data = [[i, random.random() + math.sin(i / 10)] for i in range(100)] | ||
table = wandb.Table(data=data, columns=["step", "height"]) | ||
wandb.log({'line-plot1': wandb.plot.line(table, "step", "height")}) | ||
""" | ||
return wandb.plot_table( | ||
'wandb/line/v0', | ||
table, | ||
{'x': x, 'y': y, 'stroke': stroke}, | ||
{'title': title}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import wandb | ||
from wandb import util | ||
from wandb.plots.utils import test_missing, test_types | ||
|
||
|
||
chart_limit = wandb.Table.MAX_ROWS | ||
|
||
|
||
def pr_curve(y_true=None, y_probas=None, labels=None, classes_to_plot=None): | ||
""" | ||
Computes the tradeoff between precision and recall for different thresholds. | ||
A high area under the curve represents both high recall and high precision, | ||
where high precision relates to a low false positive rate, and high recall | ||
relates to a low false negative rate. High scores for both show that the | ||
classifier is returning accurate results (high precision), as well as | ||
returning a majority of all positive results (high recall). | ||
PR curve is useful when the classes are very imbalanced. | ||
Arguments: | ||
y_true (arr): Test set labels. | ||
y_probas (arr): Test set predicted probabilities. | ||
labels (list): Named labels for target varible (y). Makes plots easier to | ||
read by replacing target values with corresponding index. | ||
For example labels= ['dog', 'cat', 'owl'] all 0s are | ||
replaced by 'dog', 1s by 'cat'. | ||
Returns: | ||
Nothing. To see plots, go to your W&B run page then expand the 'media' tab | ||
under 'auto visualizations'. | ||
Example: | ||
wandb.log({'pr-curve': wandb.plot.pr_curve(y_true, y_probas, labels)}) | ||
""" | ||
np = util.get_module("numpy", required="roc requires the numpy library, install with `pip install numpy`") | ||
scikit = util.get_module("sklearn", "roc requires the scikit library, install with `pip install scikit-learn`") | ||
|
||
y_true = np.array(y_true) | ||
y_probas = np.array(y_probas) | ||
|
||
if (test_missing(y_true=y_true, y_probas=y_probas) | ||
and test_types(y_true=y_true, y_probas=y_probas)): | ||
classes = np.unique(y_true) | ||
probas = y_probas | ||
|
||
if classes_to_plot is None: | ||
classes_to_plot = classes | ||
|
||
binarized_y_true = scikit.preprocessing.label_binarize(y_true, classes=classes) | ||
if len(classes) == 2: | ||
binarized_y_true = np.hstack( | ||
(1 - binarized_y_true, binarized_y_true)) | ||
|
||
pr_curves = {} | ||
indices_to_plot = np.in1d(classes, classes_to_plot) | ||
for i, to_plot in enumerate(indices_to_plot): | ||
if to_plot: | ||
precision, recall, _ = scikit.metrics.precision_recall_curve( | ||
y_true, probas[:, i], pos_label=classes[i]) | ||
|
||
samples = 20 | ||
sample_precision = [] | ||
sample_recall = [] | ||
for k in range(samples): | ||
sample_precision.append(precision[int(len(precision) * k / samples)]) | ||
sample_recall.append(recall[int(len(recall) * k / samples)]) | ||
|
||
pr_curves[classes[i]] = (sample_precision, sample_recall) | ||
|
||
data = [] | ||
count = 0 | ||
for class_name in pr_curves.keys(): | ||
precision, recall = pr_curves[class_name] | ||
for p, r in zip(precision, recall): | ||
# if class_names are ints and labels are set | ||
if labels is not None and (isinstance(class_name, int) | ||
or isinstance(class_name, np.integer)): | ||
class_name = labels[class_name] | ||
# if class_names are ints and labels are not set | ||
# or, if class_names have something other than ints | ||
# (string, float, date) - user class_names | ||
data.append([class_name, round(p, 3), round(r, 3)]) | ||
count += 1 | ||
if count >= chart_limit: | ||
wandb.termwarn("wandb uses only the first %d datapoints to create the plots." % wandb.Table.MAX_ROWS) | ||
break | ||
table = wandb.Table( | ||
columns=['class', 'precision', 'recall'], | ||
data=data) | ||
return wandb.plot_table( | ||
'wandb/area-under-curve/v0', | ||
table, | ||
{'x': 'recall', 'y': 'precision', 'class': 'class'}, | ||
{'title': 'Precision v. Recall'}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import wandb | ||
from wandb import util | ||
from wandb.plots.utils import test_missing, test_types | ||
chart_limit = wandb.Table.MAX_ROWS | ||
|
||
|
||
def roc_curve(y_true=None, y_probas=None, labels=None, classes_to_plot=None): | ||
""" | ||
Calculates receiver operating characteristic scores and visualizes them as the | ||
ROC curve. | ||
Arguments: | ||
y_true (arr): Test set labels. | ||
y_probas (arr): Test set predicted probabilities. | ||
labels (list): Named labels for target varible (y). Makes plots easier to | ||
read by replacing target values with corresponding index. | ||
For example labels= ['dog', 'cat', 'owl'] all 0s are | ||
replaced by 'dog', 1s by 'cat'. | ||
Returns: | ||
Nothing. To see plots, go to your W&B run page then expand the 'media' tab | ||
under 'auto visualizations'. | ||
Example: | ||
wandb.log({'roc-curve': wandb.plot.roc_curve(y_true, y_probas, labels)}) | ||
""" | ||
np = util.get_module("numpy", required="roc requires the numpy library, install with `pip install numpy`") | ||
util.get_module("sklearn", required="roc requires the scikit library, install with `pip install scikit-learn`") | ||
from sklearn.metrics import roc_curve | ||
|
||
if (test_missing(y_true=y_true, y_probas=y_probas) | ||
and test_types(y_true=y_true, y_probas=y_probas)): | ||
y_true = np.array(y_true) | ||
y_probas = np.array(y_probas) | ||
classes = np.unique(y_true) | ||
probas = y_probas | ||
|
||
if classes_to_plot is None: | ||
classes_to_plot = classes | ||
|
||
fpr_dict = dict() | ||
tpr_dict = dict() | ||
|
||
indices_to_plot = np.in1d(classes, classes_to_plot) | ||
|
||
data = [] | ||
count = 0 | ||
|
||
for i, to_plot in enumerate(indices_to_plot): | ||
fpr_dict[i], tpr_dict[i], _ = roc_curve(y_true, probas[:, i], | ||
pos_label=classes[i]) | ||
if to_plot: | ||
for j in range(len(fpr_dict[i])): | ||
if labels is not None and (isinstance(classes[i], int) | ||
or isinstance(classes[0], np.integer)): | ||
class_dict = labels[classes[i]] | ||
else: | ||
class_dict = classes[i] | ||
fpr = [class_dict, round(fpr_dict[i][j], 3), round(tpr_dict[i][j], 3)] | ||
data.append(fpr) | ||
count += 1 | ||
if count >= chart_limit: | ||
wandb.termwarn("wandb uses only the first %d datapoints to create the plots." % wandb.Table.MAX_ROWS) | ||
break | ||
table = wandb.Table( | ||
columns=['class', 'fpr', 'tpr'], | ||
data=data) | ||
return wandb.plot_table( | ||
'wandb/area-under-curve/v0', | ||
table, | ||
{'x': 'fpr', 'y': 'tpr', 'class': 'class'}, | ||
{'title': 'ROC', 'x-axis-title': 'False positive rate', 'y-axis-title': 'True positive rate'}) |
Oops, something went wrong.