In [28]:
class Node:
    def __init__(self, item, next=None, prev=None):
        self.item = item
        self.next = next
        self.prev = prev

class LinkedList:
    def __init__(self, head=None):
        self.head = head
        
    def length(self):
        length = 0
        curr = self.head
        while curr != None:
            curr = curr.next
            length += 1
        return length
    
    def __str__(self):
        ll_str = ""
        curr = self.head
        while curr != None:
            ll_str += str(curr.item) + " -> "
            curr = curr.next
        return ll_str + "None"
    
def make_ll(l):
    if len(l) == 0:
        return LinkedList()
    nodes = []
    for ele in l:
        nodes.append(Node(ele))
    for i in range(len(l) - 1):
        nodes[i].next = nodes[i + 1]
    return LinkedList(nodes[0])

**2.1 Remove Dups:** Write code to remove duplicates from an unsorted linked list.

FOLLOW UP

How would you solve this problem if a temporary buffer is not allowed?

In [9]:
def remove_dups(ll):
    item_set = set()
    prev = ll.head
    item_set.add(prev.item)
    
    curr = prev.next
    while curr != None:
        if curr.item in item_set:
            prev.next = curr.next
        else:
            item_set.add(curr.item)
            prev = curr
        curr = curr.next

In [10]:
def remove_dups(ll):
    outer = ll.head
    while outer != None:
        inner = outer
        while inner.next != None:
            if inner.next.item == outer.item:
                inner.next = inner.next.next
            inner = inner.next
            if inner == None:
                break
        outer = outer.next

In [11]:
node1 = Node(3)
node2 = Node(3)
node3 = Node(1)
node4 = Node(2)
node5 = Node(2)
node6 = Node(1)
node1.next = node2
node2.next = node3
node3.next = node4
node4.next = node5
node5.next = node6
ll = LinkedList(node1)

remove_dups(ll)

assert ll.length() == 3

**2.2 Return Kth to Last:** Implement an algorithm to find the kth to last element of a singly linked list.

In [12]:
def return_last_kth(ll, k):
    slow = ll.head
    fast = slow
    
    for _ in range(k):
        try:
            fast = fast.next
        except:
            print "[return_kth] linked list is shorter than {k}".format(k=k)
            
    while fast != None:
        slow = slow.next
        fast = fast.next
        
    return slow.item

In [13]:
node1 = Node(3)
node2 = Node(3)
node3 = Node(1)
node4 = Node(2)
node5 = Node(2)
node6 = Node(1)
node1.next = node2
node2.next = node3
node3.next = node4
node4.next = node5
node5.next = node6
ll = LinkedList(node1)

assert return_last_kth(ll, 3) == 2
assert return_last_kth(ll, 1) == 1

**2.3 Delete Middle Node:** Implement an algorithm to delete a node in the middle (i.e. any node but the first and last node, not necessarily the exact middle) of a singly linked list, given only access to that node.

In [14]:
def delete_middle(ll):
    slow = ll.head
    fast = slow
    if slow == None:
        return
    if slow.next == None:
        ll.head = None
        return
    
    while fast != None:
        slow = slow.next
        fast = fast.next
        if fast != None:
            fast = fast.next
            
    slow.next = slow.next.next

In [15]:
node1 = Node(3)
node2 = Node(3)
node3 = Node(1)
node4 = Node(2)
node5 = Node(2)
node6 = Node(1)
node1.next = node2
node2.next = node3
node3.next = node4
node4.next = node5
node5.next = node6
ll = LinkedList(node1)

delete_middle(ll)

assert return_last_kth(ll, 3) == 1

2.4 Partition

In [16]:
def partition(ll, x):
    left = Node(None)
    right = Node(None)
    left_head = left
    right_head = right
    curr = ll.head
    while curr != None:
        if curr.item < x:
            left.next = curr
            left = curr
        else:
            right.next = curr
            right = curr
        curr = curr.next
    right.next = None
    left.next = right_head.next
    return LinkedList(left_head.next)

In [17]:
node1 = Node(3)
node2 = Node(5)
node3 = Node(8)
node4 = Node(5)
node5 = Node(10)
node6 = Node(2)
node7 = Node(1)
node1.next = node2
node2.next = node3
node3.next = node4
node4.next = node5
node5.next = node6
node6.next = node7
ll = LinkedList(node1)        # "3 -> 5 -> 8 -> 5 -> 10 -> 2 -> 1 -> None"

assert str(partition(ll, 5)) == "3 -> 2 -> 1 -> 5 -> 8 -> 5 -> 10 -> None"

**2.5 Sum Lists:** You have two numbers represented by a linked list, where each node contains a single digit. The digits are stored in reverse order, such that the 1's digit is at the head of the list. Write a function that adds the two numbers and returns the sum as a linked list.

EXAMPLE

Input: (7 -> 1 -> 6) + (5 -> 9 -> 2). That is, 617 + 295.

Output: 2 -> 1 -> 9. That is, 912.

In [38]:
def sum_lists(la, lb):
    is_head = True
    curra = la.head
    currb = lb.head
    carry = 0
    pre_head = Node(None)
    prev = pre_head
    while not (curra == None and currb == None and carry == 0):
        a, b = 0, 0
        if curra != None:
            a = curra.item
        if currb != None:
            b = currb.item
            
        carry, val = (a + b + carry) / 10, (a + b + carry) % 10
        
        node = Node(val)
        prev.next = node
        prev = node
        
        if curra != None:
            curra = curra.next
        if currb != None:
            currb = currb.next
    
    return LinkedList(pre_head.next)

In [39]:
la = make_ll([7,1,6])
lb = make_ll([5,9,3])

assert str(sum_lists(la, lb)) == "2 -> 1 -> 0 -> 1 -> None"

FOLLOW UP

Suppose the digits are stored in forward order. Repeat the above problem.

EXAMPLE

Input: (6 -> 1 -> 7) + (2 -> 9 -> 5). That is, 618 + 295.

Output: 9 -> 1 -> 2. That is, 912.

In [61]:
def sum_lists(la, lb):
    def recur(na, nb):
        if na == None and nb == None:
            return 0, None
        else:
            carry, next = recur(na.next, nb.next)
            a, b = na.item, nb.item
            carry, val = (a + b + carry) / 10, (a + b + carry) % 10
            node = Node(val)
            node.next = next
            return carry, node
    
    lena, lenb = la.length(), lb.length()
    if lena != lenb:
        if lena < lenb:
            head = la.head
            for _ in range(lenb - lena):
                curr = Node(0)
                curr.next = head
                head = curr
            la = LinkedList(head)
        else:
            head = lb.head
            for _ in range(lena - lenb):
                curr = Node(0)
                curr.next = head
                head = curr
            lb = LinkedList(head)
    
    carry, head = recur(la.head, lb.head)
    
    if carry != 0:
        node = Node(carry)
        node.next = head
        head = node
        
    return LinkedList(head)

In [62]:
la = make_ll([6,1,7])
lb = make_ll([3,9,5])

assert str(sum_lists(la, lb)) == "1 -> 0 -> 1 -> 2 -> None"

**2.6 Palindrome:** Implement a function to check if a linked list is a palindrome.

In [None]:
def is_palindrome(ll, length=ll.length()):
    def recur(head, length):
        if length == 0:
            return True,
        elif length == 1:
            return True, head.next
        elif length == 2:
            return head.item == head.next.item, head.next.next
        else:
            res = recur(head.next, length - 2)
            return res[0] and head.item == res[1].item, res[1].next
    res = recur(ll.head, length)
    return res[0]

In [None]:
node1 = Node(3)
node2 = Node(3)
node3 = Node(1)
node4 = Node(2)
node5 = Node(2)
node6 = Node(1)
node1.next = node2
node2.next = node3
node3.next = node4
node4.next = node5
node5.next = node6
ll1 = LinkedList(node1)

node1 = Node(3)
node2 = Node(2)
node3 = Node(1)
node4 = Node(2)
node5 = Node(3)
node1.next = node2
node2.next = node3
node3.next = node4
node4.next = node5
ll2 = LinkedList(node1)

node1 = Node(2)
node2 = Node(2)
node1.next = node2
ll3 = LinkedList(node1)

assert is_palindrome(ll1, 6) == False
assert is_palindrome(ll2, 5) == True
assert is_palindrome(ll3, 2) == True
assert is_palindrome(LinkedList(), 0) == True

**2.7 Intersection：** Given two (singly) linked list, determine if the two lists intersect. Return the intersecting node. Note that the intersection is defined based on reference, not value.

In [None]:
def intersection(ll1, ll2):
    
    len1 = ll1.length()
    len2 = ll2.length()
    if len1 > len2:
        longer = ll1.head
        shorter = ll2.head
    else:
        longer = ll2.head
        shorter = ll1.head
        
    for _ in range(len2 - len1):
        longer = longer.next
        
    while shorter != longer:
        shorter = shorter.next
        longer = longer.next
        
    return shorter

In [None]:
#      1 -> 2 ->
#               \
# 3 -> 4 -> 5 -> 6 -> 7 -> 8 -> 9
node1 = Node(1)
node2 = Node(2)
node3 = Node(3)
node4 = Node(4)
node5 = Node(5)
node6 = Node(6)
node7 = Node(7)
node8 = Node(8)
node9 = Node(9)
node1.next = node2
node2.next = node6
node3.next = node4
node4.next = node5
node5.next = node6
node6.next = node7
node7.next = node8
node8.next = node9
ll1 = LinkedList(node1)
ll2 = LinkedList(node3)

assert intersection(ll1, ll2) == node6

**2.8 Loop Detection:** Given a circular linked list, implement an algorithm that returns the node at the begining of the loop.

DEFINITION

Circular linked list: A (corrupt) linked list in which a node's next pointer points to an earlier node, so as to make a loop in the linked list.

EXAMPLE

Input: A -> B -> C -> D -> E -> C

Output: C

In [None]:
def detect_loop(ll):
    fast = slow = ll.head
    while True:
        fast = fast.next.next
        slow = slow.next
        if fast == slow:
            break
    slow2 = ll.head
    while True:
        slow = slow.next
        slow2 = slow2.next
        if slow == slow2:
            return slow

In [None]:
nodeA = Node('A')
nodeB = Node('B')
nodeC = Node('C')
nodeD = Node('D')
nodeE = Node('E')
nodeA.next = nodeB
nodeB.next = nodeC
nodeC.next = nodeD
nodeD.next = nodeE
nodeE.next = nodeC
ll = LinkedList(nodeA)

assert detect_loop(ll) == nodeC