Skip to content

Commit

Permalink
Add wandb.Audio support to Artifacts (#1694)
Browse files Browse the repository at this point in the history
* Add wandb.Audio support to Artifacts

* Lint

* Lint

* Add audio to table

Co-authored-by: David Jackson <davidjackson@Davids-MacBook-Pro.local>
  • Loading branch information
davidwallacejackson and David Jackson committed Jan 14, 2021
1 parent 402c707 commit 7696137
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 7 deletions.
29 changes: 23 additions & 6 deletions standalone_tests/artifact_object_reference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

import wandb

columns = ["class_id", "id", "bool", "int", "float", "Image", "Clouds", "HTML", "Video", "Bokeh"]
columns = ["class_id", "id", "bool", "int", "float", "Image", "Clouds", "HTML", "Video", "Bokeh", "Audio"]

def _make_wandb_image(suffix=""):
class_labels = {1: "tree", 2: "car", 3: "road"}
Expand Down Expand Up @@ -167,6 +167,20 @@ def _make_video():
vid3 = _make_video()
vid4 = _make_video()

def _make_wandb_audio(frequency, caption):
SAMPLE_RATE = 44100
DURATION_SECONDS = 1

data = np.sin(
2 * np.pi * np.arange(SAMPLE_RATE * DURATION_SECONDS) * frequency / SAMPLE_RATE
)
return wandb.Audio(data, SAMPLE_RATE, caption)

aud1 = _make_wandb_audio(440, "four forty")
aud2 = _make_wandb_audio(480, "four eighty")
aud3 = _make_wandb_audio(500, "five hundred")
aud4 = _make_wandb_audio(520, "five twenty")

def _make_wandb_table():
classes = wandb.Classes([
{"id": 1, "name": "tree"},
Expand All @@ -176,10 +190,10 @@ def _make_wandb_table():
table = wandb.Table(
columns=columns,
data=[
[1, "string", True, 1, 1.4, _make_wandb_image(), pc1, _make_html(), vid1, b1],
[2, "string", True, 1, 1.4, _make_wandb_image(), pc2, _make_html(), vid2, b2],
[1, "string2", False, -0, -1.4, _make_wandb_image("2"), pc3, _make_html(), vid3, b3],
[3, "string2", False, -0, -1.4, _make_wandb_image("2"), pc4, _make_html(), vid4, b4],
[1, "string", True, 1, 1.4, _make_wandb_image(), pc1, _make_html(), vid1, b1, aud1],
[2, "string", True, 1, 1.4, _make_wandb_image(), pc2, _make_html(), vid2, b2, aud2],
[1, "string2", False, -0, -1.4, _make_wandb_image("2"), pc3, _make_html(), vid3, b3, aud3],
[3, "string2", False, -0, -1.4, _make_wandb_image("2"), pc4, _make_html(), vid4, b4, aud4],
],
)
table.cast("class_id", classes.get_type())
Expand All @@ -188,10 +202,10 @@ def _make_wandb_table():
def _make_wandb_joinedtable():
return wandb.JoinedTable(_make_wandb_table(), _make_wandb_table(), "id")


def _b64_to_hex_id(id_string):
return binascii.hexlify(base64.standard_b64decode(str(id_string))).decode("utf-8")


# Artifact1.add_reference(artifact_URL) => recursive reference
def test_artifact_add_reference_via_url():
""" This test creates three artifacts. The middle artifact references the first artifact's file,
Expand Down Expand Up @@ -580,6 +594,8 @@ def test_video_refs():
def test_joined_table_refs():
assert_media_obj_referential_equality(_make_wandb_joinedtable())

def test_audio_refs():
assert_media_obj_referential_equality(_make_wandb_audio(440, "four forty"))

def test_joined_table_referential():
src_image_1 = _make_wandb_image()
Expand Down Expand Up @@ -709,6 +725,7 @@ def test_image_reference_with_preferred_path():
test_video_refs,
test_table_refs,
test_joined_table_refs,
test_audio_refs,
test_joined_table_referential,
test_joined_table_add_by_path,
test_image_reference_with_preferred_path,
Expand Down
22 changes: 21 additions & 1 deletion wandb/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,8 @@ class Audio(BatchableMedia):
caption (string): Caption to display with audio.
"""

artifact_type = "audio-file"

def __init__(self, data_or_path, sample_rate=None, caption=None):
"""Accepts a path to an audio file or a numpy array of audio data."""
super(Audio, self).__init__()
Expand Down Expand Up @@ -793,11 +795,19 @@ def __init__(self, data_or_path, sample_rate=None, caption=None):
def get_media_subdir(cls):
return os.path.join("media", "audio")

@classmethod
def from_json(cls, json_obj, source_artifact):
return cls(
source_artifact.get_path(json_obj["path"]).download(),
json_obj["sample_rate"],
json_obj["caption"],
)

def to_json(self, run):
json_dict = super(Audio, self).to_json(run)
json_dict.update(
{
"_type": "audio-file",
"_type": self.artifact_type,
"sample_rate": self._sample_rate,
"caption": self._caption,
}
Expand Down Expand Up @@ -847,6 +857,16 @@ def captions(cls, audio_list):
else:
return ["" if c is None else c for c in captions]

def __eq__(self, other):
return (
super(Audio, self).__eq__(other)
and self._sample_rate == other._sample_rate
and self._caption == other._caption
)

def __ne__(self, other):
return not self.__eq__(other)


def is_numpy_array(data):
np = util.get_module(
Expand Down

0 comments on commit 7696137

Please sign in to comment.