# Groups Exercise Companion Notebook

**6.7970/8.750 Symmetry and its Application to Machine Learning**

This notebook follows the Groups exercise section by section. Use it to **prototype your code** and **test your implementations** against the course library before submitting on the website.

Each section includes small tests you can use to check your work.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/atomicarchitects/symm4ml-colabs/blob/main/groups_companion.ipynb)

## Setup

In [2]:
%%capture
!pip install https://symm4ml.mit.edu/_static/symm4ml_s26/symm4ml/symm4ml_latest.zip

In [3]:
import itertools
import numpy as np
from IPython.display import HTML

from symm4ml import groups, groups_fast, plot, vis

### Reference data

These tables and matrices are used throughout the exercise for testing.

In [4]:
# The 2x2 rotation/reflection representation of P(3) from Dresselhaus
E = np.eye(2)
A = np.array([[-1., 0.], [0., 1.]])
B = np.array([[1./2., -np.sqrt(3.)/2.], [-np.sqrt(3.)/2., -1/2.]])
C = np.array([[1./2., np.sqrt(3.)/2.], [np.sqrt(3.)/2., -1/2.]])
D = np.array([[-1./2., np.sqrt(3.)/2.], [-np.sqrt(3.)/2., -1/2.]])
F = np.array([[-1./2., -np.sqrt(3.)/2.], [np.sqrt(3.)/2., -1/2.]])
p3_dresselhaus = np.stack([E, A, B, C, D, F], axis=0)

# Reference multiplication tables used in tests
ans_table1 = np.array([[0,1,2,3,4,5],[1,0,3,2,5,4],[2,5,0,4,3,1],[3,4,1,5,2,0],[4,3,5,1,0,2],[5,2,4,0,1,3]])
ans_table2 = np.array([[0,1,2,3],[1,0,3,2],[2,3,0,1],[3,2,1,0]])
ans_table3 = np.array([[0,1,2,3],[1,0,3,2],[2,3,1,0],[3,2,0,1]])
ans_table4 = np.array([[0,1,2,3],[1,2,3,0],[2,3,0,1],[3,0,1,2]])

---
## Section 1: From Matrices to Groups

### Context: Two representations of $P(3)$

$P(3)$ ($\cong D_3$) is one of the simplest nonabelian groups. We can represent it as $3\times 3$ permutation matrices or $2\times 2$ rotation/reflection matrices.

In [5]:
p3_perm = groups.permutation_matrices(3)
print(f"P(3) permutation matrices: {p3_perm.shape}")
HTML(plot.matrix_grid(p3_perm, labels=["E","A","B","C","D","F"], cell_size=20))

P(3) permutation matrices: (6, 3, 3)


0,1,2
1,0,0
0,1,0
0,0,1

0,1,2
1,0,0
0,0,1
0,1,0

0,1,2
0,1,0
1,0,0
0,0,1

0,1,2
0,1,0
0,0,1
1,0,0

0,1,2
0,0,1
1,0,0
0,1,0

0,1,2
0,0,1
0,1,0
1,0,0


In [6]:
print(f"D3 rotation/reflection matrices: {p3_dresselhaus.shape}")
HTML(plot.matrix_grid(p3_dresselhaus, labels=["E","A","B","C","D","F"], cell_size=30))

D3 rotation/reflection matrices: (6, 2, 2)


0,1
1,0
0,1

0,1
-1,0
0,1

0,1
½,-√3/2
-√3/2,-½

0,1
½,√3/2
√3/2,-½

0,1
-½,√3/2
-√3/2,-½

0,1
-½,-√3/2
√3/2,-½


### 1.1 `permutation_matrices(n)`

Write your implementation here, then test against the course version.

In [7]:
def permutation_matrices(n):
    """Generates all permutation matrices of n elements
    Input:
        n: int
    Output:
        matrices: np.array of shape [n!, n, n]
    """
   # generate the rows (1 in one place, 0s everywhere else)
    rows = []
    for i in range(n):
      r = np.zeros(n)
      r[i] = 1
      rows.append(np.array(r))
    # make the matrices
    all = []
    idxs = itertools.permutations(range(n),n)
    for j in idxs:
        all.append(np.array([rows[k] for k in j]))
    all = np.array(all)
    return all

permutation_matrices(2)


array([[[1., 0.],
        [0., 1.]],

       [[0., 1.],
        [1., 0.]]])

In [8]:
# Small tests from the course library
# Matrices can be returned in any order, so we sort before comparing
result_2 = permutation_matrices(2)
assert result_2.shape == (2, 2, 2), f"Expected shape (2, 2, 2), got {result_2.shape}"
np.testing.assert_allclose(
    np.unique(result_2, axis=0),
    np.unique(np.array([[[1, 0], [0, 1]], [[0, 1], [1, 0]]]), axis=0),
)

result_3 = permutation_matrices(3)
assert result_3.shape == (6, 3, 3), f"Expected shape (6, 3, 3), got {result_3.shape}"
np.testing.assert_allclose(
    np.unique(result_3, axis=0),
    np.unique(groups.permutation_matrices(3), axis=0),
)
print("permutation_matrices tests passed!")

permutation_matrices tests passed!


### 1.2 `generate_group(matrices)`

Use closure under multiplication to generate a full group from a subset of elements.

In [14]:
def generate_group(matrices, decimals=4):
    """Generate new group elements from matrices (group representations)
    Input:
        matrices: np.array of shape [n, d, d] of known elements
        decimals: int number of decimals to round to when comparing matrices
    Output:
        group: np.array of shape [m, d, d], where m is the size of the resultant group
    """
    group = matrices # set of all group elements
    print('generator set:', group)
    print('fun')
    n = matrices.shape[0] # number of generators
    for i in matrices:
      for j in matrices:
        new = np.round(i @ j, decimals) # matrix mult
        print(new)
        if new not in group:
          np.append(group, new)
          print('hold')
        else:
          continue
    print(group)
    return group

# testing
p2 = permutation_matrices(2)
generate_group(p2[1:])


generator set: [[[0. 1.]
  [1. 0.]]]
fun
[[1. 0.]
 [0. 1.]]
hold
[[[0. 1.]
  [1. 0.]]]


array([[[0., 1.],
        [1., 0.]]])

In [10]:
# Small tests: generating P(3) from subsets
p3 = groups.permutation_matrices(3)
np.testing.assert_allclose(
    np.unique(generate_group(p3[:-2]), axis=0),
    np.unique(p3, axis=0),
)
print("generate_group tests passed!")

[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
[[1. 0. 0.]
 [0. 0. 1.]
 [0. 1. 0.]]
[[0. 1. 0.]
 [1. 0. 0.]
 [0. 0. 1.]]
[[0. 1. 0.]
 [0. 0. 1.]
 [1. 0. 0.]]
[[1. 0. 0.]
 [0. 0. 1.]
 [0. 1. 0.]]
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
[[0. 1. 0.]
 [0. 0. 1.]
 [1. 0. 0.]]
[[0. 1. 0.]
 [1. 0. 0.]
 [0. 0. 1.]]
[[0. 1. 0.]
 [1. 0. 0.]
 [0. 0. 1.]]
[[0. 0. 1.]
 [1. 0. 0.]
 [0. 1. 0.]]
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
[[0. 0. 1.]
 [0. 1. 0.]
 [1. 0. 0.]]
[[0. 1. 0.]
 [0. 0. 1.]
 [1. 0. 0.]]
[[0. 0. 1.]
 [0. 1. 0.]
 [1. 0. 0.]]
[[1. 0. 0.]
 [0. 0. 1.]
 [0. 1. 0.]]
[[0. 0. 1.]
 [1. 0. 0.]
 [0. 1. 0.]]
[0. 1.]


AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

(shapes (4, 3, 3), (6, 3, 3) mismatch)
 ACTUAL: array([[[0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.]],...
 DESIRED: array([[[0., 0., 1.],
        [0., 1., 0.],
        [1., 0., 0.]],...

In [None]:
# Visualize: generating D3 from a mirror and rotation
generators = p3_dresselhaus[[1, 4]]  # A (mirror) and D (rotation)
p3_generated = groups.generate_group(generators)
print(f"Generated {len(p3_generated)} elements from 2 generators")
HTML(plot.matrix_grid(p3_generated, cell_size=30))

### 1.3 `cyclic_matrices(n)`

Generate cyclic group matrices using a single generator and `generate_group`.

In [None]:
def cyclic_matrices(n):
    """Generates all cyclic matrices of n elements
    Input:
        n: int
    Output:
        matrices: np.array of shape [n, n, n]
    """
    # YOUR CODE HERE
    pass

In [None]:
# Quick check: C_1, C_2, C_3 should have 1, 2, 3 elements
for n in [1, 2, 3]:
    result = cyclic_matrices(n)
    assert result.shape == (n, n, n), f"cyclic_matrices({n}) shape should be ({n},{n},{n}), got {result.shape}"
print("cyclic_matrices tests passed!")

### 1.4 `make_multiplication_table(matrices)`

Build the Cayley table: entry at row $g$, column $h$ gives the index of $g \circ h$.

In [None]:
def make_multiplication_table(matrices, *, tol=1e-8):
    """Makes multiplication table for group.
    Input:
        matrices: np.array of shape [n, d, d], n matrices of dimension d that form a group under matrix multiplication.
        tol: float numberical tolerance
    Output:
        Group multiplication table.
        np.array of shape [n, n] where entries correspond to indices of first dim of matrices.
    """
    # YOUR CODE HERE
    pass

In [None]:
# Compare your table with the course version
table_yours = make_multiplication_table(groups.permutation_matrices(3))
table_course = groups.make_multiplication_table(groups.permutation_matrices(3))
np.testing.assert_array_equal(table_yours, table_course)
print("make_multiplication_table tests passed!")

In [None]:
# Visualize: the two P(3) tables look different because elements are ordered differently
table_perm = groups.make_multiplication_table(p3_perm)
table_2d = groups.make_multiplication_table(p3_dresselhaus)

HTML(plot.compare_tables(
    table_2d, table_perm,
    labels1=["E","A","B","C","D","F"],
    labels2=[str(i) for i in range(6)],
))

---
## Section 2: Group Definition

A group requires: closure, a unique identity, inverses for all elements, and associativity.

### 2.1 `identity(table)`

Find the unique identity element, or raise `ValueError("No or multiple identities")`.

In [None]:
def identity(table):
    """Returns the index of the identity element.
    Input:
        table: np.array of shape [n, n] where the entry at [i, j] is the index of the product of the ith and jth elements in the group.
    Output:
        Index of identity element.
    Raises:
        ValueError("No or multiple identities") if there is no or multiple identities.
    """
    # YOUR CODE HERE
    pass

In [None]:
# Tests from the course library
assert identity(ans_table1) == 0
assert identity(ans_table2) == 0
assert identity(np.array([[1, 2, 0], [2, 0, 1], [0, 1, 2]])) == 2

# Should raise ValueError for table with no identity
try:
    identity(np.array([[0, 1, 2], [0, 1, 2], [2, 0, 1]]))
    assert False, "Should have raised ValueError"
except ValueError as e:
    assert "No or multiple identities" in str(e)

print("identity tests passed!")

### 2.2 `inverses(table)`

Return array where entry $i$ is the index of the inverse of element $i$. Raise `ValueError("Every element does not have one inverse")` if not all elements have a unique inverse.

In [None]:
def inverses(table):
    """Returns the indices of the inverses of each element.
    Input:
        table: np.array of shape [n, n] where the entry at [i, j] is the index of the product of the ith and jth elements in the group.
    Output:
        np.array of shape [n] where the ith entry is the index of the inverse of the ith element.
    Raises:
        ValueError("Every element does not have one inverse") if there is no or multiple inverses.
    """
    # YOUR CODE HERE
    pass

In [None]:
# Tests
t = np.array([[0,1,2,3],[1,0,3,2],[2,3,0,1],[3,2,1,0]])
inv = inverses(t)
assert np.all(t[inv, np.arange(4)] == identity(t))

# Should raise ValueError
try:
    inverses(np.array([[2, 0, 0], [2, 2, 1], [0, 1, 2]]))
    assert False, "Should have raised ValueError"
except ValueError as e:
    assert "inverse" in str(e).lower()

print("inverses tests passed!")

### 2.3 `is_closed(table)`

In [None]:
def is_closed(table):
    """Tests whether the multiplication table is closed.
    Input:
        table: np.array of shape [n, n] where the entry at [i, j] is the index of the product of the ith and jth elements in the group.
    Output:
        True if the table represents a closed binary operation, False otherwise.
    """
    # YOUR CODE HERE
    pass

In [None]:
assert is_closed(np.array([[0, 0], [0, 0]])) == True
assert is_closed(np.array([[1, 2], [3, 4]])) == False
print("is_closed tests passed!")

### 2.4 `is_associative(table)`

In [None]:
def is_associative(table):
    """Tests whether the multiplication table is associative.
    Input:
        table: np.array of shape [n, n] where the entry at [i, j] is the index of the product of the ith and jth elements in the group.
    Output:
        True if the table represents an associative binary operation, False otherwise.
    """
    # YOUR CODE HERE
    pass

In [None]:
assert is_associative(np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])) == True
assert is_associative(np.array([[0, 1, 0], [0, 0, 0], [0, 0, 0]])) == False
print("is_associative tests passed!")

### 2.5 `test_group(table)`

Combine all four checks. Should raise specific `ValueError` messages for each failure.

In [None]:
def test_group(table):
    """Tests whether the multiplication table is valid.
    Input:
        table: np.array of shape [n, n] where the entry at [i, j] is the index of the product of the ith and jth elements in the group.
    Raises:
        ValueError("Invalid indices") if the table contains invalid indices (is not closed).
        ValueError("No or multiple identities") if the table does not contain exactly one identity.
        ValueError("Every element does not have one inverse") if not every element has an inverse.
        ValueError("Not associative") if the table is not associative.
    """
    # YOUR CODE HERE
    pass

In [None]:
# Helper to capture ValueError messages
def value_error(fn, *args):
    try:
        fn(*args)
    except ValueError as e:
        return " ".join(map(str, e.args))
    return None

# Valid group tables should pass
t = np.array([[0,1,2,3],[1,0,3,2],[2,3,0,1],[3,2,1,0]])
test_group(t)  # should not raise
test_group(ans_table1)  # should not raise

# Invalid tables should raise specific errors
assert value_error(test_group, np.array([[0, 1, 3], [1, 2, 0], [2, 0, 1]])) == "Invalid indices"
assert value_error(test_group, np.array([[0, 1, 2], [1, 2, 0], [2, 0, 2]])) == "Not associative"

print("test_group tests passed!")

In [None]:
# Explore group properties interactively
# Use the tabs: Elements, Rearrangement, Inverses, Subgroups, Conjugacy
HTML(plot.multiplication_table(table_2d, labels=["E","A","B","C","D","F"]))

---
## Section 3: Subgroups

By Lagrange's theorem, the order of a subgroup divides the order of the group. Use `itertools.combinations` to search over subsets of the right sizes.

### 3.1 `subgroups(table)`

The course provides `groups.factors(n)`. Use it to find candidate subgroup sizes.

In [None]:
# factors is provided for you
assert groups.factors(12) == {1, 2, 3, 4, 6, 12}
assert groups.factors(6) == {1, 2, 3, 6}
print("Factors of 6:", groups.factors(6))

In [None]:
def subgroups(table):
    """Find all subgroups of group.
    Input:
        table: np.array of shape [n, n] where the entry at [i, j] is the index of the product of the ith and jth elements in the group.
    Output:
        Yields tuples of elements that form subgroup.
    """
    # YOUR CODE HERE
    pass

In [None]:
t = np.array([[0,1,2,3],[1,0,3,2],[2,3,0,1],[3,2,1,0]])
assert subgroups(t) == {
    frozenset({0}),
    frozenset({0, 1}),
    frozenset({0, 2}),
    frozenset({0, 3}),
    frozenset({0, 1, 2, 3}),
}
print("subgroups tests passed!")

### 3.2 Questions: $C_3$ and $\mathbb{Z}_2$ in $P(3)$

Use the course implementations to find which indices of `permutation_matrices(3)` form $C_3$ and $\mathbb{Z}_2$.

Utility functions `groups.remap_to_minimal` and `groups.subgroup_table_from_group_table` may be helpful.

In [None]:
p3_subgroups = groups.subgroups(table_perm)
print("Subgroups of P(3):")
for sg in sorted(p3_subgroups, key=lambda s: (len(s), min(s))):
    print(f"  {sorted(sg)}  (order {len(sg)})")

# Visualize subgroup structure
HTML(plot.structure_explorer(table_perm, labels=[str(i) for i in range(6)]))

In [None]:
# YOUR ANSWERS:
# c3_in_p3 = {(?, ?, ?)}  # sorted tuple of indices forming C_3
# z2_in_p3 = {(?, ?), (?, ?), (?, ?)}  # sorted tuples of indices forming Z_2 copies

---
## Section 4: Cosets

Left cosets: $gH = \{gh : h \in H\}$. Right cosets: $Hg = \{hg : h \in H\}$.

### 4.1 `right_coset(table, subgroup_indices)`

In [None]:
def right_coset(table, subgroup_indices):
    """Returns the right coset of the ith element.
    Input:
        table: np.array of shape [n, n] where the entry at [i, j] is the index of the product of the ith and jth elements in the group.
        subgroup_indices: Indices of elements in the subgroup.
    Output:
        Set of right cosets for each element in the group. Each coset is represented as a frozenset of indices.
    Example:
        right_coset(np.array([[0, 1], [1, 0]]), {0}) == {frozenset({1}), frozenset({0})}
    """
    # YOUR CODE HERE
    pass

In [None]:
t = np.array([[0,1,2,3],[1,0,3,2],[2,3,0,1],[3,2,1,0]])
assert right_coset(t, {0, 1}) == {frozenset({0, 1}), frozenset({2, 3})}
assert right_coset(t, {0, 2}) == {frozenset({0, 2}), frozenset({1, 3})}
print("right_coset tests passed!")

### 4.2 `left_coset(table, subgroup_indices)`

In [None]:
def left_coset(table, subgroup_indices):
    """Returns the left coset of the ith element.
    Input:
        table: np.array of shape [n, n] where the entry at [i, j] is the index of the product of the ith and jth elements in the group.
        subgroup_indices: Indices of elements in the subgroup.

    Output:
        Set of left cosets for each element in the group. Each coset is represented as a set of indices.
    """
    # YOUR CODE HERE
    pass

In [None]:
t = np.array([[0,1,2,3],[1,0,3,2],[2,3,0,1],[3,2,1,0]])
assert left_coset(t, {0, 1}) == {frozenset({0, 1}), frozenset({2, 3})}
print("left_coset tests passed!")

In [None]:
# Compare left and right cosets of P(3)
# For a non-normal subgroup {E, A}, left != right cosets
print("Left cosets of {E, A}:", groups.left_coset(table_2d, {0, 1}))
print("Right cosets of {E, A}:", groups.right_coset(table_2d, {0, 1}))
print()
# For the normal subgroup C_3 = {E, D, F}, they match
print("Left cosets of {E, D, F}:", groups.left_coset(table_2d, {0, 4, 5}))
print("Right cosets of {E, D, F}:", groups.right_coset(table_2d, {0, 4, 5}))

---
## Section 5: Conjugacy, Classes, and Factor Groups

### 5.1 `conjugacy_classes(table)`

$b$ is conjugate to $a$ if $\exists x \in G$ such that $b = xax^{-1}$.

In [None]:
def conjugacy_classes(table):
    """Returns the conjugacy classes of the group.
    Input:
        table: np.array of shape [n, n] where the entry at [i, j] is the index of the product of the ith and jth elements in the group.
    Output:
        Set of conjugacy classes. Each conjugacy class is a set of integers.
    """
    # YOUR CODE HERE
    pass

In [None]:
# D2 (abelian) — every element is its own class
t = np.array([[0,1,2,3],[1,0,3,2],[2,3,0,1],[3,2,1,0]])
assert conjugacy_classes(t) == {
    frozenset({0}), frozenset({1}), frozenset({2}), frozenset({3}),
}
print("conjugacy_classes tests passed!")

In [None]:
# P(3) conjugacy classes: {E}, {D, F} (rotations), {A, B, C} (mirrors)
conj = groups.conjugacy_classes(table_2d)
labels = ["E","A","B","C","D","F"]
for c in sorted(conj, key=lambda s: (len(s), min(s))):
    print("{"+", ".join(labels[i] for i in sorted(c))+"}")

### 5.2 `selfconjugate_subgroups(table)`

A subgroup $H$ is self-conjugate (normal) if $gHg^{-1} = H$ for all $g \in G$.

In [None]:
def selfconjugate_subgroups(table):
    """Returns the self-conjugate (normal) subgroups of the group.
    Input:
        table: np.array of shape [n, n] where the entry at [i, j] is the index of the product of the ith and jth elements in the group.
    Output:
        Set of self-conjugate subgroups. Each subgroup is a frozenset of element indices.
    """
    # YOUR CODE HERE
    pass

In [None]:
# D2 is abelian so all subgroups are normal
t = np.array([[0,1,2,3],[1,0,3,2],[2,3,0,1],[3,2,1,0]])
assert selfconjugate_subgroups(t) == {
    frozenset({0}),
    frozenset({0, 1}),
    frozenset({0, 2}),
    frozenset({0, 3}),
    frozenset({0, 1, 2, 3}),
}
print("selfconjugate_subgroups tests passed!")

### 5.3 `factor_group(table, selfconj_sub)`

The factor group $G/H$ treats each coset of $H$ as a single element.

In [None]:
def factor_group(table, selfconj_sub):
    """Returns the factor group of the group.
    Input:
        table: np.array of shape [n, n] where entries correspond to indices of group elements.
        selfconj_sub: set of indices for self-conjugate subgroup.
    Output:
        Multiplication table of factor group of order n2 as sets of  elements of the group
        np.array sets of ints of shape [n2, n2]
        Multiplication table of factor group in terms of indices of right cosests
        np.array of shape [n2, n2] where entries correspond to indices of first dim of matrices.
    """
    # YOUR CODE HERE
    pass

In [None]:
# Compare your factor group with the course version
_, ft_yours = factor_group(ans_table1, frozenset({0, 3, 5}))
_, ft_course = groups.factor_group(ans_table1, frozenset({0, 3, 5}))

# The tables should be isomorphic (possibly different labeling)
print("Your factor group table:")
print(ft_yours)
print("Course factor group table:")
print(ft_course)

In [None]:
# P(3) / C_3 ≅ Z_2
coset_labels, factor_table = groups.factor_group(table_2d, frozenset({0, 4, 5}))
print("Factor group P(3)/C_3:")
print(factor_table)
print("This is Z_2!")

In [None]:
# Visualize cosets and factor groups
HTML(plot.structure_explorer(table_2d, labels=["E","A","B","C","D","F"]))

---
## Section 6: Comparing Tables

### 6.1 `isomorphisms(table_src, table_dst)`

Find all relabelings $h$ such that $h(g_1 \cdot g_2) = h(g_1) \cdot h(g_2)$.

You may find `groups.permute_mul_table` helpful for testing.

In [None]:
def isomorphisms(table_src, table_dst):
    """Finds all isomorphisms between two multiplication tables of same order.
    Input:
        table_src: np.array of shape [n, n] where the entry at [i, j] is the index of the product of the ith and jth elements in the source group.
        table_dst: np.array of shape [n, n] where the entry at [i, j] is the index of the product of the ith and jth elements in the destination group.
    Output:
        A set of isomorphisms encoded as tuples ``h`` of length ``n``.
        Each element ``h[i]`` is the index of the image of the ith element in the source group.
    """
    # YOUR CODE HERE
    pass

In [None]:
# Z_2 has only the identity isomorphism to itself
assert isomorphisms(np.array([[0, 1], [1, 0]]), np.array([[0, 1], [1, 0]])) == {(0, 1)}

# C_4 and D_2 are NOT isomorphic (same order, different structure)
assert isomorphisms(ans_table3, ans_table4) == set()

# D_2 to itself has multiple isomorphisms
assert len(isomorphisms(ans_table2, ans_table2)) > 0

print("isomorphisms tests passed!")

In [None]:
# The two P(3) representations are isomorphic
isos = groups.isomorphisms(table_2d, table_perm)
print(f"Found {len(isos)} isomorphisms from D3 to P(3) perm")

# Use one isomorphism to reorder and compare
reorder = list(list(isos)[0])
table_perm_reordered = groups.make_multiplication_table(p3_perm[reorder])
HTML(plot.compare_tables(
    table_2d, table_perm_reordered,
    labels1=["E","A","B","C","D","F"],
    labels2=["E","A","B","C","D","F"],
))

### 6.2 `surjective_homomorphisms(table_src, table_dst)`

Like isomorphisms but the map need not be injective — only surjective.

In [None]:
def surjective_homomorphisms(table_src, table_dst):
    """Finds all surjective homomorphisms from one group to another.
    Input:
        table_src: np.array of shape [n_src, n_src] where the entry at [i, j] is the index of the product of the ith and jth elements in the source group.
        table_dst: np.array of shape [n_dst, n_dst] where the entry at [i, j] is the index of the product of the ith and jth elements in the destination group.
    Output:
        A set of surjective homomorphisms encoded as tuples ``h`` of length ``n_src``.
        Each element ``h[i]`` is the index of the image of the ith element in the source group.
    """
    # YOUR CODE HERE
    pass

In [None]:
assert surjective_homomorphisms(
    np.array([[0, 1], [1, 0]]), np.array([[0, 1], [1, 0]])
) == {(0, 1)}
print("surjective_homomorphisms tests passed!")

### $C_4$ vs $D_2$: same order, not isomorphic

$C_4$ has elements of order 4, $D_2$ does not.

In [None]:
table_c4 = groups.make_multiplication_table(groups.cyclic_matrices(4))

isos_c4_d2 = groups.isomorphisms(table_c4, groups.D2_table)
print(f"C4 ≅ D2? {'Yes' if isos_c4_d2 else 'No — NOT isomorphic'}")

HTML(plot.compare_tables(
    table_c4, groups.D2_table,
    labels1=["e", "r", "r²", "r³"],
    labels2=["e", "a", "b", "c"],
))

---
## Section 7: Symmetries of Molecule $AB_4$

The $AB_4$ molecule has a central atom $A$ with four $B$ atoms at the corners of a square (not coplanar with $A$). Its symmetry group has 8 elements.

### 7.1 `AB4_group()`

Return $3 \times 3$ rotation and reflection matrices that leave the molecule invariant.

Coordinates: A = (0, 0, 1), B₁ = (1, 1, 0), B₂ = (−1, 1, 0), B₃ = (−1, −1, 0), B₄ = (1, −1, 0)

In [None]:
def AB4_group():
    """Return 3D rotation and reflection matrices
    that represent the symmetry operations for the
    molecule AB_4, where the B atoms lie at the
    corners of a square and the A atom is at the
    center and is not coplanar with the B atoms
    Output:
        np.array of shape [N, 3, 3] that represent
        symmetry operations of the AB_4 molecule
    """
    # YOUR CODE HERE
    pass

In [None]:
ab4 = AB4_group()
assert ab4.shape[0] == 8, f"AB4 group should have 8 elements, got {ab4.shape[0]}"
assert ab4.shape[1:] == (3, 3), f"Matrices should be 3x3"

# Check all matrices are orthogonal
for m in ab4:
    np.testing.assert_allclose(m @ m.T, np.eye(3), atol=1e-8)

# Compare with course implementation
ab4_course = groups.AB4_group()
table_ab4 = groups.make_multiplication_table(ab4)
table_ab4_course = groups.make_multiplication_table(ab4_course)
assert len(groups.isomorphisms(table_ab4, table_ab4_course)) > 0, "Not isomorphic to course solution!"
print("AB4_group tests passed!")

In [None]:
ab4_course = groups.AB4_group()
table_ab4 = groups.make_multiplication_table(ab4_course)
print(f"AB₄ symmetry group has {len(ab4_course)} elements")
HTML(plot.matrix_grid(ab4_course, cell_size=24))

In [None]:
HTML(plot.multiplication_table(table_ab4))

### 7.2 `AB4_sc_subs_iso_C4_vs_D2(AB4_matrices)`

Classify order-4 self-conjugate subgroups as $C_4$ or $D_2$.

You can access `groups.D2_table` and `groups.C4_table` directly. You may find `groups.remap_to_minimal` and `groups.subgroup_table_from_group_table` helpful.

In [None]:
def AB4_sc_subs_iso_C4_vs_D2(AB4_matrices):
    """Identify which self-conjugate subgroups of
    the symmetry group of AB4 are isomorphic to C_4
    vs. D_4.
    Input:
        AB4_matrices: np.array of shape [|G|, 3, 3] of all 3D rotations and reflections that leave the molecule AB4 invariant
    Output:
        A tuple of two sets of frozensets (so that it's hashable). The first set is of self-conjugate subgroups that are isomorphic to C_4 given as frozensets of of element indices. The second set is of self-conjugate subgroups that are isomorphic to D_2 given as frozensets of element indices.
        Example for random indices just to show format...
            return ({frozenset(0, 1, 5, 7), frozenset(0, 4, 6, 9)}, {frozenset(0, 7, 8, 9)})
    """
    # YOUR CODE HERE
    pass

In [None]:
C4_sets, D2_sets = AB4_sc_subs_iso_C4_vs_D2(groups.AB4_group())
print(f"Self-conjugate subgroups isomorphic to C4: {C4_sets}")
print(f"Self-conjugate subgroups isomorphic to D2: {D2_sets}")

# Check against course
C4_course, D2_course = groups.AB4_sc_subs_iso_C4_vs_D2(groups.AB4_group())
assert C4_sets == C4_course and D2_sets == D2_course, "Does not match course solution!"
print("AB4_sc_subs_iso_C4_vs_D2 tests passed!")

In [None]:
HTML(plot.structure_explorer(table_ab4))

---
## Section 8: Playing with $P(4)$

$P(4)$ has $4! = 24$ elements. The naive `groups.subgroups` and `groups.isomorphisms` functions are too slow at this scale — use `groups_fast` instead.

**Important:** Use `groups_fast.generate_subgroups_dynamic_programming` and `groups_fast.isomorphisms_generator_backtracking`.

In [None]:
p4 = groups.permutation_matrices(4)
table_p4 = groups.make_multiplication_table(p4)
print(f"P(4) has {len(p4)} elements")

### 8.1 Order of $P(4)$

In [None]:
p4_order = len(p4)
print(f"|P(4)| = {p4_order}")

### 8.2 Conjugacy classes of $P(4)$

Match each conjugacy class to its geometric interpretation: $E$, $C_2$, $C_3$, $\sigma_d$, $S_4$.

In [None]:
conj_p4 = groups.conjugacy_classes(table_p4)
print(f"{len(conj_p4)} conjugacy classes:")
for c in sorted(conj_p4, key=lambda s: (len(s), min(s))):
    print(f"  {sorted(c)}  (size {len(c)})")

In [None]:
# Look at the actual matrices to identify each class geometrically
# Hint: check determinants (rotation vs improper) and traces
for c in sorted(conj_p4, key=lambda s: (len(s), min(s))):
    rep = sorted(c)[0]
    det = np.linalg.det(p4[rep])
    tr = np.trace(p4[rep])
    print(f"  Class {sorted(c)}: det={det:+.0f}, trace={tr:+.0f}")
    print(f"    Representative matrix:\n{p4[rep]}\n")

In [None]:
HTML(plot.multiplication_table(table_p4, cell_size=22))

### 8.3 Factor groups of $P(4)$: finding $P(4)/H \cong P(3)$

**Strategy:**
1. Find subgroups with `groups_fast.generate_subgroups_dynamic_programming`
2. Check normality by comparing left and right cosets
3. Compute factor groups for non-trivial normal subgroups
4. Test isomorphism with $P(3)$ using `groups_fast.isomorphisms_generator_backtracking`

In [None]:
# Step 1: Find subgroups efficiently
p4_subgroups = groups_fast.generate_subgroups_dynamic_programming(
    np.array(table_p4, dtype=np.int32)
)
print(f"P(4) has {len(p4_subgroups)} subgroups")

# Step 2: Find normal subgroups (left cosets == right cosets)
sc_subgroups = []
for s in p4_subgroups:
    if groups.right_coset(table_p4, s) == groups.left_coset(table_p4, s):
        sc_subgroups.append(s)

print(f"{len(sc_subgroups)} are normal:")
for sg in sorted(sc_subgroups, key=lambda s: (len(s), min(s))):
    print(f"  {sorted(sg)}  (order {len(sg)})")

In [None]:
# Steps 3 & 4: Find which normal subgroup gives factor group ≅ P(3)
table_p3 = groups.make_multiplication_table(groups.permutation_matrices(3))

for sg in sorted(sc_subgroups, key=lambda s: (len(s), min(s))):
    if len(sg) in (1, len(table_p4)):  # skip trivial
        continue
    _, ft = groups.factor_group(table_p4, sg)
    if len(ft) == len(table_p3):
        # Use fast isomorphism check
        found = False
        for iso in groups_fast.isomorphisms_generator_backtracking(
            np.array(ft, dtype=np.int32), np.array(table_p3, dtype=np.int32)
        ):
            found = True
            break
        print(f"H = {sorted(sg)}: P(4)/H ≅ P(3)? {found}")

In [None]:
HTML(plot.structure_explorer(table_p4, cell_size=22))

In [None]:
# Try a different group!


In [None]:
# Explore subgroups and factor groups


In [None]:
# Compare two groups for isomorphism
