## This notebook contains an algorithm for parsing and filtering images from Google searches. This can be useful if you need to expand your data set

## This is my first post for the Kaggle community, I hope you enjoy it

# Install

In [None]:
# install google chrome
!wget https://dl.google.com/linux/linux_signing_key.pub
!sudo apt-key add linux_signing_key.pub
!echo 'deb [arch=amd64] http://dl.google.com/linux/chrome/deb/ stable main' >> /etc/apt/sources.list.d/google-chrome.list
!sudo apt-get -y update
!sudo apt-get install -y google-chrome-stable

In [None]:
# install chromedriver
# !apt-get install -y qq unzip
!wget -O /tmp/chromedriver.zip http://chromedriver.storage.googleapis.com/`curl -sS chromedriver.storage.googleapis.com/LATEST_RELEASE`/chromedriver_linux64.zip
!unzip /tmp/chromedriver.zip chromedriver -d /usr/local/bin/

In [None]:
# install selenium
!sudo apt install -y python3-selenium
!pip install selenium==3.141.0 > /dev/null

In [None]:
# To check Google Chrome's version
!google-chrome --version

In [None]:
# To check Chrome Driver's version
!chromedriver -v

# Start

In [None]:
# import libraries
import io
import os
import time
import shutil
import hashlib
import requests
import signal
import errno

from tqdm import tqdm
from multiprocessing import Pool
from PIL import Image, ImageDraw
from selenium import webdriver
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.common.by import By

In [None]:
!mkdir data

In [None]:
chrome_options = webdriver.ChromeOptions()
chrome_options.add_argument('--no-sandbox')
chrome_options.add_argument('--headless')
chrome_options.add_argument('--disable-gpu')
chrome_options.add_argument('--disable-dev-shm-usage')
chrome_options.add_argument("--window-size=1920,1080")
driver = webdriver.Chrome(options=chrome_options)

In [None]:
def fetch_image_urls(query:str, max_links_to_fetch:int, wd:webdriver, sleep_between_interactions:int=1):
    """
    This function scrolls down and collects links to images in a google search
    """
    def scroll_to_end(wd):
        wd.execute_script("window.scrollTo(0, document.body.scrollHeight);")
        time.sleep(sleep_between_interactions)    
    
    # build the google query
    search_url = "https://www.google.com/search?safe=off&site=&tbm=isch&source=hp&q={q}&oq={q}&gs_l=img"

    # load the page
    wd.get(search_url.format(q=query))

    image_urls = set()
    image_count = 0
    results_start = 0
    while image_count < max_links_to_fetch:
        scroll_to_end(wd)

        # get all image thumbnail results
        thumbnail_results = wd.find_elements_by_css_selector("img.Q4LuWd")
        number_results = len(thumbnail_results)
        
        print(f"Found: {number_results} search results. Extracting links from {results_start}:{number_results}")
        
        for img in thumbnail_results[results_start:number_results]:
            # try to click every thumbnail such that we can get the real image behind it
            try:
                img.click()
                time.sleep(sleep_between_interactions)
            except Exception:
                continue

            # extract image urls    
            actual_images = wd.find_elements_by_css_selector('img.n3VNCb')
            for actual_image in actual_images:
                if actual_image.get_attribute('src') and 'http' in actual_image.get_attribute('src'):
                    image_urls.add(actual_image.get_attribute('src'))
                    
                    print(actual_image.get_attribute('src'))
                    #print(actual_image.text)

            image_count = len(image_urls)

            if len(image_urls) >= max_links_to_fetch:
                print(f"Found: {len(image_urls)} image links, done!")
                break
        else:
            print("Found:", len(image_urls), "image links, looking for more ...")
            time.sleep(10)
            #return
            load_more_button = wd.find_element_by_css_selector(".mye4qd")
            if load_more_button:
                wd.execute_script("document.querySelector('.mye4qd').click();")

        # move the result startpoint further down
        results_start = len(thumbnail_results)

    return image_urls

In [None]:
def persist_image(folder_path:str,url:str):
    """
    This function saves images to desktop in the 'folder_path' directory
    """
    try:
        image_content = requests.get(url).content

    except Exception as e:
        print(f"ERROR - Could not download {url} - {e}")

    try:
        image_file = io.BytesIO(image_content)
        image = Image.open(image_file).convert('RGB')
        file_path = os.path.join(folder_path,hashlib.sha1(image_content).hexdigest()[:10] + '.jpg')
        with open(file_path, 'wb') as f:
            image.save(f, "JPEG", quality=85)
        print(f"SUCCESS - saved {url} - as {file_path}")
    except Exception as e:
        print(f"ERROR - Could not save {url} - {e}")

In [None]:
class TimeoutError(Exception):
    pass

class timeout:
    """
    This function interrupts too long functions (when something is bad)
    """
    def __init__(self, seconds=1, error_message='Timeout'):
        self.seconds = seconds
        self.error_message = error_message
    def handle_timeout(self, signum, frame):
        raise TimeoutError(self.error_message)
    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)
    def __exit__(self, type, value, traceback):
        signal.alarm(0)

In [None]:
# the following queries will be inserted into google search
# examples of famous places are taken from this site
# https://www.architecturaldigest.com/story/most-iconic-buildings-around-the-world

queries = ['statue of liberty',
           'Hagia Sophia',
           'Dancing House',
           'The Pyramids of Giza',
           'Acropolis of Athens',
           'Gateway Arch',
           'Le Centre Pompidou',
           'Musée d’Orsay',
           'Dresden Frauenkirche',
           'Château Frontenac',
           'St. Basil’s Cathedral',
           'Casa Milà',
           'White House',
           'Lincoln Center',
           'Angkor Wat',
           'Sultan Ahmed Mosque',
           'Musée du Louvre',
           'Sydney Opera House',
           'Guggenheim Museum',
           'Burj Khalifa',
           'Leaning Tower of Pisa',
           'Flatiron Building',
           'Eiffel Tower',
           'The Colosseum',
           'Stonehenge'
          ]

queries = ['The Colosseum', 'Stonehenge']

In [None]:
N = 400                 # how many images to parse
NUMBER_OF_BROUSERS = 2  # The number of brousers that will search queries in parallel 
                        # it mostly depends on the speed of your internet connection

def parse_query(query):
    #driver = webdriver.Firefox(executable_path=DRIVER_PATH)
    driver = webdriver.Chrome(options=chrome_options)

    links = fetch_image_urls(query, N, driver)

    directory = 'data/{}'.format(query.replace(" ", "_"))

    try:
        os.stat(directory)
    except:
        os.mkdir(directory) # create directory for this query if it doesn't exist

    for link in list(links):
        with timeout(seconds=5): # if it takes more than 5 sec to save an image, we will skip it
            persist_image(directory, link) 


if __name__ == '__main__':
    with Pool(NUMBER_OF_BROUSERS) as p:
        print(p.map(parse_query, queries))

In [None]:
for query in queries:
    directory = 'data/{}'.format(query.replace(" ", "_"))
    
    print("There are {} photos of {}".format(len(os.listdir(directory)), query))

In [None]:
# How many objects we have
len(queries)

# Filter images with clip

In [None]:
# Installation
# !conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
# !pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

In [None]:
import torch
import clip # how to install clip: https://github.com/openai/CLIP#usage
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [None]:
from scipy import spatial

def find_suspecious_images(query, THREASHOLD = 0.30):
    """
    This function finds sucpecious images for a specific query and replace it in a new folder
    """
    # ideally we expect all pictures to look like a query, but in practice there is a lot of garbage
    LIST_TO_COMPARE_WITH = ['photo of ' + query] # 'photo of Golden Gate Bridge' for example
    # If an image doesn't look like a photo of query it will be deleted
    # Google images are full of pictures, schemes and other garbage 
    
    images_format= ['jpg', 'png']
    directory = 'data/{}'.format(query.replace(" ", "_"))
    files = [f for f in os.listdir(directory) if f[-3:] in images_format]

    bad_pictures_dir = '{}/bad {}'.format(directory, query)
    if not os.path.isdir(bad_pictures_dir):
        os.mkdir(bad_pictures_dir)
    
    
    for file in tqdm(files):
        image = preprocess(Image.open("{}/{}".format(directory, file))).unsqueeze(0).to(device)
        text = clip.tokenize(LIST_TO_COMPARE_WITH).to(device)

        with torch.no_grad():
            image_features = model.encode_image(image)
            text_features = model.encode_text(text)

            # we check how close the picture to the text query we were looking for
            a = image_features[0].tolist()
            b = text_features[0].tolist()
            result = 1 - spatial.distance.cosine(a, b) # cosinus distance between picture vector and text vector

            
        if result < THREASHOLD:
            print("{}/{}".format(directory, file), 'is bad')
            old_directory = "{}/{}".format(directory, file)
            new_directory = "{}/{}".format(bad_pictures_dir, file)
            os.replace(old_directory, new_directory) # we put suspecious file to bad_pictures_dir
                

In [None]:
for query in queries:
    print(query, '...')
    find_suspecious_images(query)

print('done')

## We can now check which images the algorithms consider suspicious before deleting them. All these images are in a new folder called "bad ..." corresponding to a specific query.

## The cell below will delete all suspecious images - don't execute it if you want to check them first

In [None]:
for query in queries:
    bad_pictures_dir = 'data/{}/bad {}'.format(query.replace(" ", "_"), query)

    if os.path.isdir(bad_pictures_dir):
        shutil.rmtree(bad_pictures_dir)
        print(query, '- suspecious images removed')

In [None]:
for query in queries:
    directory = 'data/{}'.format(query.replace(" ", "_"))
    
    print("There are {} photos of {}".format(len(os.listdir(directory)), query))

In [None]:
!ls data

In [None]:
!ls ../input/pictures-of-famous-places/