Skip to content

Commit

Permalink
Merge 23aeda5 into b1330dd
Browse files Browse the repository at this point in the history
  • Loading branch information
timmahrt committed Jan 7, 2023
2 parents b1330dd + 23aeda5 commit 5ffcc6a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 39 deletions.
25 changes: 15 additions & 10 deletions praatio/data_classes/textgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Dict,
)
from typing_extensions import Literal
from collections import OrderedDict


from praatio.utilities.constants import (
Expand Down Expand Up @@ -50,8 +51,7 @@ def __init__(self, minTimestamp: float = None, maxTimestamp: float = None):
maxTimestamp: the maximum allowable timestamp in the textgrid
"""

self.tierNameList: List[str] = [] # Preserves the order of the tiers
self.tierDict: Dict[str, textgrid_tier.TextgridTier] = {}
self.tierDict: OrderedDict[str, textgrid_tier.TextgridTier] = OrderedDict()

# Timestamps are determined by the first tier added
self.minTimestamp: float = minTimestamp # type: ignore[assignment]
Expand All @@ -72,6 +72,10 @@ def __eq__(self, other):

return isEqual

@property
def tierNameList(self):
return tuple(self.tierDict.keys())

def addTier(
self,
tier: textgrid_tier.TextgridTier,
Expand Down Expand Up @@ -103,15 +107,17 @@ def addTier(
)
errorReporter = utils.getErrorReporter(reportingMode)

if tier.name in list(self.tierDict.keys()):
if tier.name in self.tierNameList:
raise errors.TierNameExistsError("Tier name already in tier")

if tierIndex is None:
self.tierNameList.append(tier.name)
else:
self.tierNameList.insert(tierIndex, tier.name)

tmpTierNameList = list(self.tierNameList)
self.tierDict[tier.name] = tier
if tierIndex is not None: # Need to recreate the tierDict with the new order
tmpTierNameList.insert(tierIndex, tier.name)
newTierDict = OrderedDict()
for tmpName in tmpTierNameList:
newTierDict[tmpName] = self.tierDict[tmpName]
self.tierDict = newTierDict

minV = tier.minTimestamp
if self.minTimestamp is not None and minV < self.minTimestamp:
Expand Down Expand Up @@ -148,7 +154,7 @@ def appendTextgrid(self, tg: "Textgrid", onlyMatchingNames: bool) -> "Textgrid":

# Get all tier names. Ordered first by this textgrid and
# then by the other textgrid.
combinedTierNameList = self.tierNameList[:]
combinedTierNameList = list(self.tierNameList)
for tierName in tg.tierNameList:
if tierName not in combinedTierNameList:
combinedTierNameList.append(tierName)
Expand Down Expand Up @@ -493,7 +499,6 @@ def renameTier(self, oldName: str, newName: str) -> None:
self.addTier(oldTier.new(newName, oldTier.entryList), tierIndex)

def removeTier(self, name: str) -> textgrid_tier.TextgridTier:
self.tierNameList.pop(self.tierNameList.index(name))
return self.tierDict.pop(name)

def replaceTier(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_openTextgrid_can_rename_tiers_when_textgrid_has_duplicate_tier_names(se
inputFN, False, duplicateNamesMode=constants.DuplicateNames.RENAME
)

self.assertEqual(["Mary", "Mary_2", "Mary_3"], sut.tierNameList)
self.assertSequenceEqual(["Mary", "Mary_2", "Mary_3"], sut.tierNameList)

def test_tg_io_long_vs_short(self):
"""Tests reading of long vs short textgrids"""
Expand Down
36 changes: 8 additions & 28 deletions tests/test_textgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_add_tier_can_add_a_tier_to_a_tg(self):
sut.addTier(tier2, reportingMode="error")
sut.addTier(tier3, reportingMode="error")

self.assertEqual(["words", "phrases", "max pitch"], sut.tierNameList)
self.assertSequenceEqual(["words", "phrases", "max pitch"], sut.tierNameList)
self.assertEqual(tier1, sut.tierDict["words"])
self.assertEqual(tier2, sut.tierDict["phrases"])
self.assertEqual(tier3, sut.tierDict["max pitch"])
Expand All @@ -99,7 +99,7 @@ def test_add_tier_can_add_a_tier_to_a_tg_at_specific_indices(self):
sut.addTier(tier3, tierIndex=1, reportingMode="error")

# tier3 was inserted last but with index 1, so it will appear second
self.assertEqual(["words", "max pitch", "phrases"], sut.tierNameList)
self.assertSequenceEqual(["words", "max pitch", "phrases"], sut.tierNameList)
self.assertEqual(tier1, sut.tierDict["words"])
self.assertEqual(tier3, sut.tierDict["max pitch"])
self.assertEqual(tier2, sut.tierDict["phrases"])
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_append_textgrid_with_matching_names_only(self):

self.assertEqual(0, sut.minTimestamp)
self.assertEqual(10, sut.maxTimestamp)
self.assertEqual(["words", "max pitch"], sut.tierNameList)
self.assertSequenceEqual(["words", "max pitch"], sut.tierNameList)
self.assertEqual(expectedTier1, sut.tierDict["words"])
self.assertEqual(expectedTier2, sut.tierDict["max pitch"])

Expand Down Expand Up @@ -190,7 +190,7 @@ def test_append_textgrid_without_matching_names_only(self):

self.assertEqual(0, sut.minTimestamp)
self.assertEqual(10, sut.maxTimestamp)
self.assertEqual(
self.assertSequenceEqual(
["words", "max pitch", "phrases", "min pitch", "cats", "dogs"],
sut.tierNameList,
)
Expand Down Expand Up @@ -844,8 +844,7 @@ def test_rename_tier_renames_a_tier(self):
expectedRenamedTier = makeIntervalTier("cats", [[5, 6.7, "hey there"]])

self.assertEqual(expectedRenamedTier, sut.tierDict["cats"])
self.assertEqual(["phrases", "cats", "phones"], sut.tierNameList)
self.assertCountEqual(["phrases", "cats", "phones"], sut.tierDict.keys())
self.assertSequenceEqual(["phrases", "cats", "phones"], sut.tierNameList)

def test_remove_tier_removes_a_tier(self):
sut = textgrid.Textgrid(0, 10)
Expand All @@ -860,8 +859,7 @@ def test_remove_tier_removes_a_tier(self):
removedTier = sut.removeTier("words")

self.assertEqual(removedTier, tier2)
self.assertEqual(["phrases", "phones"], sut.tierNameList)
self.assertCountEqual(["phrases", "phones"], sut.tierDict.keys())
self.assertSequenceEqual(["phrases", "phones"], sut.tierNameList)

def test_replace_tier_replaces_one_tier_with_another(self):
sut = textgrid.Textgrid(0, 10)
Expand All @@ -872,12 +870,11 @@ def test_replace_tier_replaces_one_tier_with_another(self):
sut.addTier(tier1)
sut.addTier(tier2)

self.assertEqual(["words", "phones"], sut.tierNameList)
self.assertSequenceEqual(["words", "phones"], sut.tierNameList)

sut.replaceTier("words", newTier1)

self.assertEqual(["cats", "phones"], sut.tierNameList)
self.assertCountEqual(["cats", "phones"], sut.tierDict.keys())
self.assertSequenceEqual(["cats", "phones"], sut.tierNameList)
self.assertEqual(newTier1, sut.tierDict["cats"])

def test_replace_tier_reports_if_new_tier_is_larger_than_textgrid(self):
Expand Down Expand Up @@ -909,23 +906,6 @@ def test_validate_throws_error_if_reporting_mode_is_invalid(self):
with self.assertRaises(errors.WrongOption) as _:
sut.validate("bird")

def test_validate_throws_error_if_two_tiers_have_the_same_name(self):
# Users shouldn't be manually manipulating the tierNameList
# (currently its the only way to trigger this error)
sut = textgrid.Textgrid()
tier1 = makeIntervalTier(name="phones")
tier2 = makeIntervalTier(name="words")

sut.addTier(tier1)
sut.addTier(tier2)
self.assertTrue(sut.validate())

sut.tierNameList.append("phones")
self.assertFalse(sut.validate(constants.ErrorReportingMode.SILENCE))

with self.assertRaises(errors.TierNameExistsError) as _:
sut.validate(constants.ErrorReportingMode.ERROR)

def test_validate_throws_error_if_tiers_and_textgrid_dont_agree_on_min_timestamp(
self,
):
Expand Down

0 comments on commit 5ffcc6a

Please sign in to comment.