# task.py

In [13]:
import numpy as np
import pickle
import os
from google.cloud import storage

import tensorflow as tf
from annoy import AnnoyIndex

class MyClassifier:
    def __init__(self, 
                 vector_length=512, 
                 metric='angular', 
                 bucket_name='easy666',
                 num_trees=100, 
                 blob_name='Qa', 
                 index_path='./index.ann', 
                 mapping_path='./mapping.pickle'):
        
        self.vector_length = vector_length
        self.metric = metric
        self.bucket_name = bucket_name
        self.num_trees = num_trees
        self.blob_name = blob_name
        self.index_path = index_path
        self.mapping_path = mapping_path
        
        self.storage_client = storage.Client()
        self.bucket = self.storage_client.bucket(bucket_name)
        
    def process(self):
        self.build_index()
        self.upload_file(self.index_path)
        self.upload_file(self.mapping_path)
        print('done!')
        
    def build_index(self):
        embed_file_list = self.get_embed_file_list()
        
        mapping = {}
        annoy_id = 0
        annoy_index = AnnoyIndex(self.vector_length, metric=self.metric)
        for i in range(len(embed_file_list)):
            embed_file = embed_file_list[i]
            record_iterator = tf.compat.v1.python_io.tf_record_iterator(path=embed_file)
            for string_record in record_iterator:
                example = tf.train.Example()
                example.ParseFromString(string_record)
                
                embedding = np.array(example.features.feature['embedding'].float_list.value)
                annoy_index.add_item(annoy_id, embedding)
                
                gds_name = example.features.feature['id'].bytes_list.value[0]
                gds_name = str(gds_name, 'utf-8') 
                mapping[annoy_id] = gds_name
                
                annoy_id += 1
        annoy_index.build(n_trees=self.num_trees)
        annoy_index.save(self.index_path)
        annoy_index.unload()
        
        self.save_mapping(mapping)
        
    def get_embed_file_list(self):
        pattern = "gs://%s/%s/embeddings/*.tfrecords"%(self.bucket_name, self.blob_name) # gs://easy826/Question/embeddings/embed-00000-of-00003.tfrecords
        embed_file_list = tf.io.gfile.glob(pattern)
        return embed_file_list
        
    def save_mapping(self, mapping):
        with open(self.mapping_path, 'wb') as file:
            pickle.dump(mapping, file, protocol=pickle.HIGHEST_PROTOCOL)
            
    def upload_file(self, file_path):
        filename = os.path.basename(file_path)
        blob_name = "%s/index/%s"%(self.blob_name, filename)
        blob = self.bucket.blob(blob_name)
        blob.upload_from_filename(file_path, content_type='application/octet-stream')
        
myClassifier = MyClassifier()


## test

In [14]:
myClassifier.process()

done!


In [8]:
# with open(myClassifier.mapping_path, 'rb') as f:
#     mapping = pickle.load(f)
# mapping
