# Solution

> Official python docs: [https://docs.python.org/3/library/itertools.html#itertools.permutations](https://docs.python.org/3/library/itertools.html#itertools.permutations)

## 1. define permutation function

In [1]:
def permutations(iterable, r=None):
    # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) --> 012 021 102 120 201 210

    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return
    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    yield tuple(pool[i] for i in indices[:r])
    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            return

## 2. unit tests

Let's modify this function to include our desired behaviour.

But first, to make this easy, let's make a test to make sure our function is working correctly

In [10]:
def test_permutations(permutation_func):
    result = list(permutation_func(
        'abc',
        r=3,
    ))
    expected = [
        ('a', 'b', 'c'),
        ('a', 'c', 'b'),
        ('b', 'a', 'c'),
        ('b', 'c', 'a'),
        ('c', 'a', 'b'),
        ('c', 'b', 'a'),
    ]
    assert expected == result, f'{expected} != {result}'

def test_permutations_mapping(permutation_func):
    result = list(permutation_func(
        'abcd',
        r=4,
        mapping=[('a', 'c')],
    ))
    expected = [
        ('a', 'c', 'b', 'd'),
        ('a', 'c', 'd', 'b'),
    ]
    assert expected == result, f'{expected} != {result}'
    
def test_func(f):
    test_permutations(f)
    test_permutations_mapping(f)

In [3]:
def permutations(iterable, r=None, mapping=[]):
    # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) --> 012 021 102 120 201 210

    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return
    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    yield tuple(pool[i] for i in indices[:r])
    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            return

In [4]:
test_permutations_mapping(permutations)

AssertionError: (('a', 'c', 'b', 'd'), ('a', 'c', 'd', 'b')) != [('a', 'b', 'c', 'd'), ('a', 'b', 'd', 'c'), ('a', 'c', 'b', 'd'), ('a', 'c', 'd', 'b'), ('a', 'd', 'b', 'c'), ('a', 'd', 'c', 'b'), ('b', 'a', 'c', 'd'), ('b', 'a', 'd', 'c'), ('b', 'c', 'a', 'd'), ('b', 'c', 'd', 'a'), ('b', 'd', 'a', 'c'), ('b', 'd', 'c', 'a'), ('c', 'a', 'b', 'd'), ('c', 'a', 'd', 'b'), ('c', 'b', 'a', 'd'), ('c', 'b', 'd', 'a'), ('c', 'd', 'a', 'b'), ('c', 'd', 'b', 'a'), ('d', 'a', 'b', 'c'), ('d', 'a', 'c', 'b'), ('d', 'b', 'a', 'c'), ('d', 'b', 'c', 'a'), ('d', 'c', 'a', 'b'), ('d', 'c', 'b', 'a')]

### 2.1. Customise function

Cool! Now that we have our failing test, let's update the code to work as we would like

In [5]:
list(permutations('abc', r=3, mapping=[('a',)]))

[('a', 'b', 'c'),
 ('a', 'c', 'b'),
 ('b', 'a', 'c'),
 ('b', 'c', 'a'),
 ('c', 'a', 'b'),
 ('c', 'b', 'a')]

In [6]:
def permutations(iterable, r=None, mapping=[]):
    # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) --> 012 021 102 120 201 210

    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return
    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    
    option = tuple(pool[i] for i in indices[:r])

    if mapping:
        for m in mapping:
            # print('->', m, option, option[:len(m)])
            if m == option[:len(m)]:
                # print('✓', option)
                yield option
    else:
        yield option

    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                option = tuple(pool[i] for i in indices[:r])
                # print('->', option)

                if mapping:
                    for m in mapping:
                        # print('->', m, option, option[:len(m)])
                        if m == option[:len(m)]:
                            # print('✓', option)
                            yield option
                else:
                    yield option
                break
        else:
            return

In [7]:
list(permutations('abc', r=3, mapping=[('a',)]))

[('a', 'b', 'c'), ('a', 'c', 'b')]

In [13]:
list(permutations('abcd', r=4, mapping=[('a', 'c')]))

[('a', 'c', 'b', 'd'), ('a', 'c', 'd', 'b')]

In [11]:
list(permutations('abc'))

[('a', 'b', 'c'),
 ('a', 'c', 'b'),
 ('b', 'a', 'c'),
 ('b', 'c', 'a'),
 ('c', 'a', 'b'),
 ('c', 'b', 'a')]

In [12]:
test_func(permutations)