# 2966. Divide Array Into Arrays With Max Difference

# Medium

You are given an integer array nums of size n where n is a multiple of 3 and a positive integer k.

Divide the array nums into n / 3 arrays of size 3 satisfying the following condition:

The difference between any two elements in one array is less than or equal to k.
Return a 2D array containing the arrays. If it is impossible to satisfy the conditions, return an empty array. And if there are multiple answers, return any of them.

# Example 1:

```
Input: nums = [1,3,4,8,7,9,3,5,1], k = 2

Output: [[1,1,3],[3,4,5],[7,8,9]]

Explanation:

The difference between any two elements in each array is less than or equal to 2.
```

# Example 2:

```
Input: nums = [2,4,2,2,5,2], k = 2

Output: []

Explanation:

Different ways to divide nums into 2 arrays of size 3 are:

[[2,2,2],[2,4,5]] (and its permutations)
[[2,2,4],[2,2,5]] (and its permutations)
Because there are four 2s there will be an array with the elements 2 and 5 no matter how we divide it. since 5 - 2 = 3 > k, the condition is not satisfied and so there is no valid division.
```

# Example 3:

```
Input: nums = [4,2,9,8,2,12,7,12,10,5,8,5,5,7,9,2,5,11], k = 14

Output: [[2,2,12],[4,8,5],[5,9,7],[7,8,5],[5,9,10],[11,12,2]]

Explanation:

The difference between any two elements in each array is less than or equal to 14.

```

# Constraints:

- n == nums.length
- 1 <= n <= 105
- n is a multiple of 3
- 1 <= nums[i] <= 105
- 1 <= k <= 105


The problem asks us to divide an array `nums` of size `n` (where `n` is a multiple of 3) into `n / 3` arrays of size 3. The crucial condition is that for each of these 3-element subarrays, the difference between _any_ two elements must be less than or equal to `k`. If a valid division is possible, return any valid 2D array; otherwise, return an empty array.

### Key Insight for Optimal Solution

The most important insight to solve this problem efficiently is to **sort the input array `nums` first.**

Why sorting?
If we want the difference between _any_ two elements in a group of three to be at most `k`, this implies that the difference between the smallest and largest element in that group must be at most `k`. If we sort the elements, we can easily pick three consecutive elements. If `num[i]`, `num[i+1]`, `num[i+2]` are sorted, then `num[i+2] - num[i]` is the maximum difference within that group. If this maximum difference is $\le k$, then all other differences within the group (e.g., `num[i+1] - num[i]`, `num[i+2] - num[i+1]`) will also be $\le k$.

This greedy approach works because sorting allows us to try and form groups with the smallest possible differences. If we _can't_ form a valid group with the smallest available elements, then including larger elements later would only make the differences even larger, making it impossible.

### Algorithm 1: Greedy Approach (Optimal)

This is the most efficient and straightforward way to solve the problem after understanding the key insight.

**Algorithm:**

1. **Sort `nums`:** Sort the input array `nums` in non-decreasing order. This is the cornerstone of this approach.
2. **Initialize Result:** Create an empty list `result` to store the 2D array.
3. **Iterate and Form Groups:** Iterate through the sorted `nums` array with a step of 3 (i.e., `i` takes values `0, 3, 6, ...`).
   a. For each `i`, consider the subarray `[nums[i], nums[i+1], nums[i+2]]`.
   b. Check the condition: Calculate the difference between the largest and smallest element in this group, which is `nums[i+2] - nums[i]`.
   c. If `nums[i+2] - nums[i] > k`: It's impossible to satisfy the condition for this group. Since the array is sorted, any other permutation or combination involving these numbers would also lead to a difference greater than `k` (or equal, but not better). Therefore, return an empty array `[]` immediately.
   d. If `nums[i+2] - nums[i] <= k`: This group satisfies the condition. Add `[nums[i], nums[i+1], nums[i+2]]` to `result`.
4. **Return Result:** If the loop completes without returning an empty array, it means all groups were successfully formed. Return `result`.

**Example Walkthrough:** `nums = [1,3,4,8,7,9,3,5,1]`, `k = 2`

1. **Sort `nums`:** `[1, 1, 3, 3, 4, 5, 7, 8, 9]`
2. `result = []`

3. **Iterate:**

   - **`i = 0`:**

     - Group: `[nums[0], nums[1], nums[2]] = [1, 1, 3]`
     - Difference: `3 - 1 = 2`
     - `2 <= k` (2 <= 2) is True.
     - Add `[1, 1, 3]` to `result`. `result = [[1, 1, 3]]`

   - **`i = 3`:**

     - Group: `[nums[3], nums[4], nums[5]] = [3, 4, 5]`
     - Difference: `5 - 3 = 2`
     - `2 <= k` (2 <= 2) is True.
     - Add `[3, 4, 5]` to `result`. `result = [[1, 1, 3], [3, 4, 5]]`

   - **`i = 6`:**
     - Group: `[nums[6], nums[7], nums[8]] = [7, 8, 9]`
     - Difference: `9 - 7 = 2`
     - `2 <= k` (2 <= 2) is True.
     - Add `[7, 8, 9]` to `result`. `result = [[1, 1, 3], [3, 4, 5], [7, 8, 9]]`

4. **Loop finishes.** Return `[[1,1,3],[3,4,5],[7,8,9]]`.

**Time Complexity:**

- Sorting: $O(N \log N)$ where $N$ is the length of `nums`.
- Iteration: $O(N)$ for creating the groups.
- Overall: **$O(N \log N)$** (dominated by sorting).

**Space Complexity:**

- Sorting: $O(\log N)$ or $O(N)$ depending on the sort implementation (Python's Timsort is $O(N)$ in worst case, $O(\log N)$ in average case for auxiliary space).
- `result` array: $O(N)$ as we store all $N$ elements.
- Overall: **$O(N)$**.

**Code:**

```python
class Solution:
    def divideArray_greedy_sort(self, nums: list[int], k: int) -> list[list[int]]:
        # 1. Sort the input array
        nums.sort()

        n = len(nums)
        result = []

        # 2. Iterate and form groups of 3
        # We iterate with a step of 3
        for i in range(0, n, 3):
            # Check if there are enough elements for a group of 3
            # (This check is implicitly handled by `range(0, n, 3)` since n is a multiple of 3)
            # However, explicit indexing check is good practice if n was not guaranteed multiple of 3.

            # The current group will be [nums[i], nums[i+1], nums[i+2]]
            # The maximum difference within this sorted group is nums[i+2] - nums[i]
            if nums[i+2] - nums[i] <= k:
                result.append([nums[i], nums[i+1], nums[i+2]])
            else:
                # If the condition is not met for any group, it's impossible.
                # Since the array is sorted, we cannot find a better combination.
                return []

        return result

```

### Algorithm 2: Counting Sort / Frequency Array (Potentially faster for specific constraints, but more complex)

While the greedy approach is generally optimal, if the range of `nums[i]` values (`1 <= nums[i] <= 10^5`) is relatively small compared to `N` and `N` is very large, a counting sort approach could theoretically be faster for the "sorting" part. However, given `N <= 10^5`, standard comparison sorts are perfectly fine, and this approach adds complexity for potentially marginal gains.

**Algorithm:**

1.  **Frequency Counting:** Create a frequency array (or hash map) to count occurrences of each number in `nums`. Let `max_val` be the maximum value in `nums`. The frequency array will be of size `max_val + 1`.
2.  **Reconstruct Sorted Array (Conceptual):** Instead of explicitly building a sorted array, we can iterate through the frequency array and conceptually pick elements.
3.  **Form Groups:**
    a. Initialize `result = []` and `current_group = []`.
    b. Iterate from `num = 1` to `max_val`.
    c. For each `num`, while its count `freq[num]` is greater than 0:
    i. Add `num` to `current_group`.
    ii. Decrement `freq[num]`.
    iii. If `len(current_group) == 3`: 1. Check condition: If `current_group[2] - current_group[0] > k`, return `[]`. 2. Add `current_group` to `result`. 3. Reset `current_group = []`.
    iv. If `len(current_group) == 1` or `len(current_group) == 2`:
    This is the tricky part. If `num` is very far from the next available number needed to complete the group, it might be impossible. This approach essentially mimics the sorted pass, but the logic for checking `k` within a group that spans non-consecutive numbers in the frequency array becomes complex.
    **Simpler Refinement:** Instead of forming groups during the counting sort pass, first use the frequency array to reconstruct a truly sorted `nums` array. Then, apply Algorithm 1. This would be `O(N + max_val)` for the "sort" and `O(N)` for grouping.

**Time Complexity:**

- Frequency Counting: $O(N + MaxVal)$ where `MaxVal` is the maximum value in `nums`.
- Reconstructing and Grouping: $O(N)$
- Overall: **$O(N + MaxVal)$**. This is better than $O(N \log N)$ only if $MaxVal < N \log N$. Given $N, MaxVal \le 10^5$, $N \log N$ is roughly $10^5 \times 17$, so $MaxVal$ is roughly in the same ballpark as $N \log N$. So, $O(N \log N)$ is often faster or comparable, and simpler to implement.

**Space Complexity:** $O(MaxVal)$ for the frequency array.

**Code (Using Counting Sort to explicitly build the sorted array, then applying Greedy):**

```python
class Solution:
    def divideArray_counting_sort_then_greedy(self, nums: list[int], k: int) -> list[list[int]]:
        n = len(nums)

        # Determine the maximum value for counting sort array size
        # Max nums[i] is 10^5
        max_val = 100001

        # 1. Frequency Counting
        counts = [0] * max_val
        for num in nums:
            counts[num] += 1

        # 2. Reconstruct Sorted Array (conceptually or explicitly)
        # We can build a sorted list first, then apply the greedy strategy
        sorted_nums_from_counts = []
        for i in range(1, max_val):
            sorted_nums_from_counts.extend([i] * counts[i])

        # Now, apply the same greedy grouping logic as Approach 1
        result = []
        for i in range(0, n, 3):
            if sorted_nums_from_counts[i+2] - sorted_nums_from_counts[i] <= k:
                result.append([sorted_nums_from_counts[i], sorted_nums_from_counts[i+1], sorted_nums_from_counts[i+2]])
            else:
                return [] # Impossible to form valid groups

        return result
```

### Algorithm 3: Using a Min-Heap (Less Optimal than Sorting)

This approach is less efficient than full sorting but demonstrates a different way to think about the problem. It tries to pick the smallest available elements, but it still has to re-heapify.

**Algorithm:**

1. **Build Min-Heap:** Insert all elements from `nums` into a min-heap.
2. **Form Groups:**
   a. Initialize `result = []`.
   b. While the heap is not empty:
   i. Extract the three smallest elements from the heap: `val1 = heapq.heappop(heap)`, `val2 = heapq.heappop(heap)`, `val3 = heapq.heappop(heap)`.
   ii. Check condition: If `val3 - val1 > k`, then it's impossible to form a valid division. Return `[]`.
   iii. Add `[val1, val2, val3]` to `result`.
3. **Return Result:** Return `result`.

**Time Complexity:**

- Building heap: $O(N)$
- Extracting 3 elements $N/3$ times: $N/3 \times 3 \log N = O(N \log N)$
- Overall: **$O(N \log N)$**.

**Space Complexity:** $O(N)$ for the heap.

**Code:**

```python
import heapq

class Solution:
    def divideArray_min_heap(self, nums: list[int], k: int) -> list[list[int]]:
        n = len(nums)

        # Build a min-heap from the nums array
        heapq.heapify(nums) # O(N) operation to transform list into heap in-place

        result = []

        # Iterate n/3 times to form groups
        for _ in range(n // 3):
            # Extract the 3 smallest elements from the heap
            # We assume n is a multiple of 3, so there will always be 3 elements.
            # If n was not guaranteed, we would need checks like if len(nums) < 3 etc.
            val1 = heapq.heappop(nums)
            val2 = heapq.heappop(nums)
            val3 = heapq.heappop(nums)

            # Check the condition for the current group
            if val3 - val1 > k:
                return [] # Condition not met, impossible to divide

            result.append([val1, val2, val3])

        return result

```

### Conclusion

The **Greedy Approach (Sorting First)** is the most idiomatic and efficient solution for this problem. Its $O(N \log N)$ time complexity is dominated by the sort, and it correctly leverages the property that to minimize the difference between the smallest and largest in a group, you should pick the smallest available numbers.

The Counting Sort approach offers an alternative if the range of values is very small, but it's generally more complex to implement correctly for this specific problem's grouping logic.

The Min-Heap approach provides the same time complexity as sorting but is often slightly slower in practice due to the overhead of heap operations compared to highly optimized `sort()` implementations.


In [None]:
class Solution:
    def divideArray(self, nums: list[int], k: int) -> list[list[int]]:
        """
        Divides an array into n/3 arrays of size 3 such that the difference
        between any two elements in one array is less than or equal to k.

        Args:
            nums: The input integer array where n is a multiple of 3.
            k: A positive integer representing the maximum allowed difference.

        Returns:
            A 2D array containing the arrays if conditions are met, otherwise an empty array.
        """
        # Algorithm: Greedy approach with sorting.
        # Sorting the array is the key step. By sorting, we ensure that
        # if a triplet [nums[i], nums[i+1], nums[i+2]] meets the condition
        # (nums[i+2] - nums[i] <= k), it's the most optimal choice for these
        # elements, leaving the remaining elements in the best possible state
        # for subsequent triplets.

        # 1. Sort the input array `nums` in non-decreasing order.
        nums.sort()

        n = len(nums)
        result = []  # Initialize an empty list to store the resulting 2D array

        # 2. Iterate through the sorted array with a step of 3 to form triplets.
        for i in range(0, n, 3):
            # Extract the current triplet candidate.
            # Python's slicing handles cases where i+3 might exceed n gracefully,
            # but since n is a multiple of 3 and we iterate up to n-3,
            # `nums[i:i+3]` will always yield a list of 3 elements.
            group = nums[i:i+3]

            # 3. Check the condition: The difference between the largest and smallest
            # element in the current triplet (group[2] and group[0] after sorting)
            # must be less than or equal to k.
            if group[2] - group[0] > k:
                # 4. If the condition is violated, it's impossible to satisfy
                # the requirements for the entire array using this greedy strategy.
                # Since the array is sorted, any other formation would also either
                # involve larger differences or leave an impossible subproblem.
                # So, return an empty array immediately.
                return []
            else:
                # 5. If the condition is satisfied, add this valid triplet to the result.
                result.append(group)

        # 6. If the loop completes, it means all n/3 triplets were successfully
        # formed and satisfied the condition. Return the accumulated result.
        return result


# --- Edge Cases and Test Cases ---
if __name__ == "__main__":
    sol = Solution()

    # Example 1: Basic case from problem description
    nums1 = [1, 3, 4, 8, 7, 9, 3, 5, 1]
    k1 = 2
    expected1 = [[1, 1, 3], [3, 4, 5], [7, 8, 9]]
    result1 = sol.divideArray(nums1, k1)
    print(f"Test Case 1: nums={nums1}, k={k1}")
    print(f"Output: {result1}")
    print(f"Expected: {expected1}")
    # We sort both to handle "any order" requirement for triplets' order,
    # and "any of them" for the elements within triplets themselves (which are sorted by algorithm)
    assert all(sorted(t) == sorted(e) for t, e in zip(result1, expected1)) and len(result1) == len(expected1), f"Test Case 1 Failed"
    print("-" * 30)


    # Example 2: Impossible scenario from problem description
    nums2 = [2, 4, 2, 2, 5, 2]
    k2 = 2
    expected2 = []
    result2 = sol.divideArray(nums2, k2)
    print(f"Test Case 2: nums={nums2}, k={k2}")
    print(f"Output: {result2}")
    print(f"Expected: {expected2}")
    assert result2 == expected2, f"Test Case 2 Failed"
    print("-" * 30)

    # Example 3: Large k, entire array might be one "super triplet"
    nums3 = [4, 2, 9, 8, 2, 12, 7, 12, 10, 5, 8, 5, 5, 7, 9, 2, 5, 11]
    k3 = 14
    # Expected output after sorting: [[2,2,2], [4,5,5], [5,5,7], [7,8,8], [9,9,10], [11,12,12]]
    # Max diff in first triplet [2,2,2] is 0 <= 14.
    # Max diff in last triplet [11,12,12] is 1 <= 14.
    expected3_sorted_nums = sorted(nums3)
    expected3 = []
    for i in range(0, len(expected3_sorted_nums), 3):
        expected3.append(expected3_sorted_nums[i:i+3])

    result3 = sol.divideArray(nums3, k3)
    print(f"Test Case 3: nums={nums3}, k={k3}")
    print(f"Output: {result3}")
    print(f"Expected: {expected3}")
    assert all(sorted(t) == sorted(e) for t, e in zip(result3, expected3)) and len(result3) == len(expected3), f"Test Case 3 Failed"
    print("-" * 30)

    # Edge Case 1: Minimum n (n=3)
    nums4 = [10, 12, 11]
    k4 = 2
    expected4 = [[10, 11, 12]]
    result4 = sol.divideArray(nums4, k4)
    print(f"Test Case 4: nums={nums4}, k={k4}")
    print(f"Output: {result4}")
    print(f"Expected: {expected4}")
    assert all(sorted(t) == sorted(e) for t, e in zip(result4, expected4)) and len(result4) == len(expected4), f"Test Case 4 Failed"
    print("-" * 30)

    # Edge Case 2: Minimum n (n=3) - impossible
    nums5 = [1, 5, 10]
    k5 = 3
    expected5 = []
    result5 = sol.divideArray(nums5, k5)
    print(f"Test Case 5: nums={nums5}, k={k5}")
    print(f"Output: {result5}")
    print(f"Expected: {expected5}")
    assert result5 == expected5, f"Test Case 5 Failed"
    print("-" * 30)

    # Edge Case 3: All elements are the same
    nums6 = [7, 7, 7, 7, 7, 7]
    k6 = 0
    expected6 = [[7, 7, 7], [7, 7, 7]]
    result6 = sol.divideArray(nums6, k6)
    print(f"Test Case 6: nums={nums6}, k={k6}")
    print(f"Output: {result6}")
    print(f"Expected: {expected6}")
    assert all(sorted(t) == sorted(e) for t, e in zip(result6, expected6)) and len(result6) == len(expected6), f"Test Case 6 Failed"
    print("-" * 30)

    # Edge Case 4: All elements are the same but k is too small (e.g., k= -1, though k is positive)
    # The constraint says 1 <= k <= 10^5, so k will always be positive.
    # Let's test with k=0 for a difference of 0
    nums7 = [7, 7, 7, 7, 7, 7]
    k7 = -1 # This would violate constraints, but tests logic if k was allowed to be negative.
            # Since problem constraints state 1 <= k, this specific test might not be necessary on LeetCode.
            # For strict adherence to constraints, we'll use a valid k for this.
    k7_valid = 0 # If diff must be <= 0, but elements are identical, this is fine
    expected7_valid = [[7, 7, 7], [7, 7, 7]]
    result7_valid = sol.divideArray(nums7, k7_valid)
    print(f"Test Case 7 (k=0): nums={nums7}, k={k7_valid}")
    print(f"Output: {result7_valid}")
    print(f"Expected: {expected7_valid}")
    assert all(sorted(t) == sorted(e) for t, e in zip(result7_valid, expected7_valid)) and len(result7_valid) == len(expected7_valid), f"Test Case 7 Failed"
    print("-" * 30)

    # Edge Case 5: Large array size, still possible
    nums8 = list(range(1, 3 * 10**4 + 1)) # N = 30000
    k8 = 2
    # This should pass. Each group will be [i, i+1, i+2]
    # (i+2) - i = 2 <= k
    result8 = sol.divideArray(nums8, k8)
    expected8 = []
    for i in range(0, len(nums8), 3):
        expected8.append([nums8[i], nums8[i+1], nums8[i+2]])

    print(f"Test Case 8: Large array, n={len(nums8)}, k={k8}")
    # Don't print full arrays for large inputs
    # print(f"Output: {result8[:2]}...{result8[-2:]}")
    # print(f"Expected: {expected8[:2]}...{expected8[-2:]}")
    assert len(result8) == len(expected8) and all(result8[i] == expected8[i] for i in range(len(result8))), f"Test Case 8 Failed"
    print("Test Case 8 (Large array, possible) PASSED")
    print("-" * 30)

    # Edge Case 6: Large array size, impossible
    nums9 = list(range(1, 3 * 10**4 + 1))
    # Make a specific spot impossible, e.g., break a triplet condition
    nums9[3] = 100 # Original [4,5,6], now [4,5,100]. This will fail when k=2.
    k9 = 2
    result9 = sol.divideArray(nums9, k9)
    print(f"Test Case 9: Large array, n={len(nums9)}, k={k9}, impossible scenario")
    # print(f"Output: {result9}")
    # print(f"Expected: {[]}")
    assert result9 == [], f"Test Case 9 Failed"
    print("Test Case 9 (Large array, impossible) PASSED")
    print("-" * 30)

    print("\nAll implemented test cases passed!")

In [None]:
from typing import List

class Solution:
    def divideArray(self, nums: List[int], k: int) -> List[List[int]]:
        nums.sort()
        res = []
        for i in range(0, len(nums), 3):
            group = nums[i:i+3]
            if len(group) < 3 or group[2] - group[0] > k:
                return []
            res.append(g
			
def test_divideArray():
    sol = Solution()

    # ✅ Basic valid case
    assert sol.divideArray([1,3,4,8,7,9,3,5,1], 2) == [[1,1,3],[3,4,5],[7,8,9]]

    # ❌ Invalid case due to difference > k
    assert sol.divideArray([2,4,2,2,5,2], 2) == []

    # ✅ Large k allows wide grouping
    assert sol.divideArray([4,2,9,8,2,12,7,12,10,5,8,5,5,7,9,2,5,11], 14) == [[2,2,2],[4,5,5],[5,5,7],[7,8,8],[9,9,10],[11,12,12]]

    # ✅ Minimum input size
    assert sol.divideArray([1,1,1], 0) == [[1,1,1]]

    # ❌ Edge case: last group violates condition
    assert sol.divideArray([1,2,3,4,5,10], 2) == []

    # ❌ Edge case: group has fewer than 3 elements
    assert sol.divideArray([1,2], 1) == []

    # ✅ All elements same
    assert sol.divideArray([7,7,7,7,7,7], 0) == [[7,7,7],[7,7,7]]

    print("All test cases passed!")

test_divideArray()

In [None]:
class Solution:
    def divideArray(self, nums: list[int], k: int) -> list[list[int]]:
        # Algorithm: Greedy approach with sorting.

        # 1. Sort the input array `nums`.
        # This is crucial because it ensures that when we pick three consecutive
        # numbers, their difference will be minimized for that starting point.
        # If nums[i+2] - nums[i] <= k, then all other pairwise differences in
        # [nums[i], nums[i+1], nums[i+2]] will also be <= k.
        nums.sort()

        n = len(nums)
        result = []  # Initialize an empty list to store the resulting 2D array

        # 2. Iterate and form triplets.
        # We iterate with a step of 3 because each triplet consumes 3 elements.
        for i in range(0, n, 3):
            # Check for edge case where remaining elements are not enough for a triplet.
            # This should technically not happen given n is a multiple of 3,
            # but good practice for robustness.
            if i + 2 >= n:
                # This means the array length is not a multiple of 3 or an unexpected error.
                # Per constraints, n is always a multiple of 3, so this check is mostly defensive.
                return [] # Should not be reached if constraints are followed.

            # Consider the current triplet candidate: nums[i], nums[i+1], nums[i+2]
            # 3. Check the condition: The difference between the largest and smallest
            # element in the triplet must be less than or equal to k.
            # Since the array is sorted, nums[i+2] is the largest and nums[i] is the smallest.
            if nums[i+2] - nums[i] <= k:
                # 4. If the condition is satisfied, add this triplet to the result.
                result.append([nums[i], nums[i+1], nums[i+2]])
            else:
                # 5. If the condition is violated, it's impossible to satisfy
                # the requirements for the entire array.
                # Because the array is sorted, any other combination involving
                # these elements or larger numbers would also fail or be sub-optimal.
                # So, return an empty array immediately.
                return []

        # 6. If the loop completes, it means all triplets were successfully formed
        # satisfying the condition. Return the result.
        return result

In [None]:
import collections
import math

class Solution:
    def minimumScore(self, nums: list[int], edges: list[list[int]]) -> int:
        n = len(nums)
        adj = collections.defaultdict(list)
        for u, v in edges:
            adj[u].append(v)
            adj[v].append(u)

        # total_xor of the entire tree
        total_xor = 0
        for x in nums:
            total_xor ^= x

        # sub_xors[i] will store the XOR sum of the subtree rooted at i
        # child_nodes[i] will store all nodes in the subtree rooted at i (for checking ancestor/descendant)
        sub_xors = [0] * n
        child_nodes = [set() for _ in range(n)]

        # DFS to calculate subtree XORs and identify child nodes
        def dfs_sub_xor(u, parent):
            sub_xors[u] = nums[u]
            child_nodes[u].add(u) # Add itself to its children set
            for v in adj[u]:
                if v == parent:
                    continue
                dfs_sub_xor(v, u)
                sub_xors[u] ^= sub_xors[v]
                child_nodes[u].update(child_nodes[v]) # Add all children of v to u's children set
            return sub_xors[u]

        dfs_sub_xor(0, -1) # Start DFS from node 0 (arbitrary root)

        min_score = math.inf

        # Iterate over all distinct pairs of edges
        for i in range(n - 1):
            for j in range(i + 1, n - 1):
                edge1_u, edge1_v = edges[i]
                edge2_u, edge2_v = edges[j]

                # Determine which node is the "child" in the edge (for subtree XOR calculation)
                # For edge (u, v), assume v is the child if u is its parent
                # We need to ensure that the sub_xors[cut_node] refers to the XOR sum of the component formed by cutting that edge.
                # If child_nodes[edge1_u] contains edge1_v, then edge1_v is a descendant of edge1_u
                cut1_node = edge1_v if edge1_v in child_nodes[edge1_u] else edge1_u
                cut2_node = edge2_v if edge2_v in child_nodes[edge2_u] else edge2_u

                xor1 = 0
                xor2 = 0
                xor3 = 0

                # Determine the three XOR sums based on the relationship between cut1_node and cut2_node
                if cut2_node in child_nodes[cut1_node]: # cut2_node is a descendant of cut1_node
                    xor1 = sub_xors[cut2_node]
                    xor2 = sub_xors[cut1_node] ^ sub_xors[cut2_node]
                    xor3 = total_xor ^ sub_xors[cut1_node]
                elif cut1_node in child_nodes[cut2_node]: # cut1_node is a descendant of cut2_node
                    xor1 = sub_xors[cut1_node]
                    xor2 = sub_xors[cut2_node] ^ sub_xors[cut1_node]
                    xor3 = total_xor ^ sub_xors[cut2_node]
                else: # Disjoint subtrees
                    xor1 = sub_xors[cut1_node]
                    xor2 = sub_xors[cut2_node]
                    xor3 = total_xor ^ sub_xors[cut1_node] ^ sub_xors[cut2_node]

                current_xor_values = sorted([xor1, xor2, xor3])
                current_score = current_xor_values[2] - current_xor_values[0]
                min_score = min(min_score, current_score)
        
        return min_score

In [None]:
class Solution:
    def minimumScore(self, nums: List[int], edges: List[List[int]]) -> int:
        def dfs(i: int, fa: int) -> int:
            res = nums[i]
            for j in g[i]:
                if j != fa:
                    res ^= dfs(j, i)
            return res

        def dfs2(i: int, fa: int) -> int:
            nonlocal s, s1, ans
            res = nums[i]
            for j in g[i]:
                if j != fa:
                    s2 = dfs2(j, i)
                    res ^= s2
                    mx = max(s ^ s1, s2, s1 ^ s2)
                    mn = min(s ^ s1, s2, s1 ^ s2)
                    ans = min(ans, mx - mn)
            return res

        g = defaultdict(list)
        for a, b in edges:
            g[a].append(b)
            g[b].append(a)
        s = reduce(lambda x, y: x ^ y, nums)
        n = len(nums)
        ans = inf
        for i in range(n):
            for j in g[i]:
                s1 = dfs(i, j)
                dfs2(i, j)
        return ans

In [None]:
from typing import List
from collections import defaultdict
from functools import reduce
from math import inf

class Solution:
    def minimumScore(self, nums: List[int], edges: List[List[int]]) -> int:
        def computeXOR(node: int, parent: int) -> int:
            xor_sum = nums[node]
            for neighbor in graph[node]:
                if neighbor != parent:
                    xor_sum ^= computeXOR(neighbor, node)
            return xor_sum

        def evaluatePartition(node: int, parent: int) -> int:
            nonlocal total_xor, subtree_xor, min_score
            xor_sum = nums[node]
            for neighbor in graph[node]:
                if neighbor != parent:
                    child_xor = evaluatePartition(neighbor, node)
                    xor_sum ^= child_xor
                    max_val = max(total_xor ^ subtree_xor, child_xor, subtree_xor ^ child_xor)
                    min_val = min(total_xor ^ subtree_xor, child_xor, subtree_xor ^ child_xor)
                    min_score = min(min_score, max_val - min_val)
            return xor_sum

        graph = defaultdict(list)
        for u, v in edges:
            graph[u].append(v)
            graph[v].append(u)

        total_xor = reduce(lambda x, y: x ^ y, nums)
        num_nodes = len(nums)
        min_score = inf

        for node in range(num_nodes):
            for neighbor in graph[node]:
                subtree_xor = computeXOR(node, neighbor)
                evaluatePartition(node, neighbor)

        return min_score