Skip to content
This repository has been archived by the owner on Jan 30, 2023. It is now read-only.

Commit

Permalink
a faster iterator for set partitions with given block sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
mantepse committed Jul 17, 2018
1 parent 7524eb5 commit 84ba2c8
Showing 1 changed file with 48 additions and 38 deletions.
86 changes: 48 additions & 38 deletions src/sage/combinat/set_partition.py
Expand Up @@ -56,6 +56,7 @@
from sage.functions.other import factorial
from sage.misc.prandom import random, randint
from sage.probability.probability_distribution import GeneralDiscreteDistribution
from sage.graphs.linearextensions import LinearExtensions

@add_metaclass(InheritComparisonClasscallMetaclass)
class AbstractSetPartition(ClonableArray):
Expand Down Expand Up @@ -1651,41 +1652,6 @@ def _element_constructor_(self, s, check=True):

Element = SetPartition

def _iterator_part(self, part):
"""
Return an iterator for the set partitions with block sizes
corresponding to the partition ``part``.
INPUT:
- ``part`` -- a :class:`Partition` object
EXAMPLES::
sage: S = SetPartitions(3)
sage: it = S._iterator_part(Partition([1,1,1]))
sage: sorted(map(list, next(it)))
[[1], [2], [3]]
sage: S21 = SetPartitions(3,Partition([2,1]))
sage: len(list(S._iterator_part(Partition([2,1])))) == S21.cardinality()
True
"""
nonzero = []
expo = [0] + part.to_exp()

for i in range(len(expo)):
if expo[i] != 0:
nonzero.append([i, expo[i]])

taillesblocs = [(x[0])*(x[1]) for x in nonzero]

blocs = OrderedSetPartitions(self._set, taillesblocs)

for b in blocs:
lb = [IterableFunctionCall(_listbloc, nonzero[i][0], nonzero[i][1], b[i]) for i in range(len(nonzero))]
for x in itertools.product(*lb):
yield _union(x)

def is_less_than(self, s, t):
r"""
Check if `s < t` in the refinement ordering on set partitions.
Expand Down Expand Up @@ -2079,7 +2045,7 @@ def __init__(self, s, parts):
sage: TestSuite(S).run()
"""
SetPartitions_set.__init__(self, s)
self.parts = parts
self.parts = Partition(parts)

def _repr_(self):
"""
Expand Down Expand Up @@ -2136,6 +2102,40 @@ def cardinality(self):
cardinal /= prod(repetitions)
return Integer(cardinal)

def _set_partition_poset(self):
"""
Return a poset whose linear extensions correspond to the set
partitions with specified block sizes
TESTS::
sage: n = 12
sage: all(SetPartitions(n, mu).cardinality() == _set_partition_poset(mu).linear_extensions().cardinality() for mu in Partitions(n))
"""
c = self.parts.to_exp_dict()
covers = dict()
i = 1
for s in sorted(c):
# s is the block size
# each block is one tree in the poset
for m in range(c[s]):
# m is the multiplicity of blocks with size s
#
# the first element in each non-final block has an
# additional cover
first = i
if s == 1:
covers[i] = []
else:
for j in range(s-1):
covers[i] = [i+1]
i += 1
i += 1
if m < c[s]-1:
covers[first].append(i)
return DiGraph(covers)

def __iter__(self):
"""
An iterator for all the set partitions of the given set with
Expand All @@ -2145,9 +2145,19 @@ def __iter__(self):
sage: SetPartitions(3, [2,1]).list()
[{{1}, {2, 3}}, {{1, 3}, {2}}, {{1, 2}, {3}}]
"""
for sp in self._iterator_part(self.parts):
yield self.element_class(self, sp)
# Ruskey, Combinatorial Generation, sec. 5.10.1 and Knuth TAOCP 4A 7.2.1.5, Exercise 6
k = len(self.parts)
P = self._set_partition_poset2()

sums = [0]
for b in sorted(self.parts):
sums.append(sums[-1] + b)

for ext in LinearExtensions(P):
pi = Permutation(ext, check_input = False).inverse()
yield SetPartition([pi[sums[i]:sums[i+1]] for i in range(k)])

def __contains__(self, x):
"""
Expand Down

0 comments on commit 84ba2c8

Please sign in to comment.