In [9]:
from typing import List

import numpy as np
import pandas as pd

import pickle


from sklearn.cluster import KMeans, MiniBatchKMeans





In [10]:
from sklearnex import patch_sklearn
patch_sklearn()

Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


In [11]:
def convert_semanticID_to_text(hie_id: List[int]) -> str:
    """Convert the generated hierarchical semantic id to text for training
       eg. [2, 5, 9, 5, 52] --> "2 5 9 5 52"
    
    Arguments:
      :hie_id - List[int]: the hierarchical semantic id for a single passage
    return:
      :str_id - str: the id with str format 
    """
    str_id_cand = []
    for sem_i in hie_id:
        str_id_cand.append(str(sem_i))
    return " ".join(str_id_cand)

def create_structured_semantic_id(
            img_emb_path: str="flickr-30k/figures_repr_mat.npy",
            k: int=5,
            c: int=30,
            emb_dim: int=768,
            cluster_bsz: int=int(1e3),
    ):
        """Create the structured semantic ids from data representation
        
        Args:
            psg_emb_path(str): the path of data representation. It shoule
                be loaded as a numpy.array format
            k(int): the number of clusters in clustering process
            c(int): the maximum number of leaf cluster
            emb_dim(int): dim of data representation
            cluster_bsz(int): batch size of KMeans clustering
        """
        
        # load passage_embedding npy file
        with open(img_emb_path, 'rb') as f:
            X = np.load(f, allow_pickle=True)

        kmeans = KMeans(
            n_clusters=k,
            max_iter=500,
            n_init=100,
            init='k-means++',
            tol=1e-6,
            verbose=20
        )

        mini_kmeans = MiniBatchKMeans(
            n_clusters=k,
            max_iter=300,
            n_init=100,
            init='k-means++',
            batch_size=cluster_bsz,
            reassignment_ratio=0.01,
            max_no_improvement=50,
            tol=1e-7,
            verbose=1
        )

        # use a list to store generated structured semantic ids
        semantic_id_list = []

        def classify_recursion(x_data_pos):
            if x_data_pos.shape[0] <= c:
                if x_data_pos.shape[0] == 1:
                    return
                for idx, pos in enumerate(x_data_pos):
                    semantic_id_list[pos].append(idx)
                return

            temp_data = np.zeros((x_data_pos.shape[0], emb_dim))
            for idx, pos in enumerate(x_data_pos):
                temp_data[idx, :] = X[pos]

            if x_data_pos.shape[0] >= cluster_bsz:
                pred = mini_kmeans.fit_predict(temp_data)
            else:
                pred = kmeans.fit_predict(temp_data)

            for i in range(k):
                pos_lists = []
                for id_, class_ in enumerate(pred):
                    if class_ == i:
                        pos_lists.append(x_data_pos[id_])
                        semantic_id_list[x_data_pos[id_]].append(i)
                classify_recursion(np.array(pos_lists))
            return

        print('Start First Clustering')
        pred = kmeans.fit_predict(X)
        print(pred.shape)  # int 0-9 for each vector
        print(kmeans.n_iter_)

        for class_ in pred:
            semantic_id_list.append([class_])

        print('Start Recursively Clustering...')
        for i in range(k):
            print(i, "th cluster")
            pos_lists = []
            for id_, class_ in enumerate(pred):
                if class_ == i:
                    pos_lists.append(id_)
            classify_recursion(np.array(pos_lists))
        print('Complete!')

        # This process is very time-consuming, so it will be better to save the
        # intermediate result to disk
        with open("flickr-30k/new_ids_list.pkl", "wb") as f:
            pickle.dump(semantic_id_list, f)

        old2new_id_mapper = {}
        new2old_id_mapper = {}
        
        # load old id
        with open("flickr-30k/old_ids_list.pkl", "rb") as f:
            old_id_list = pickle.load(f)
        # load new id
        with open("flickr-30k/new_ids_list.pkl", "rb") as f:
            new_id_list = pickle.load(f)

        # generate old2new_id_mapper
        for i in range(len(old_id_list)):
            old2new_id_mapper[old_id_list[i]] = convert_semanticID_to_text(new_id_list[i])
        # generate new2old_id_mapper
        for old, new in old2new_id_mapper.items():
            new2old_id_mapper[new] = old

        # save to disk
        with open(f"flickr-30k/old2new_id_mapper_k{k}_c{c}.pkl", "wb") as f:
            pickle.dump(old2new_id_mapper, f)
        with open(f"flickr-30k/new2old_id_mapper_k{k}_c{c}.pkl", "wb") as f:
            pickle.dump(new2old_id_mapper, f)

create_structured_semantic_id()

Start First Clustering
Initialization complete
Iteration 0, inertia 35924.15234375.
Initialization complete
Iteration 1, inertia 35831.92578125.
Initialization complete
Iteration 2, inertia 35831.92578125.
Initialization complete
Iteration 3, inertia 35832.1796875.
Initialization complete
Iteration 4, inertia 35924.15234375.
Initialization complete
Iteration 5, inertia 35924.15234375.
Initialization complete
Iteration 6, inertia 35832.0078125.
Initialization complete
Iteration 7, inertia 35831.609375.
Initialization complete
Iteration 8, inertia 35924.15234375.
Initialization complete
Iteration 9, inertia 35924.15234375.
Initialization complete
Iteration 10, inertia 35922.73046875.
Initialization complete
Iteration 11, inertia 35831.92578125.
Initialization complete
Iteration 12, inertia 35924.15234375.
Initialization complete
Iteration 13, inertia 35924.15234375.
Initialization complete
Iteration 14, inertia 35831.73828125.
Initialization complete
Iteration 15, inertia 35924.15234375.

  "MiniBatchKMeans is known to have a memory leak on "


Inertia for init 2/100: 7071.84224962362
Init 3/100 with method k-means++
Inertia for init 3/100: 6380.491453894814
Init 4/100 with method k-means++
Inertia for init 4/100: 6329.060743705633
Init 5/100 with method k-means++
Inertia for init 5/100: 7120.824777234356
Init 6/100 with method k-means++
Inertia for init 6/100: 6064.05781549347
Init 7/100 with method k-means++
Inertia for init 7/100: 6553.93764421682
Init 8/100 with method k-means++
Inertia for init 8/100: 6422.038381949092
Init 9/100 with method k-means++
Inertia for init 9/100: 6692.573243501959
Init 10/100 with method k-means++
Inertia for init 10/100: 6262.972284837174
Init 11/100 with method k-means++
Inertia for init 11/100: 6186.16516689887
Init 12/100 with method k-means++
Inertia for init 12/100: 6492.353583005461
Init 13/100 with method k-means++
Inertia for init 13/100: 7308.873345553165
Init 14/100 with method k-means++
Inertia for init 14/100: 6059.293465259439
Init 15/100 with method k-means++
Inertia for init 1

  "MiniBatchKMeans is known to have a memory leak on "


Inertia for init 2/100: 5509.474579873572
Init 3/100 with method k-means++
Inertia for init 3/100: 5404.341127007102
Init 4/100 with method k-means++
Inertia for init 4/100: 5776.875050201645
Init 5/100 with method k-means++
Inertia for init 5/100: 5236.653340654402
Init 6/100 with method k-means++
Inertia for init 6/100: 6078.621539094563
Init 7/100 with method k-means++
Inertia for init 7/100: 5914.546455856184
Init 8/100 with method k-means++
Inertia for init 8/100: 5566.334396982798
Init 9/100 with method k-means++
Inertia for init 9/100: 6016.20668142236
Init 10/100 with method k-means++
Inertia for init 10/100: 5773.045061609345
Init 11/100 with method k-means++
Inertia for init 11/100: 5346.637029025864
Init 12/100 with method k-means++
Inertia for init 12/100: 5485.488024316934
Init 13/100 with method k-means++
Inertia for init 13/100: 5569.769134442238
Init 14/100 with method k-means++
Inertia for init 14/100: 5682.961531739159
Init 15/100 with method k-means++
Inertia for ini

  "MiniBatchKMeans is known to have a memory leak on "


Init 1/100 with method k-means++
Inertia for init 1/100: 4939.0350564033115
Init 2/100 with method k-means++
Inertia for init 2/100: 5205.498077201858
Init 3/100 with method k-means++
Inertia for init 3/100: 5275.319319243537
Init 4/100 with method k-means++
Inertia for init 4/100: 4966.318998308703
Init 5/100 with method k-means++
Inertia for init 5/100: 4759.479115793449
Init 6/100 with method k-means++
Inertia for init 6/100: 5081.816630253719
Init 7/100 with method k-means++
Inertia for init 7/100: 4866.286225591295
Init 8/100 with method k-means++
Inertia for init 8/100: 5062.422621791877
Init 9/100 with method k-means++
Inertia for init 9/100: 4821.00536851254
Init 10/100 with method k-means++
Inertia for init 10/100: 5127.833249151706
Init 11/100 with method k-means++
Inertia for init 11/100: 5061.677861754459
Init 12/100 with method k-means++
Inertia for init 12/100: 5484.223158679159
Init 13/100 with method k-means++
Inertia for init 13/100: 5038.008850168665
Init 14/100 with 

  "MiniBatchKMeans is known to have a memory leak on "


Inertia for init 2/100: 1380.0590063478066
Init 3/100 with method k-means++
Inertia for init 3/100: 1227.1774463011354
Init 4/100 with method k-means++
Inertia for init 4/100: 1301.7615322906204
Init 5/100 with method k-means++
Inertia for init 5/100: 1279.6831070297683
Init 6/100 with method k-means++
Inertia for init 6/100: 1224.9716048690727
Init 7/100 with method k-means++
Inertia for init 7/100: 1293.3533208789988
Init 8/100 with method k-means++
Inertia for init 8/100: 1317.2849873045648
Init 9/100 with method k-means++
Inertia for init 9/100: 1239.372967458849
Init 10/100 with method k-means++
Inertia for init 10/100: 1389.7932518985144
Init 11/100 with method k-means++
Inertia for init 11/100: 1312.1682094691055
Init 12/100 with method k-means++
Inertia for init 12/100: 1252.223042976886
Init 13/100 with method k-means++
Inertia for init 13/100: 1259.0309286531485
Init 14/100 with method k-means++
Inertia for init 14/100: 1248.6711448190856
Init 15/100 with method k-means++
Ine

  "MiniBatchKMeans is known to have a memory leak on "


Inertia for init 3/100: 944.2308164624492
Init 4/100 with method k-means++
Inertia for init 4/100: 886.1091886379141
Init 5/100 with method k-means++
Inertia for init 5/100: 1014.0595164570503
Init 6/100 with method k-means++
Inertia for init 6/100: 947.7846814300468
Init 7/100 with method k-means++
Inertia for init 7/100: 972.1094187927463
Init 8/100 with method k-means++
Inertia for init 8/100: 913.5628890030383
Init 9/100 with method k-means++
Inertia for init 9/100: 1044.1205034984282
Init 10/100 with method k-means++
Inertia for init 10/100: 964.5966904304637
Init 11/100 with method k-means++
Inertia for init 11/100: 909.9591155287991
Init 12/100 with method k-means++
Inertia for init 12/100: 962.4008760733535
Init 13/100 with method k-means++
Inertia for init 13/100: 1000.8115289669165
Init 14/100 with method k-means++
Inertia for init 14/100: 938.4048838059322
Init 15/100 with method k-means++
Inertia for init 15/100: 966.8753265256257
Init 16/100 with method k-means++
Inertia f