In [None]:
import random
from timeit import default_timer as timer

In [None]:
class TPA:
  def __init__(self, server_segment_tree_control, security_implementation):
    self.server_segment_tree_control = server_segment_tree_control 
    self.security_implementation = security_implementation 
    self.version_root_hash  = []
    self.base_encrypted_blocks_hash = [] # the main block hashes ffom initial file 
    self.updates = []
    self.block_update_history = []

  def add_new_version_root_hash(self, new_version_root_hash):
    self.version_root_hash.append(new_version_root_hash)
    # print(self.version_root_hash)

  def initial_file_block_save(self, encrypted_blocks_hash):
    self.base_encrypted_blocks_hash = encrypted_blocks_hash
    # debug 
    # print(len(self.base_encrypted_blocks_hash))
    for i in range(0, len(encrypted_blocks_hash)):
      self.block_update_history.append([]) #e.g, 1: 2, 3,   2: 5, 7

  def store_updated_blocks(self,updated_encryption_block_hash,block_id):
    # storing (bitmask of the updated indexes, encrypted normal value)
    # print(len(self.updates), updated_encryption_block_hash, block_id, len(self.block_update_history))
    if len(self.updates) == 0:
      self.updates.append([block_id, updated_encryption_block_hash])
    else:
      self.updates.append([(1<<block_id) | (self.updates[-1][0]), updated_encryption_block_hash])
    self.block_update_history[block_id].append(len(self.updates)) # 0th file version/update 0, 1th file version/update 1, 2nd version/update 2
    #print(self.block_update_history)
    #print(updated_encryption_block_hash)

  def challenge_server(self, number_of_queries=1, file_version_no=0):
    number_of_blocks = len(self.base_encrypted_blocks_hash)-1 
    _sum = 0.0
    for i in range(0, number_of_queries):
      block_index = random.randint(0, number_of_blocks) 
      file_version_no = random.randint(0, len(self.version_root_hash)-1) 
      # print("block index = {} file_version_no = {}".format(block_index, file_version_no), len(sibling_nodes))
      input_hash = self.get_hash_for_particular_block_for_different_updates(file_version_no=file_version_no, block_index=block_index)
      #print(self.base_encrypted_blocks_hash[block_index])
      #print(self.updates[0][1])
      start = timer()
      sibling_nodes = self.server_segment_tree_control.challenge_server(file_version_no=file_version_no, block_index=block_index, hashes=True)
      #print("block index = {} file_version_no = {}".format(block_index, file_version_no), len(sibling_nodes))
      calculated_root_hash = self.security_implementation.calculate_root_from_sib_path(input_hash=input_hash, 
                                                                                  sibling_nodes=sibling_nodes, node_based=False)
      print(calculated_root_hash, self.version_root_hash[file_version_no], file_version_no)
      if calculated_root_hash == self.version_root_hash[file_version_no]:
        #print("Matched")
        pass 
      else:
        print("block index = {} file_version_no = {}".format(block_index, file_version_no))
        print("Not matched")
      # root hash must be properly calculated 
      assert(calculated_root_hash == self.version_root_hash[file_version_no])
      end = timer()
      _sum = _sum + end-start 
    _sum = _sum /(1.0 * number_of_queries)
    return _sum 

  def get_hash_for_particular_block_for_different_updates(self, file_version_no, block_index):
    # will get the updated block hash for a particular block
    if file_version_no == 0:
      return self.base_encrypted_blocks_hash[block_index] # returning the block hash from the initial file status 
    # will do binary search, to get the last update number for the given block 
    st, en = 0, len(self.block_update_history[block_index])-1
    ans = -1
    while True:
      if (st==en):
        if (self.block_update_history[block_index][st] <= file_version_no):
          ans = st
        break 
      if (st>en):
        break
      mid = int((st+en)/2)
      if self.block_update_history[block_index][mid] > file_version_no:
        en = mid-1 # update crossed max limit of search  
      elif self.block_update_history[block_index][mid] < file_version_no: 
        ans = max(ans, mid) 
        st = mid+1 
      else:
        ans = mid 
        break 
    if (ans == -1):
      # did not find any valid entry within search update
      return self.base_encrypted_blocks_hash[block_index]
    # in ans I have the largest update (==file version no) number <=file_version_no 
    ans = self.block_update_history[block_index][ans]
    # print("answer ", ans)
    return self.updates[ans-1][1] # 0 based indexing 

  def generate_file_blocks_from_tpa(self, version_no):
    pass 

  def generate_file_blocks_from_seg_tree(self, version_no):
    pass 