## XOR TRIPLETS:
Problem Description

Given an array of integers A of size N.

A triplet (i, j, k), i <= j <= k is called a power triplet if A[i] ^ A[i+1] ^....A[j-1] = A[j] ^.....^A[k].

Where, ^ denotes bitwise xor.

Return the count of all possible power triplets. Since the answer could be large return answer % 109 +7.
### Problem Constraints
1 <= N <= 100000

1 <= A[i] <= 100000
### Input Format
The first argument given is the integer array A.
### Output Format
Return the count of all possible power triplets % 109 + 7.
## Example Input
### Input 1:
 A = [5, 2, 7]
### Input 2:
 A = [1, 2, 3]
## Example Output
### Output 1:
 2
### Output 2:
 2
## Example Explanation
### Explanation 1:
 All possible power triplets are:
    
    1. (1, 2, 3) ->  A[1] = A[2] ^ A[3]
    
    2. (1, 3, 3) ->  A[1] ^ A[2] = A[3]
### Explanation 2:
 All possible power triplets are:
    
    1. (1, 2, 3) ->  A[1] = A[2] ^ A[3]
    
    2. (1, 3, 3) ->  A[1] ^ A[2] = A[3]

Approach:

We need to push the cumulative/prefix xor value into the trie until the current prefix xor becomes ZERO.
Here, ZERO signifies that the cumulative xor from index i till index k has become ZERO.
So, now we know [i .............. k] ==> xor is ZERO
==> We need to find no. of j's that satisfy i<j<=k condition.

In [44]:
from collections import defaultdict
class TrieNode():
    def __init__(self):
        self.children = defaultdict()
        self.value = -1
        self.cnt = 0
        self.sum_of_indices = 0
        
class Trie():
    def __init__(self):
        self.root = self.get_node()

    def get_node(self):
        return TrieNode()

    def insert(self, n, index):
        root = self.root
        #print(root.children)
        num = '{:032b}'.format(n)
        for i in range(32):
            curr = int(num[i])
            if curr not in root.children:
                root.children[curr] = self.get_node()
            root = root.children.get(curr)
        root.value = n
        root.cnt += 1
        root.sum_of_indices = index+1
        
    def search(self, n):
        #print(n)
        root = self.root
        num = '{:032b}'.format(n)
        for i in range(32):
            curr = int(num[i])
            if not root:
                return False
            root = root.children.get(curr)
        #print(root.val)
        return [root.cnt, root.sum_of_indices] if root else False
    
class Solution:
    # @param A : list of integers
    # @return an integer
    def solve(self, a):
        n = len(a)
        t = Trie()
        xor = 0
        t.insert(0, -1)
        ans = 0
        for i in range(n):
            xor = xor ^ a[i]
            check_xor = t.search(xor)
            #print(xor, check_xor, i, ans)
            if check_xor != False:
                ans += ((i * check_xor[0]) - check_xor[1])
            else:
                t.insert(xor, i)
        return ans
    def solve1(self, a):
        n = len(a)
        xor = 0
        d = defaultdict(int)
        d[0] = [1, 0]
        ans = 0
        for i in range(n):
            xor ^= a[i]
            if xor in d:
                ans += (i*d[xor][0] - d[xor][1])
            else:
                d[xor] = [1, i+1]
        return ans
        
o = Solution()
A = [ 804, 621, 170, 320, 234, 81, 57, 175, 513, 189, 163, 610, 656, 52, 957, 632, 33, 920, 280, 317, 931, 848, 630, 511, 251, 754, 899, 648, 284, 598, 818, 428, 18, 996, 629, 203, 449, 925, 25, 961, 451, 80, 625, 284, 945, 190, 650, 501, 265, 56, 919, 803, 762, 514, 973, 564, 356, 775, 538, 550, 755, 903, 106, 365, 230, 174, 882, 918, 290, 775, 169, 251, 477, 49, 107, 967, 368, 432, 272, 5, 556, 223, 460, 812, 848, 853, 513, 470, 833, 966, 786, 641, 916, 892, 448, 973, 488, 669, 819, 687 ]

print(o.solve1(A))


171


In [None]:
lg = 20
  
# Structure of a Trie Node 
class TrieNode: 
    # Constructor to initialize a newly created node 
    def __init__(self):
        self.children = [None, None] 
        self.sum_of_indexes = 0
        self.number_of_indexes = 0
  
# Function to insert curr_xor into the trie 
def insert(node, num,index):
  
    for bits in range(lg,-1,-1):
        # Check if the current bit is set or not 
        curr_bit = (num >> bits) & 1
  
        # If this node isn't already present in the trie structure insert it into the trie. 
        if (node.children[curr_bit] == None): 
            node.children[curr_bit] = TrieNode() 
        node = node.children[curr_bit] 
  
    # Increase the sum of indexes by the current index value 
    node.sum_of_indexes += index
  
    # Increase the number of indexes by 1 
    node.number_of_indexes+=1 

  
# Function to check if curr_xor is present in trie or not 
def query(node, num, index):
  
    for bits in range(lg,-1,-1):
        # Check if the current bit s set or not 
        curr_bit = (num >> bits) & 1
  
        # If this node isn't already present in the trie structure that means no sub array till current index has 0 xor so 
        # return 0 
        if (node.children[curr_bit] == None): 
            return 0
  
        node = node.children[curr_bit] 
  
    # Calculate the number of index inserted at final node 
    sz = node.number_of_indexes 
  
    # Calculate the sum of index inserted at final node 
    Sum = node.sum_of_indexes 
  
    ans = (sz * index) - (Sum)
  
    return ans

class Solution:
    # @param A : list of integers
    # @return an integer
    def solve(self, A):
        curr_xor = 0 
        mod = 1e9+7
        number_of_triplets = 0 
        
        n = len(A)
        
        # The root of the trie 
        root = TrieNode()
      
        for i in range(n):
            x = A[i]
     
            # Insert the curr_xor in the trie 
            insert(root, curr_xor, i)
      
            # Update the cumulative xor 
            curr_xor ^= x
      
            # Check if the cumulative xor is present in the trie or not if present then add (sz * index) - sum 
            number_of_triplets += query(root, curr_xor, i)
            number_of_triplets %= mod
    
      
        return int(number_of_triplets)

In [None]:
class Solution:
    # @param A : list of integers
    # @return an integer
    def solve(self, arr):
        d = {0:[-1]} #helps in implementation
        res = 0 #running xor value
        count = 0
        m = 10**9+7
        for i in range(len(arr)):
            res = res^arr[i]
            if res in d: #Point 2 above, if we find the vaue already in dict
                temp = [i-x-1 for x in d[res]] #find all j, and add
                count += sum(temp)
                d[res].append(i)
            else:
                d[res] = [i]#save in dict if not already present Point 1 above
        return count%m
