In [None]:
from sklearn.mixture import GaussianMixture
import lmdb
import numpy as np
import lmdb
from tqdm import tqdm

DB_clean_features = lmdb.open('./clean/features.lmdb/',map_size=1200*1_000_000)  
DB_clean_id_to_filename = lmdb.open('./clean/id_to_filename.lmdb/',map_size=50*1_000_000) #50mb

def int_to_bytes(x: int) -> bytes:
    return x.to_bytes(4, 'big')

def int_from_bytes(xbytes: bytes) -> int:
    return int.from_bytes(xbytes, 'big')

def get_all_data(db, size=20000):
    with db.begin(buffers=True) as txn:
        with txn.cursor() as curs:
            ids = [] 
            features = [] 
            i=0
            for data in tqdm(curs.iternext(keys=True, values=True)):
                if i>=size:
                    break
                ids.append(int_from_bytes(data[0]))
                features.append(np.frombuffer(data[1],dtype=np.float32))
                i+=1
            return ids, features


def get_file_name(image_id,file_name_db):
    with file_name_db.begin(buffers=False) as txn:
        file_name = txn.get(int_to_bytes(image_id), default=False)
        return file_name.decode("utf-8")

In [None]:
clean_ids, clean_features = get_all_data(DB_clean_features, 300000)
clean_ids, clean_features = np.array(clean_ids), np.array(clean_features)

In [None]:
gmm = GaussianMixture(n_components = 16, covariance_type = 'full')
gmm.fit(clean_features)

# import pickle 
# with open("./gmm.model","wb") as file:
#     pickle.dump(gmm,file)

In [None]:
import seaborn as sns
clean_scores = gmm.score_samples(clean_features)
sns.histplot([el for el in clean_scores if el>0])

In [None]:
DB_test_features = lmdb.open('./test/features.lmdb/',map_size=1200*1_000_000) 
DB_test_id_to_filename = lmdb.open('./test/id_to_filename.lmdb/',map_size=50*1_000_000)

test_ids, test_features = get_all_data(DB_test_features, 300000)
test_ids, test_features = np.array(test_ids), np.array(test_features)

In [None]:
test_scores = gmm.score_samples(test_features)
sns.histplot([el for el in test_scores if el>0])

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import math

def read_img_file(f):
    img = Image.open(f)
    if img.mode != 'RGB':
        img = img.convert('RGB')
    return img

def plot_imgs(file_names,IMG_PATH):
    s = math.ceil(math.sqrt(len(file_names)))
    _, axs = plt.subplots(s, s, figsize=(12, 12))
    axs = axs.flatten()
    imgs = [np.array(read_img_file(IMG_PATH+el).resize((256,256))) for el in file_names]
    for img, ax in zip(imgs, axs):
        ax.imshow(img)
    plt.show()

In [None]:
clean_filenames_range =  [get_file_name(int(clean_ids[idx]),DB_clean_id_to_filename) for idx,score in enumerate(clean_scores) if 1000<score<2000]

In [None]:
plot_imgs(clean_filenames_range[:64],"./clean/images/")

In [None]:
test_filenames_range =  [get_file_name(int(test_ids[idx]),DB_test_id_to_filename) for idx,score in enumerate(test_scores) if 1000<score<2000]

In [None]:
plot_imgs(test_filenames_range[:64],"./test/images/")