In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# import pandas as pd
import os
from sklearn.neighbors import NearestNeighbors
import joblib
import pandas as pd
from google.cloud import storage
from vincentvanbot.params import BUCKET_NAME, BUCKET_INITIAL_DATASET_FOLDER
from vincentvanbot.preprocessing.utils import get_jpg_link

Workflow:
- get an example input file (in bytes)
- preprocess it --> return df resized and flattened
- KNN predict --> return indices of closest images
- indices to url --> return  list of urls of closest images

In [3]:
# train
def train_model(df_transformed):
    """Takes preprocessed train data as df. Returns fitted KNN model 
    and train data image indexes (used then to refer back to initial database).
    Saves locally model and indexes."""
    knn_model = NearestNeighbors().fit(df_transformed)
    
    joblib.dump(knn_model,'model.joblib')
    joblib.dump(df_transformed.index,'train_indexes.joblib')

In [4]:
def save_model_to_cloud(rm=False):
    """Uploads fitted model and related indexes to GCloud."""
    client = storage.Client().bucket(BUCKET_NAME)
    
    for filename in ['model.joblib','train_indexes.joblib']:
        storage_location = f"predict/{filename}"
        blob = client.blob(storage_location)
        blob.upload_from_filename(filename)
        print(f"=> {filename} uploaded to bucket {BUCKET_NAME} inside {storage_location}")
    if rm:
        os.remove('model.joblib')
        os.remove('train_indexes.joblib')

In [5]:
# predict
def get_closest_images_indexes(user_input_transformed, nsimilar=3, rm=True):
    """Takes user_input_transformed as np.array. Downloads fitted knn model and related indexes.
    Returns indexes of nsimilar closest images"""
    client = storage.Client().bucket(BUCKET_NAME)
    
    # download model
    local_name = 'model.joblib'
    storage_location = f"predict/{local_name}"
    blob = client.blob(storage_location)
    blob.download_to_filename(local_name)
    print(f"=> {local_name} downloaded from storage")
    model = joblib.load(local_name)
    
    # download indexes
    local_name = 'train_indexes.joblib'
    storage_location = f"predict/{local_name}"
    blob = client.blob(storage_location)
    blob.download_to_filename(local_name)
    print(f"=> {local_name} downloaded from storage")
    indexes = joblib.load(local_name)
    
    if rm:
        os.remove('model.joblib')
        os.remove('train_indexes.joblib')
    
    index_neighbors = model.kneighbors(user_input_transformed, n_neighbors=nsimilar)[1][0]
    
    return [int(indexes[i]) for i in list(index_neighbors)]

In [6]:
def get_info_from_index(indexes, all_info=False):
    """From given image indexes, gets initial dataset from gcloud
    and returns respective information (urls, etc.)"""
    client = storage.Client()
    
    dataset_filename = 'catalog.csv'
    path = f"gs://{BUCKET_NAME}/{BUCKET_INITIAL_DATASET_FOLDER}/{dataset_filename}"
    
    df = pd.read_csv(path, encoding= 'unicode_escape')
    df['URL'] = df['URL'].map(get_jpg_link)
    
    urls = [df.iloc[i]['URL'] for i in indexes]
    
    if not all_info:
        return urls
    
    # get additional info
    titles = [df.iloc[i]['TITLE'] for i in indexes]
    authors = [df.iloc[i]['AUTHOR'] for i in indexes]
    
    return urls, titles, authors

In [7]:
# prepare inputs for test
from vincentvanbot.preprocessing.utils import preprocess_image
from vincentvanbot.data import get_pickle

user_img = preprocess_image('example-input.jpg',dim=(36,42))
train_df = get_pickle()

In [8]:
# workflow
train_model(train_df)
save_model_to_cloud(rm=True)
indexes = get_closest_images_indexes(user_img)
get_info_from_index(indexes)

=> model.joblib uploaded to bucket vincent-van-bot-bucket inside predict/model.joblib
=> train_indexes.joblib uploaded to bucket vincent-van-bot-bucket inside predict/train_indexes.joblib
=> model.joblib downloaded from storage
=> train_indexes.joblib downloaded from storage


['https://www.wga.hu/art/a/abbate/torfani2.jpg',
 'https://www.wga.hu/art/a/abbati/abbati1.jpg',
 'https://www.wga.hu/art/a/ademollo/ark1.jpg']