# Partition List

Given the `head` of a linked list and a value `x`, partition it such that all nodes less than `x` come before nodes greater than or equal to `x`.

You should preserve the original relative order of the nodes in each of the two partitions.

Example 1:

```python
1 -> 4 -> 3 -> 2 -> 5 -> 2
to...
1 -> 2 -> 2 -> 3 -> 3 -> 5

head = [1,4,3,2,5,2]
x = 3
expected = [1,2,2,4,3,5]
```

In [1]:
from typing import *

# Definition for singly-linked list.
class ListNode:
    def __init__(self, val=0, next=None):
        self.val = val
        self.next = next

class Solution:
    # time = O(n)
    # space = O(1)
    # no need additional space
    def partition(self, head: Optional[ListNode], x: int) -> Optional[ListNode]:
        before = before_head = ListNode(0)
        after = after_head = ListNode(0)

        while head:
            # If the original list node is lesser than the given x,
            # assign it to the before list.
            if head.val < x:
                before.next = head
                before = before.next
            else:
                # If the original list node is greater or equal to the given x,
                # assign it to the after list.
                after.next = head
                after = after.next

            # move ahead in the original list
            head = head.next

        # Last node of "after" list would also be ending node of the reformed list
        after.next = None
        # Once all the nodes are correctly assigned to the two lists,
        # combine them to form a single list which would be returned.
        before.next = after_head.next

        return before_head.next

In [2]:
head = ListNode(1)
head.next = ListNode(4)
head.next.next = ListNode(3)
head.next.next.next = ListNode(2)
head.next.next.next.next = ListNode(5)
head.next.next.next.next.next = ListNode(2)
x = 3

output = Solution().partition(head, x)
print(output.val, '->', 
      output.next.val, '->', 
      output.next.next.val, '->', 
      output.next.next.next.val, '->', 
      output.next.next.next.next.val, '->', 
      output.next.next.next.next.next.val)

assert output.val == 1
assert output.next.val == 2
assert output.next.next.val == 2
assert output.next.next.next.val == 4
assert output.next.next.next.next.val == 3
assert output.next.next.next.next.next.val == 5
assert output.next.next.next.next.next.next == None

1 -> 2 -> 2 -> 4 -> 3 -> 5


In [3]:
head = ListNode(2)
head.next = ListNode(1)
x = 2

output = Solution().partition(head, x)
print(output.val, '->', 
      output.next.val)

assert output.val == 1
assert output.next.val == 2
assert output.next.next == None

1 -> 2
