In [19]:
from bs4 import BeautifulSoup
import pandas as pd
from PIL import Image
import requests
from io import BytesIO
import numpy as np
from skimage import measure
import pickle
import imagehash
from string import ascii_lowercase

In [20]:
color_dict = {}

In [21]:
flag_df = pd.read_csv("flag_df.csv", index_col = "country")

In [12]:
flag_df.head()

Unnamed: 0_level_0,flag
country,Unnamed: 1_level_1
Afghanistan,flags/Afghanistan.pkl
Albania,flags/Albania.pkl
Algeria,flags/Algeria.pkl
Andorra,flags/Andorra.pkl
Angola,flags/Angola.pkl


In [34]:
def makeArray(npy_file):
    with open(npy_file,'rb') as f:
        arr = pickle.load(f)
        return arr

In [35]:
flag_df['flag'] = flag_df['flag'].apply(makeArray)

In [36]:
flag_df.head()

Unnamed: 0_level_0,flag
country,Unnamed: 1_level_1
Afghanistan,"[[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], ..."
Albania,"[[[255, 0, 0], [255, 0, 0], [255, 0, 0], [255,..."
Algeria,"[[[0, 102, 51], [0, 102, 51], [0, 102, 51], [0..."
Andorra,"[[[0, 0, 153], [0, 0, 153], [0, 0, 153], [0, 0..."
Angola,"[[[204, 0, 51], [204, 0, 51], [204, 0, 51], [2..."


In [16]:
len(flag_df)

195

In [22]:
def get_web_safe_colors():
    lst = []
    for i in range(0, 256, 51):
        for j in range(0, 256, 51):
            for k in range(0, 256, 51):
                tup = (i, j, k)
                lst.append(tup)
                
    return lst

In [23]:
def find_nearest_color(color, color_dict):
    if color in color_dict:
        return color_dict[color]
    
    web_safe = get_web_safe_colors()
    
    min_dist = 1000
    nearest_color = (0, 0, 0)
    for cur_col in web_safe:
        dist = sum(tuple(map(lambda i, j: abs(i - j), color, cur_col)))
        if dist < min_dist:
            min_dist = dist
            nearest_color = cur_col
            
    color_dict[color] = nearest_color
    return nearest_color

In [24]:
def fix_image_color(img):
    # Get the size of the image
    width, height = img.size

    # Process every pixel
    for x in range(width):
        for y in range(height):
            current_color = img.getpixel( (x,y) )
            new_color = find_nearest_color(current_color, color_dict)
            img.putpixel( (x,y), new_color)

In [25]:
def process_img(url):
    # getting the image from the url
    response = requests.get(url)
    image_bytes = BytesIO(response.content)
    img = Image.open(image_bytes)
    
    # converting it to RGB
    img = img.convert("RGB")
    
    # resizing to 180 x 90
    img = img.resize((180, 90))
    
    # standardizing the colors to 'web safe' colors
    fix_image_color(img)
    
    # converting an array of pixels
    img = np.array(img)
    return img

In [26]:
def mse(imageA, imageB):
    err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
    err /= float(imageA.shape[0] * imageA.shape[1])

    return err

In [27]:
def closest_flag(country):
    flag = flag_df["flag"].loc[country]
    max_ssim = -1
    closest_country = 0
    for c in flag_df.index:
        cur_flag = flag_df["flag"].loc[c]
        dist = measure.compare_ssim(flag, cur_flag, multichannel = True)
        if dist > max_ssim and dist != 1:
            max_ssim = dist
            closest_country = c
            
    print(max_ssim)
    print(closest_country)

In [28]:
def visualize_flag(flag_pixels):
    Image.fromarray(flag_pixels, "RGB").show()

In [29]:
def identify_flag_ssim(url):
    flag = process_img(url)
    
    max_ssim = -1
    max_index = 0
    for c in flag_df.index:
        cur_flag = flag_df["flag"].loc[c]
        dist = measure.compare_ssim(flag, cur_flag, multichannel = True)
        if dist > max_ssim:
            max_ssim = dist
            max_index = c
            
#     print(max_ssim)
    return max_index

In [30]:
def identify_flag_mse(url):
    flag = process_img(url)
    
    min_mse = -1
    max_index = 0
    for c in flag_df.index:
        cur_flag = flag_df["flag"].loc[c]
        error = mse(flag, cur_flag)
        if error < min_mse or min_mse == -1:
            min_mse = error
            max_index = c
            
#     print(min_mse)
    return max_index

In [31]:
def identify_flag_hash(url):
    flag = process_img(url)
    
    min_hash = -1
    max_index = 0
    for c in flag_df.index:
        cur_flag = flag_df["flag"].loc[c]
        firsthash = imagehash.average_hash(Image.fromarray(flag))
        otherhash = imagehash.average_hash(Image.fromarray(cur_flag))
        hash_dist = firsthash - otherhash
        
        if hash_dist < min_hash or min_hash == -1:
            min_hash = hash_dist
            max_index = c
            
#     print(min_hash)
    return max_index

In [37]:
arr1 = flag_df["flag"].loc["Georgia"]
arr2 = flag_df["flag"].loc["Denmark"]
sum(sum(arr1 * arr2))**2

array([68, 64, 64], dtype=uint8)

In [38]:
url = "https://www.cia.gov/library/publications/the-world-factbook/attachments/flags/AL-flag.gif"
identify_flag_ssim(url)
identify_flag_mse(url)
identify_flag_hash(url)

'Albania'

In [86]:
flag_df.head()

Unnamed: 0_level_0,flag
country,Unnamed: 1_level_1
Afghanistan,"[[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], ..."
Albania,"[[[255, 0, 0], [255, 0, 0], [255, 0, 0], [255,..."
Algeria,"[[[0, 102, 51], [0, 102, 51], [0, 102, 51], [0..."
Andorra,"[[[0, 0, 153], [0, 0, 153], [0, 0, 153], [0, 0..."
Angola,"[[[204, 0, 51], [204, 0, 51], [204, 0, 51], [2..."


In [64]:
visualize_flag(flag_df["flag"].loc["Georgia"])

In [57]:
for index in flag_df.index:
    print(index)

Afghanistan
Albania
Algeria
Andorra
Angola
Antigua And Barbuda
Argentina
Armenia
Australia
Austria
Azerbaijan
The Bahamas
Bahrain
Bangladesh
Barbados
Belarus
Belgium
Belize
Benin
Bhutan
Bolivia
Bosnia And Herzegovina
Botswana
Brazil
Brunei
Bulgaria
Burkina Faso
Burundi
Cambodia
Cameroon
Canada
Cape Verde
The Central African Republic
Chad
Chile
China
Colombia
The Comoros
The Democratic Republic Of The Congo
The Republic Of The Congo
Costa Rica
Croatia
Cuba
Cyprus
The Czech Republic
Denmark
Djibouti
Dominica
The Dominican Republic
East Timor
Ecuador
Egypt
El Salvador
Equatorial Guinea
Eritrea
Estonia
Eswatini
Ethiopia
Fiji
Finland
France
Gabon
The Gambia
Georgia (Country)
Germany
Ghana
Greece
Grenada
Guatemala
Guinea
Guinea-Bissau
Guyana
Haiti
Honduras
Hungary
Iceland
India
Indonesia
Iran
Iraq
Ireland
Israel
Italy
Ivory Coast
Jamaica
Japan
Jordan
Kazakhstan
Kenya
Kiribati
North Korea
South Korea
Kuwait
Kyrgyzstan
Laos
Latvia
Lebanon
Lesotho
Liberia
Libya
Liechtenstein
Lithuania
Luxembour

In [17]:
firsthash = imagehash.average_hash(Image.fromarray(process_img(url)))
otherhash = imagehash.average_hash(Image.fromarray(flag_df["flag"].loc["Paraguay"]))

print(firsthash - otherhash)

33


In [20]:
visualize_flag(process_img(url))

In [49]:
def test_with_cia(identify_flag_func):
    html = requests.get("https://www.cia.gov/library/publications/the-world-factbook/docs/flagsoftheworld.html").text
    soup = BeautifulSoup(html, "lxml")
    
    num_correct = 0
    num_total = 0
    
    for letter in ascii_lowercase:
        test_flags = soup.find_all("li", class_ = "flag appendix-entry ln-" + letter)
        
        for flag in test_flags:
            
            flag_country = flag.find("span").get_text().title()
            if "," in flag_country:
                split = flag_country.split(", ")
                flag_country = split[1] + " " + split[0]
                
#             print(flag_country, flag_country in flag_df.index)
            flag_img = flag.find("img").get("src").replace("..", "https://www.cia.gov/library/publications/the-world-factbook")
            if flag_country in flag_df.index or "The " + flag_country in flag_df.index:
                num_total += 1
                identified_flag = identify_flag_func(flag_img)
                if identified_flag == flag_country or identified_flag == "The " + flag_country:
                    num_correct+=1
            else:
                print(flag_country, "is not in the dataframe")
    print(num_correct / num_total)
    print(num_total)

In [68]:
def test_two(identify_flag_func):
    html = requests.get("https://flagpedia.net/index").text
    soup = BeautifulSoup(html, "lxml")
    
    num_correct = 0
    num_total = 0

    test_flags = soup.find("ul", class_ = "flag-grid").find_all("li")

    for flag in test_flags:

        flag_country = flag.find("span").get_text().title()

        if "," in flag_country:
            split = flag_country.split(", ")
            flag_country = split[1] + " " + split[0]

        flag_img = "https://flagpedia.net/" + flag.find("img").get("src")
        if flag_country in flag_df.index or "The " + flag_country in flag_df.index:
            num_total += 1
            identified_flag = identify_flag_func(flag_img)
            if identified_flag == flag_country or identified_flag == "The " + flag_country:
                num_correct+=1
        else:
            print(flag_country, "is not in the dataframe")
    print(num_correct / num_total)
    print(num_total)

In [64]:
test_one(identify_flag_ssim)

Akrotiri is not in the dataframe
American Samoa is not in the dataframe
Anguilla is not in the dataframe
Aruba is not in the dataframe
Ashmore And Cartier Islands is not in the dataframe
Bermuda is not in the dataframe
Bouvet Island is not in the dataframe
British Indian Ocean Territory is not in the dataframe
British Virgin Islands is not in the dataframe
Burma is not in the dataframe
Cabo Verde is not in the dataframe
Cayman Islands is not in the dataframe
Christmas Island is not in the dataframe
Clipperton Island is not in the dataframe
Cocos (Keeling) Islands is not in the dataframe
Cook Islands is not in the dataframe
Coral Sea Islands is not in the dataframe
Cote D'Ivoire is not in the dataframe
Curacao is not in the dataframe
Czechia is not in the dataframe
Dhekelia is not in the dataframe
European Union is not in the dataframe
Falkland Islands (Islas Malvinas) is not in the dataframe
Faroe Islands is not in the dataframe
French Polynesia is not in the dataframe
French Southern 

In [69]:
test_two(identify_flag_mse)

Åland Islands is not in the dataframe
American Samoa is not in the dataframe
Anguilla is not in the dataframe
Antarctica is not in the dataframe
Aruba is not in the dataframe
Bermuda is not in the dataframe
Bouvet Island is not in the dataframe
British Indian Ocean Territory is not in the dataframe
Caribbean Netherlands is not in the dataframe
Cayman Islands is not in the dataframe
Christmas Island is not in the dataframe
Cocos Islands is not in the dataframe
Dr Congo is not in the dataframe
Cook Islands is not in the dataframe
Côte D'Ivoire is not in the dataframe
Curaçao is not in the dataframe
Czechia is not in the dataframe
England is not in the dataframe
Falkland Islands is not in the dataframe
Faroe Islands is not in the dataframe
French Guiana is not in the dataframe
French Polynesia is not in the dataframe
French Southern And Antarctic Lands is not in the dataframe
Gibraltar is not in the dataframe
Greenland is not in the dataframe
Guadeloupe is not in the dataframe
Guam is not

In [59]:
mse_test1 = 0.8716577540106952 * 187
mse_test2 = 0.9894736842105263 * 190
print(mse_test1, mse_test2)

163.0 188.0


In [65]:
ssim_test1 = 0.5401069518716578 * 187
ssim_test2 = 0.9842105263157894 * 190
print(ssim_test1, ssim_test2)

101.00000000000001 187.0


In [67]:
hash_test1 = 0.46524064171123 * 187
hash_test2 = 0.6578947368421053 * 190
print(hash_test1, hash_test2)

87.0 125.00000000000001
