In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:98% !important; }</style>"))
display(HTML("<style>.output_result { max-width:98% !important; }</style>"))


## Version info:
- all data (ungranted as well)
- IPFs only
- Counting by: inventors

### Imports

In [2]:
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import plotly.graph_objects as go
import plotly.express as px
import plotly.colors as colors
from datetime import datetime
import re
import os
import csv
from sklearn.preprocessing import MinMaxScaler
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import dendrogram
from sklearn.base import clone
import plotly.io as pio
import math
import itertools as it
from collections import Counter
from random import choice
from nltk.util import ngrams, everygrams
import copy

# Set default color palette
colors_plotly_default = colors.qualitative.Plotly

main_path_mac = '/Users/philippmetzger/Documents/GitHub/battery_patents/'
main_path_ssd = '/Volumes/Samsung Portable SSD T3 Media/'

import sys
packages_path = main_path_mac+'/02 Code'
sys.path.append(packages_path)

from helpers import (current_time_string,
                              image_saver,
                              country_labels_dict,
                              ctry_code_name_dict,
                              message,
                              numbers_dict)


In [3]:
# Create another dictionary that is almost the same as ctry_code_name_dict. Just some values are changed
# due to different country names in UN populations dataset
ctry_code_name_dict_UN = ctry_code_name_dict.copy()

# Define a function for this purpose
def replace_dict_value(dict_, key_, new):
    
    dict_[key_] = new

# Define changes: First values = key, second values = new values
key_new = [
    ('TW', 'China, Taiwan Province of China'),
    ('HK', 'China, Hong Kong SAR'),
    ('SH', 'Saint Helena'),
    ('KP', "Dem. People's Republic of Korea"),
    ('MO', 'China, Macao SAR')
]

# Execute changes
for tuple_ in key_new:
    replace_dict_value(ctry_code_name_dict_UN, tuple_[0], tuple_[1])
    

In [4]:
# Create another dictionary that is almost the same as ctry_code_name_dict. Just some values are changed
# due to different country names in world bank work force dataset
ctry_code_name_dict_world_bank = ctry_code_name_dict.copy()

# Define a function for this purpose
def replace_dict_value(dict_, key_, new):
    
    dict_[key_] = new

# Define changes: First values = key, second values = new values
key_new = [
    ('KR', 'Korea, Rep.'),
    ('US', 'United States'),
    ('HK', 'Hong Kong SAR, China'),
    ('BS', 'Bahamas, The'),
    ('CZ', 'Czech Republic'),
    ('IR', 'Iran, Islamic Rep.'),
    ('SK', 'Slovak Republic'),
    ('VE', 'Venezuela, RB'),
    ('EG', 'Egypt, Arab Rep.'),
    ('KP', "Korea, Dem. People's Rep."),
    ('EG', 'Egypt, Arab Rep.'),
    ('KG', 'Kyrgyz Republic'),
    ('LA', 'Lao PDR'),
    ('MO', 'Macao SAR, China'),
    ('LC', 'St. Lucia'),
    ('TZ', 'Tanzania'),
    ('VN', 'Vietnam')
]

# Execute changes
for tuple_ in key_new:
    replace_dict_value(ctry_code_name_dict_world_bank, tuple_[0], tuple_[1])
    

In [5]:
message()


executing a function from helpers.py


In [6]:
current_time_string()


'2022-01-27_1449'

## Read the whole dataset and reduce it to what we are interested in

In [7]:
# Read the whole dataset
dataset_name = 'data_batteries_2022-01-26_1852'

path = main_path_ssd+'Dataset saves/04 From 15 Nov 2021 (release of 2021 Autumn edition)/01 Preprocessed/03 final - technologies tagged/'+dataset_name+'.csv'

print('Loading data from:')
print(path)

data = pd.read_csv(path, delimiter = ";", low_memory = False, na_values=['', ' ', '  '], keep_default_na = False)

print('Number of rows:', len(data))

print('Distinct values in column "granted":', pd.unique(data['granted']))

# Reduce it to non active parts, electrodes, secondary cells, charging, redox flow, and Nickel-Hydrogen
a = (data['non_active_parts_electrodes_secondary_cells'] == 1)
b = (data['charging'] == 1)
c = (data['is_Redox flow'] == 1)
d = (data['is_Nickel–hydrogen'] == 1)

data_reduced = data[a | b | c | d].copy()
del data
data = data_reduced

# Futher reduce it to IPFs only
data_ipf = data[data['tag'] == 'IPF'].copy()
ipf_percentage = (len(set(data_ipf['docdb_family_id'])) / len(set(data['docdb_family_id']))) * 100
print('Percentage of IPFs in relation to all battery patent families:'+str(round(ipf_percentage, 2))+'%')
del data
data = data_ipf


Loading data from:
/Volumes/Samsung Portable SSD T3 Media/Dataset saves/04 From 15 Nov 2021 (release of 2021 Autumn edition)/01 Preprocessed/03 final - technologies tagged/data_batteries_2022-01-26_1852.csv
Number of rows: 4086532
Distinct values in column "granted": ['N' 'Y']
Percentage of IPFs in relation to all battery patent families:19.41%


## Reduce to years we are interested in

In [8]:
print(set(data['earliest_publn_year_this_family_id']))


{1999, 2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019}


In [9]:
data_reduced = data[data['earliest_publn_year_this_family_id'] >= 2000].copy()
del data
data = data_reduced


In [10]:
print(set(data['earliest_publn_year_this_family_id']))


{2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019}


## Define variable only_granted. Only used for creation of plot filenames

In [11]:
only_granted = False


## Make a few checks: Total number of IPFs, number of IPFs related to different categories, distinct values in column 'tag'

In [12]:
len(list(set(data['docdb_family_id'])))


92700

In [13]:
len(list(set(data[data['non_active_parts_electrodes_secondary_cells'] == 1]['docdb_family_id'])))


63282

In [14]:
len(list(set(data[data['charging'] == 1]['docdb_family_id'])))


44039

In [15]:
len(list(set(data[data['is_Redox flow'] == 1]['docdb_family_id'])))


843

In [16]:
set(data['tag'])


{'IPF'}

## In person_ctry_code: Replace NaNs with '  '

In [17]:
data['person_ctry_code'].fillna('  ', inplace = True)
#data['person_ctry_code_imputed'].fillna('  ', inplace = True)


## Dataset integrity check

In [18]:
def check_if_docdb_family_size_is_equal_to_number_of_applications(data_to_check):

    """Version 2 - 3. Jan 2022 (altered version from function in Create_db_4)
    Does not print when unequal. Instad increments a counter that it returns."""
    
    reduced = data_to_check[['docdb_family_id','appln_id','docdb_family_size']].drop_duplicates()

    family_ids = pd.unique(reduced['docdb_family_id'])

    counter = 0
    
    for family_id in tqdm(family_ids):
        
        reduced_this_family_id = reduced[reduced['docdb_family_id'] == family_id]

        len_ = len(reduced_this_family_id)
        docdb_family_size = list(set(reduced_this_family_id['docdb_family_size']))

        if len(docdb_family_size) > 1:
            print(str(family_id)+': There is more than one docdb_family_size.')
            break

        docdb_family_size = docdb_family_size[0]

        if (len_ != docdb_family_size):
            #print(str(family_id)+': docdb_family_size is not equal to the number of applications contained in this family')
            counter += 1
            
    return counter

In [19]:
# Returns 0 if all is as it should be
# Uncomment this to check

#check_if_docdb_family_size_is_equal_to_number_of_applications(data)


## Infer our time frame from data

In [20]:
year_begin = min(data['earliest_publn_year_this_family_id'])
year_end = max(data['earliest_publn_year_this_family_id'])

years = list(range(year_begin, year_end + 1))
print(years)


[2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019]


## Assess missing values situation
### Check the percentage of missing values in person_ctry_code

In [21]:
len(data[data['person_ctry_code'] == '  ']) / len(data)
#len(data[data['person_ctry_code'].isna()]) / len(data)
#len(data[data['person_ctry_code'] == '  ']) / len(data)
#len(data[data['person_ctry_code'].isna()]) / len(data)


0.3070879822843029

### Get the application authority distribution for rows with missing person_ctry_code

In [22]:
nan_data_appln_auth_counts = data[data['person_ctry_code'] == '  '].groupby(by='appln_auth').count().sort_values(by='docdb_family_id', ascending=False).rename(columns=({'docdb_family_id':'count'}))['count']
nan_data_appln_auth_counts[:10]


appln_auth
CN    223386
JP    144619
KR     31137
AU     14767
WO     12088
ES      6481
HK      4170
MX      2521
US      2249
BR      1471
Name: count, dtype: int64

### Share of applications with missing person_ctry_code that are filed in China

In [23]:
nan_share_china = nan_data_appln_auth_counts['CN'] / nan_data_appln_auth_counts.sum()
nan_share_china


0.4956533315582773

### Share of applications with missing person_ctry_code that are filed in Japan

In [24]:
nan_share_japan = nan_data_appln_auth_counts['JP'] / nan_data_appln_auth_counts.sum()
nan_share_japan


0.3208835341365462

### Share of applications with missing person_ctry_code that are filed in the rest of the world

In [25]:
1 - nan_share_china - nan_share_japan


0.18346313430517652

## Define the core function of this notebook

In [26]:
# Version 13. Jan. 2022: Enhanced fractional counting (for technologies)

def get_counts(data_to_function, use_tqdm, nat_intl_separation, count_inventors, technologies):
    
    if count_inventors:
        data_to_function = data_to_function[data_to_function['invt_seq_nr'] > 0]
    else:
        data_to_function = data_to_function[data_to_function['applt_seq_nr'] > 0]
        
    ctry_codes_this = sorted(list(set(data_to_function['person_ctry_code'])))
    
    family_ids_this = pd.unique(data_to_function['docdb_family_id'])

    family_id_ctry_codes = {}
    family_id_coo_appln_auth = {}
    
    if technologies:
        # Dictionary for the number of technologies each patent family is assignted to
        family_id_num_technologies = {}
    
    # A wrapper function for turning tqdm on or off
    def tqdm_wrapper(input_):
        if use_tqdm:
            return tqdm(input_)
        else:
            return input_

    # Loop over family IDs and get all country codes of each family ID. Also delete unknown country code entries
    # (country code = '  ')
    for family_id in tqdm_wrapper(family_ids_this):
        
        data_this_family_id = data_to_function[data_to_function['docdb_family_id']==family_id]

        ##########################################
        # Change in this version (13. Jan. 2022) #
        ##########################################
        # Old code line:
        #ctry_codes_this_family_id = list(set(data_this_family_id['person_ctry_code']))
        # New in this version:
        # Get the person_ctry_codes for each distinct (psn_name, person_ctry_code) pair
        ctry_codes_this_family_id = list(
            data_this_family_id[data_this_family_id['person_ctry_code'] != '  '][['psn_name', 'person_ctry_code']].groupby(
                by = ['psn_name', 'person_ctry_code']
            ).size().reset_index(name='Freq')['person_ctry_code']
        )        
        
        # Not needed anymore
        #try:
        #    ctry_codes_this_family_id.remove('  ')
        #except:
        #    pass
            
        family_id_ctry_codes[family_id] = ctry_codes_this_family_id
        
        if technologies:
            # Save the number of technologies this patent family is assignted to
            num_technologies_this_family_id = list(set(data_this_family_id['technologies_one_hot_sum']))
            
            # Throw error if there is more than one distinct value
            if len(num_technologies_this_family_id) > 1:
                print('More than one one-hot-sum')
                return
            
            family_id_num_technologies[family_id] = num_technologies_this_family_id[0]
        
    # Create a dictionary full of zeros
    ctry_codes_counts = {}
    for key_ in ctry_codes_this:
        
        ctry_codes_counts[key_] = 0
        
        if nat_intl_separation:
            
            key_intl_string = key_+'_intl'
            ctry_codes_counts[key_intl_string] = 0
        
    # Loop over dictionary and increment its values to create the count
    # Increment by 1/k, k being the number of distinct country codes for a given family ID
    nat_int_counter = {'national':0,
                       'international':0,
                       'unknown':0}
    for key_ in list(family_id_ctry_codes):

        size_ = len(family_id_ctry_codes[key_])
        
        # New in this version (13. Jan. 2022). Need this for nat_intl_separation separation mode
        num_distinct_ctry_codes_this_id = len(list(set(family_id_ctry_codes[key_])))
        
        # Get this patent family's number of technologies it is assignted to
        if technologies:
            l = family_id_num_technologies[key_]
        else:
            l = 1

        if size_==0:

            ctry_codes_counts['  '] += 1/l
            
            nat_int_counter['unknown'] += 1/l

        else:
            
            if nat_intl_separation:
                
                # Changed in this version (13. Jan. 2022):
                if num_distinct_ctry_codes_this_id == 1:
                    
                    ctry_codes_counts[family_id_ctry_codes[key_][0]] += 1/l
                    
                    nat_int_counter['national'] += 1/l
                    
                else:
                    
                    for item in family_id_ctry_codes[key_]:
                        
                        item_intl_string = item+'_intl'

                        ctry_codes_counts[item_intl_string] += (1/size_)/l
                        
                    nat_int_counter['international'] += 1/l
            
            else:

                for item in family_id_ctry_codes[key_]:

                    ctry_codes_counts[item] += (1/size_)/l
    
    ctry_codes_counts_sorted = dict(sorted(ctry_codes_counts.items(), key=lambda x:x[1], reverse=True))

    try:
        known_percentage = (1 - (ctry_codes_counts['  '] / sum(ctry_codes_counts.values()))) * 100
    except:
        known_percentage = 100

    if nat_intl_separation:
        return family_id_ctry_codes, ctry_codes_counts_sorted, known_percentage, ctry_codes_this, nat_int_counter
    
    else:
        return family_id_ctry_codes, ctry_codes_counts_sorted, known_percentage
    

## Separation of patent families with applicants from only one country vs. international cooperations

In [27]:
def nat_int_prepare_df(nat_intl_dict, ctry_codes_list):
    
    dict_ = nat_intl_dict.copy()
    list_ = ctry_codes_list.copy()
    
    dict_.pop('  ')
    dict_.pop('  _intl')
    list_.remove('  ')
    
    #####
    
    array_ = []

    for key_ in list_:

        key_intl_string = key_+'_intl'

        if key_ in list(country_labels_dict):
            country_name = country_labels_dict[key_]
        else:
            country_name = ctry_code_name_dict[key_]

        row = [country_name, dict_[key_], dict_[key_intl_string]]

        array_.append(row)
        
    #####
    
    num_countries_to_plot = 6

    col_1 = 'Country'
    col_2 = 'Patent families with national inventors only'
    col_3 = 'Patent families with inventors from at least one other country'

    nat_intl_df = pd.DataFrame(array_, index=None, columns = [col_1, col_2, col_3]).sort_values([col_2, col_3], ascending=True)
    nat_intl_df = nat_intl_df[-num_countries_to_plot:]

    #####
    
    col_4 = 'intl/(nat+intl)'
    col_5 = 'nat+intl'

    nat_intl_ratio_df = nat_intl_df.copy()
    nat_intl_ratio_df[col_4] = nat_intl_ratio_df['Patent families with inventors from at least one other country'] / (nat_intl_ratio_df['Patent families with national inventors only'] + nat_intl_ratio_df['Patent families with inventors from at least one other country'])
    nat_intl_ratio_df[col_5] = (nat_intl_ratio_df['Patent families with national inventors only'] + nat_intl_ratio_df['Patent families with inventors from at least one other country'])
    nat_intl_ratio_df.sort_values([col_5], ascending=True, inplace = True)
    
    return nat_intl_ratio_df


In [28]:
def nat_int_plot(df, years):
    
    col_1 = 'Country'
    col_2 = 'Patent families with national inventors only'
    col_3 = 'Patent families with inventors from at least one other country'
    col_4 = 'intl/(nat+intl)'
    col_5 = 'nat+intl'
    
    x_label_string = 'Total number of battery patent families in '+str(min(years))+'-'+str(max(years))
    #y_label_string = 'Country code'

    nat_intl_counts_plot = px.bar(df,
                                  y = "Country",
                                  x=[col_2, col_3])
    nat_intl_counts_plot.update_xaxes(title=x_label_string)

    nat_intl_counts_plot.update_layout(
        legend=dict(yanchor="bottom",
                    y=0.02,
                    xanchor="right",
                    x=0.99,
                    title=''))

    for i in range(len(df)):
        nat_intl_counts_plot.add_annotation(x = list(df[col_5])[i],
                                            y = i,
                                            text = str(round(list(df[col_4])[i]*100, 2))+'%',
                                            showarrow = False,
                                            xshift = 25)

    nat_intl_counts_plot.show()

    # Save this plot as eps
    if only_granted:
        filename = 'nat_intl_counts_only_granted'
    else:
        filename = 'nat_intl_counts_all_appln'

    #image_saver(nat_intl_counts_plot, filename, True)
    

### Whole time period

In [29]:
family_id_ctry_codes_whole_time_nat_intl, ctry_codes_counts_whole_time_nat_intl, known_percentage_whole_time_nat_intl, ctry_codes_whole_time, nat_int_counter_whole_time = get_counts(
    data,
    True,
    True,
    True,
    False)


  0%|          | 0/92667 [00:00<?, ?it/s]

### Split in two time periods

In [None]:
years_nat_int_first_part = [2000, 2013]

data_nat_int_first_part = data[
    ((data['earliest_publn_year_this_family_id'] >= years_nat_int_first_part[0]) & (data['earliest_publn_year_this_family_id'] <= years_nat_int_first_part[1]))]

family_id_ctry_codes_first_part_nat_intl, ctry_codes_counts_first_part_nat_intl, known_percentage_first_part_nat_intl, ctry_codes_first_part, nat_int_counter_first_part = get_counts(
    data_nat_int_first_part,
    True,
    True,
    True,
    False)

years_nat_int_second_part = [2014, 2018]

data_nat_int_second_part = data[
    ((data['earliest_publn_year_this_family_id'] >= years_nat_int_second_part[0]) & (data['earliest_publn_year_this_family_id'] <= years_nat_int_second_part[1]))
]

family_id_ctry_codes_second_part_nat_intl, ctry_codes_counts_second_part_nat_intl, known_percentage_second_part_nat_intl, ctry_codes_second_part, nat_int_counter_second_part = get_counts(
    data_nat_int_second_part,
    True,
    True,
    True,
    False)


  0%|          | 0/39580 [00:00<?, ?it/s]

  0%|          | 0/41949 [00:00<?, ?it/s]

In [None]:
print(set(data_nat_int_first_part['earliest_publn_year_this_family_id']))


In [None]:
print(set(data_nat_int_second_part['earliest_publn_year_this_family_id']))


### Get the dataframes

In [None]:
nat_intl_df_whole_whole_time = nat_int_prepare_df(ctry_codes_counts_whole_time_nat_intl,
                                                  ctry_codes_whole_time)
nat_intl_df_whole_whole_time


In [None]:
nat_intl_df_whole_first_part = nat_int_prepare_df(ctry_codes_counts_first_part_nat_intl,
                                                  ctry_codes_first_part)
nat_intl_df_whole_first_part


In [None]:
nat_intl_df_whole_second_part = nat_int_prepare_df(ctry_codes_counts_second_part_nat_intl,
                                                   ctry_codes_second_part)
nat_intl_df_whole_second_part


### Plot national/international counts

In [None]:
nat_int_plot(nat_intl_df_whole_whole_time, years)


In [None]:
nat_int_plot(nat_intl_df_whole_first_part, years_nat_int_first_part)


In [None]:
nat_int_plot(nat_intl_df_whole_second_part, years_nat_int_second_part)


### Grouped bar plot comparing percentages from first and second time periods

In [None]:
colname_first_part = 'Share of co-inventions in '+str(min(years_nat_int_first_part))+'-'+str(max(years_nat_int_first_part))
colname_second_part = 'Share of co-inventions in '+str(min(years_nat_int_second_part))+'-'+str(max(years_nat_int_second_part))

comparison_df = pd.DataFrame.from_records(
    [
        dict(nat_intl_df_whole_first_part[['Country', 'intl/(nat+intl)']].values),
        dict(nat_intl_df_whole_second_part[['Country', 'intl/(nat+intl)']].values)
    ],
    index = [
        colname_first_part,
        colname_second_part
    ]
)

# Transpose and reverse row order
comparison_df = comparison_df.transpose()[::-1]
comparison_df


In [None]:
countries = list(comparison_df.index)

cols = list(comparison_df)
#cols.reverse()

co_invention_ratio_data = []
for col_ in cols:
    co_invention_ratio_data.append(
        go.Bar(name = col_,
               y = countries,
               x = list(comparison_df[col_]),
               orientation='h')
    )

co_invention_ratio_layout = go.Layout(
    
    legend = dict(
        orientation="h",
        yanchor="top",
        y = -0.07,
        xanchor="left",
        x = 0,
        #traceorder = 'reversed'
    ),
    
    plot_bgcolor = "white",
    
    yaxis=dict(
        #title = 'Country',
        color = 'black',
        showgrid = False,
        #gridwidth = 1,
        #gridcolor = 'black',
        #type = "log",
        dtick = 0.5
    ),
    
    xaxis=dict(
        #title = '%',
        color = 'black',
        showgrid = True,
        gridwidth = 1,
        gridcolor = 'black',
        tickformat = '%'
        #dtick = 1
    )
)
    
co_invention_ratio_plot = go.Figure(data = co_invention_ratio_data, layout = co_invention_ratio_layout)

co_invention_ratio_plot.update_yaxes(autorange="reversed")
#co_invention_ratio_plot.update_legend(traceorder="reversed")

# Add annotations
shift_dir = -1
for col_ in list(comparison_df):
    
    shift_dir = shift_dir * (-1)
    
    for i in range(len(comparison_df)):
        
        co_invention_ratio_plot.add_annotation(x = list(comparison_df[col_])[i],
                                               y = i,
                                               text = str(round(list(comparison_df[col_])[i]*100, 1))+'%',
                                               showarrow = False,
                                               xshift = 25,
                                               yshift = shift_dir * 10,
                                               bgcolor = 'white',
                                               #opacity=0.1
                                              )
    
# Change the bar mode
co_invention_ratio_plot.update_layout(barmode='group')
co_invention_ratio_plot.show()

#image_saver(co_invention_ratio_plot, 'co_invention_ratio_plot', True)


## Count patents for the whole time period

In [None]:
family_id_ctry_codes_whole_time, ctry_codes_counts_whole_time, known_percentage_whole_time = get_counts(data,
                                                                                                        True,
                                                                                                        False,
                                                                                                        True,
                                                                                                        False)

#, ctry_codes, nat_int_counter


In [None]:
known_percentage_whole_time


In [None]:
#ctry_codes_counts_whole_time


### Prepare plot for countries' total over the whole timespan

In [None]:
path = main_path_mac + '03 Extra data/PATSTAT reference tables/TLS801_COUNTRY.csv'
continents_patstat = pd.read_csv(path, delimiter = ';').dropna().reset_index(drop=True)
continents_patstat = continents_patstat[['ctry_code', 'continent']]
continents_patstat


In [None]:
ctry_code_continent_dict = {}

for i in range(len(continents_patstat)):
    
    ctry_code_continent_dict[continents_patstat.loc[i,'ctry_code']] = continents_patstat.loc[i,'continent']
    

In [None]:
# Delete the counter for missing ctry_code values and create a sorted the dictionary with the other counters
data_whole_timespan_plot = ctry_codes_counts_whole_time.copy()
data_whole_timespan_plot.pop('  ')
data_whole_timespan_plot = dict(sorted(data_whole_timespan_plot.items(), key=lambda x:x[1], reverse=False))


In [None]:
# Save this data in order to use it in Co_occurences for size of country dots in network visualisation

if False:
    
    # open file for writing, "w" is writing
    if only_granted:
        file_ = open("country_sizes_only_granted.csv", "w")
    else:
        file_ = open("country_sizes_all_appln.csv", "w")

    # Create a csv writer
    writer_ = csv.writer(file_)

    # loop over dictionary keys and values
    for key_, value_ in data_whole_timespan_plot.items():

        # write every key and value to file
        writer_.writerow([key_, value_])

    # Close the file
    file_.close()


In [None]:
number_countries = 8

countries_written = [country_labels_dict[item] for item in list(data_whole_timespan_plot.keys())[-number_countries:]]
continents = [ctry_code_continent_dict[item] for item in list(data_whole_timespan_plot.keys())[-number_countries:]]

df_countries = pd.DataFrame(data=[list(data_whole_timespan_plot.values())[-number_countries:],
                                  continents],
                            columns=countries_written, index=['Count', 'Continent']).transpose().reset_index(drop=False)

df_countries = df_countries.rename(columns={'index':'Country'})

df_countries = df_countries.sort_values(by='Count', ascending=False)

df_countries


In [None]:
top_countries_title = 'Top ' + str(number_countries) + ' countries of origin of battery patents, ' +str(min(years))+'-'+str(max(years))

x_label_string = 'Total number of battery patent families in '+str(year_begin)+'-'+str(year_end)
y_label_string = 'Country'

top_countries_plot = px.bar(df_countries,
                            x = 'Count', 
                            y = 'Country', 
                            color = 'Continent', 
                            orientation='h',
                            labels = {'Count':x_label_string, 'Country':y_label_string}
                           )

top_countries_plot.update_layout(yaxis={'categoryorder':'total ascending'},
                                 #title=top_countries_title
                                )

top_countries_plot.show()

# Save this plot as eps
if only_granted:
    filename = 'totals_by_country_only_granted'
else:
    filename = 'totals_by_country_all_appln'
    
#image_saver(top_countries_plot, filename, True)

top_countries_title


## Count patents for each year

In [None]:
family_id_ctry_codes_list = []
ctry_codes_counts_list = []
known_percentage_list = []

for year in years:
    
    print('Year', str(year))
    
    data_to_function = data[data['earliest_publn_year_this_family_id']==year]
    
    family_id_ctry_codes, ctry_codes_counts, known_percentage = get_counts(data_to_function,
                                                                           True,
                                                                           False,
                                                                           True,
                                                                           False)
    
    family_id_ctry_codes_list = family_id_ctry_codes_list + [family_id_ctry_codes]
    ctry_codes_counts_list = ctry_codes_counts_list + [ctry_codes_counts]
    known_percentage_list = known_percentage_list + [known_percentage]
        

In [None]:
print(known_percentage_list)


In [None]:
df = pd.DataFrame.from_records(ctry_codes_counts_list)
df.insert(loc=0, column='year', value=years)
df.rename(columns={'  ':'unknown'}, inplace=True)
df.fillna(0, inplace=True)
df_with_unknowns = df
#df


In [None]:
# Save this result to csv
#years_string = str(year_begin)+'-'+str(year_end)
#filename = 'country_counts_yearly_'+years_string+'.csv'
#df.to_csv(path_or_buf=filename, sep=';', index=False)


### Yearly totals plot

In [None]:
totals_series = df.drop(columns='year').sum(axis=1)
totals_series


In [None]:
totals_df = pd.DataFrame()

totals_df['year'] = years
totals_df['count'] = totals_series

totals_df.to_csv(path_or_buf= 'total_yearly_counts', sep=';', index=False)

totals_df


In [None]:
cut = 2014

totals_df_first_part = totals_df[totals_df['year'] <= cut]
#totals_df_first_part

totals_df_second_part = totals_df[totals_df['year'] > cut]
#totals_df_second_part

years_first_part = list(totals_df_first_part['year'])
sum_first_part = totals_df_first_part['count'].sum()
print(years_first_part)
print(sum_first_part)
print()

years_second_part = list(totals_df_second_part['year'])
sum_second_part = totals_df_second_part['count'].sum()
print(years_second_part)
print(sum_second_part)


In [None]:
totals_data = [dict(type='bar',
                        x=years,
                        y=totals_series
                       )]


In [None]:
# Compute mean increase year-over-year

increase = []

for i in range(1,len(totals_series)):
    
    increase_this_year = (totals_series[i] - totals_series[i-1]) / totals_series[i-1]
    print(increase_this_year)
    increase.append(increase_this_year)
    
sum(increase) / len(increase)


In [None]:
totals_title = 'Global development of the number of battery patent families, '+str(min(years))+'-'+str(max(years))

totals_layout =dict(#title=dict(text = totals_title),
                    yaxis=dict(title='Number of battery patent families'),
                    xaxis=dict(title='Year'),
                          legend = dict(
                            xanchor="center",
                            yanchor="top",
                            y=-0.18, 
                            x=0.5   
                          )
                )


In [None]:
totals_plot = go.Figure(data = totals_data, layout=totals_layout)
           
totals_plot.update_xaxes(dtick=1)

totals_plot.show()

# Save this plot as eps
if only_granted:
    filename = 'development_global_only_granted'
else:
    filename = 'development_global_all_appln'
    
#image_saver(totals_plot, filename, True)

totals_title


### Yearly, by country plot

In [None]:
df.drop(columns='unknown', inplace=True)
df


In [None]:
# Get the labels of the countries with the highest totals
number_of_countries = 8
highest_total_labels = list(df.drop(columns='year').sum().sort_values(ascending=False).index[0:number_of_countries])
highest_total_labels


In [None]:
for i, country in enumerate(highest_total_labels):
    
    x = df['year']
    y = df[country]
    country_name = country_labels_dict[country]
    #print(country)
    
    line_width = 3
    
    if True:
        if i == 0:
            person_ctry_code_data = [dict(type = 'scatter',
                                          x = x,
                                          y = y,
                                          name = country_name,
                                          line_width = line_width
                                         )
                                    ]
        else:
            if True:
                person_ctry_code_data.append(dict(type = 'scatter',
                                                  x = x,
                                                  y = y,
                                                  name = country_name,
                                                  line_width = line_width
                                                 )
                                            )
                

In [None]:
person_ctry_code_title = "Development of the absolute number of battery IPFs:<br>Counted by inventors' countries of origin, "+str(min(years))+'-'+str(max(years))

person_ctry_code_layout =dict(
    title=dict(
        text = person_ctry_code_title,
        y = 0.9,
        x = 0.5,
        xanchor = 'center',
        yanchor = 'top',
        font = dict(color = 'black')
    ),
    yaxis = dict(
        color = 'black',
        title = 'Number of IPFs',
        showgrid = True,
        gridwidth = 1,
        gridcolor = 'black',
        zerolinecolor = 'black',
        zerolinewidth = 1
    ),
    xaxis = dict(
        color = 'black',
        title='Year',
        dtick = 1
    ),
    legend = dict(
        xanchor = "left",
        yanchor = "middle",
        y = 0.5, 
        x = 1,
        orientation = "v"
    ),
    plot_bgcolor = 'white'
)


In [None]:
person_ctry_code_plot = go.Figure(data = person_ctry_code_data, layout=person_ctry_code_layout)

#person_ctry_code_plot.update_yaxes(type="log",
#                            dtick=1)
           
person_ctry_code_plot.update_xaxes(dtick=1)

person_ctry_code_plot.show()

# Save this plot as eps
if only_granted:
    filename = 'development_by_country_only_granted'
else:
    filename = 'development_by_country_all_appln'

image_saver(person_ctry_code_plot, filename, True)

person_ctry_code_title


### China's share plot

In [None]:
china = df_with_unknowns['CN']
all_ = df_with_unknowns.drop(columns=['year'])
chinas_share = (china / all_.sum(axis=1)) * 100


In [None]:
china_share_data = [dict(type='scatter',
                        x=df_with_unknowns['year'],
                        y=chinas_share,
                        name="China's share of the number of battery patent families"+str(min(years))+'-'+str(max(years))
                       )]


In [None]:
chinas_share_title = "Development of China's share of the number of battery patent families, "+str(min(years))+'-'+str(max(years))

china_share_layout =dict(#title=dict(text = chinas_share_title),
                 yaxis=dict(title="China's share [%]"),
                 xaxis=dict(title='Year'),
                          legend = dict(
                            xanchor="center",
                            yanchor="top",
                            y=-0.18, 
                            x=0.5   
                          )
                )


In [None]:
china_share_plot = go.Figure(data = china_share_data, layout=china_share_layout)
           
china_share_plot.show()

# Save this plot as eps
if only_granted:
    filename = 'chinas_share_only_granted'
else:
    filename = 'chinas_share_all_appln'

#image_saver(china_share_plot, filename, True)

chinas_share_title


### Herfindahl–Hirschman index plot

https://en.wikipedia.org/wiki/Herfindahl%E2%80%93Hirschman_Index

In [None]:
countries_shares_list = []

for country in list(all_):
    
    country_counts = df_with_unknowns[country]
    
    country_shares = (country_counts / all_.sum(axis=1)) * 100
    
    countries_shares_list.append(country_shares)
    
countries_shares_df = pd.concat(countries_shares_list, axis=1)

#countries_shares_df


In [None]:
# Compute Herfindahl–Hirschman Index for each year

hh_indices_list = []

for i in range(len(countries_shares_df)):
        
    shares_list = countries_shares_df.iloc[i,:]
    
    hh_index = 0
    
    for item in shares_list:
        
        hh_index += item**2
                
    hh_indices_list.append(hh_index)
    
#hh_indices_list


In [None]:
herfindahl_data = [dict(type='scatter',
                        x=df_with_unknowns['year'],
                        y=hh_indices_list
                       )]


In [None]:
herfindahl_title = "Development of the Herfindahl–Hirschman Index (by countries), "+str(min(years))+'-'+str(max(years))

herfindahl_layout =dict(#title=dict(text = herfindahl_title),
                 yaxis=dict(title="Herfindahl–Hirschman Index"),
                 xaxis=dict(title='Year'),
                          legend = dict(
                            xanchor="center",
                            yanchor="top",
                            y=-0.18, 
                            x=0.5   
                          )
                       )


In [None]:
herfindahl_plot = go.Figure(data = herfindahl_data, layout=herfindahl_layout)
           
herfindahl_plot.show()

# Save this plot as eps
if only_granted:
    filename = 'Herfindahl_Hirschman_index_only_granted'
else:
    filename = 'Herfindahl_Hirschman_index_all_appln'

#image_saver(herfindahl_plot, filename, True)

herfindahl_title


### Per labor force plot

In [None]:
if False:
    
    path = main_path_mac + '03 Analysis/01 Country counts/WPP2019_TotalPopulationBySex.csv'
    populations = pd.read_csv(path, delimiter = ',')

    # Note: Values are in one thousand inhabitants!


    pop_dict = {}
    for country in highest_total_labels:
        country_name = ctry_code_name_dict_UN[country]
        population = (populations[(populations['Location'] == country_name) & (populations['Time'] >= min(years)) & (populations['Time'] <= max(years))]['PopTotal']).values
        pop_dict[country_name] = population

    #pop_dict
    # Note: Values are in one thousand inhabitants!
    

#### Create a labor force dict

In [None]:
path = main_path_mac + '03 Extra data/Worldbank - Labor force, total/API_SL.TLF.TOTL.IN_DS2_en_csv_v2_3471351/API_SL.TLF.TOTL.IN_DS2_en_csv_v2_3471351.csv'
labor_force = pd.read_csv(path, delimiter = ',')

#labor_force

# Note: Values are NOT in one thousand inhabitants! They are one to one.


In [None]:
# Get our year as strings

years_strings = []
for year in range(min(years), max(years)+1):
    
    years_strings.append(str(year))
    
#years_strings


In [None]:
# Save values in a dict and divide them by 1000

labor_force_dict = {}
for i in range(len(labor_force)):
    
    labor_force_this_country = labor_force.iloc[i,:]
    
    name = labor_force_this_country['Country Name']
    labor_force_values = labor_force_this_country[years_strings]
    
    labor_force_dict[name] = labor_force_values.values / 1000
    
#labor_force_dict


In [None]:
# Taiwan (no data at world bank):

# This is copy-pasted from taiwan_labor_force from (https://eng.stat.gov.tw/ct.asp?xItem=42761&ctNode=1609&mp=5):

labor_force_taiwan = np.array(
    [
        9784,
        9832,
        9969,
        10076,
        10240,
        10371,
        10522,
        10713,
        10853,
        10917,
        11070,
        11200,
        11341,
        11445,
        11535,
        11638,
        11727,
        11795,
        11874,
        11946
    ]
)

print(len(labor_force_taiwan))

labor_force_dict['Chinese Taipei'] = labor_force_taiwan


In [None]:
pop_dict = {}
for country in highest_total_labels:
    
    country_name = ctry_code_name_dict_world_bank[country]
    
    pop_dict[country_name] = labor_force_dict[country_name]
    
#pop_dict


In [None]:
for i, country in enumerate(highest_total_labels):
    
    country_name = country_labels_dict[country]
    #country_name_UN = ctry_code_name_dict_UN[country]
    country_name_UN = ctry_code_name_dict_world_bank[country]

    x = df['year']
    y = (df[highest_total_labels[i]] / pop_dict[country_name_UN])  * (10**3)
    
    if country == 'CN':
        count_per_1M_CN = y
    
    #print(country)
    
    if True:
        if i == 0:
            person_ctry_code_pop_data = [dict(type='scatter',
                                x=x,
                                y=y,
                                name=country_name
                               )]
        else:
            if True:
                person_ctry_code_pop_data.append(dict(type='scatter',
                                    x=x,
                                    y=y,
                                    name=country_name
                                   ))


#### Side research: Compute mean increase year-over-year for China

In [None]:

increase_CN = []

for i in range(1,len(count_per_1M_CN)):
    
    increase_this_year = (count_per_1M_CN[i] - count_per_1M_CN[i-1]) / count_per_1M_CN[i-1]
    print(increase_this_year)
    increase_CN.append(increase_this_year)
    
sum(increase_CN) / len(increase_CN)


In [None]:
#person_ctry_code_pop_title = "Development of the number of battery patent families per population size by countries of origin, " +str(min(years))+'-'+str(max(years))
person_ctry_code_pop_title = "Development of the number of battery patent families per 1M workers by countries of origin of the inventors, " +str(min(years))+'-'+str(max(years))

person_ctry_code_pop_layout =dict(#title=dict(text =  person_ctry_code_pop_title),
                                  yaxis=dict(title='Number of battery patent families per 1M workers'),
                                  xaxis=dict(title='Year'),
                                  legend = dict(
                                      xanchor="left",
                                      yanchor="middle",
                                      y=0.5, 
                                      x=1,
                                      orientation="v")
                                 )


In [None]:
person_ctry_pop_plot = go.Figure(data = person_ctry_code_pop_data, layout=person_ctry_code_pop_layout)
           
person_ctry_pop_plot.show()

# Save this plot as eps
if only_granted:
    filename = 'per_pop_size_all_only_granted'
else:
    filename = 'per_pop_size_all_all_appln'
    
#image_saver(person_ctry_pop_plot, filename, True)

person_ctry_code_pop_title


### Per labor force plot with logarithmic scale

In [None]:
person_ctry_code_pop_title_for_log = "Development of the number of battery patent families per population size by countries of origin of the inventors, "+str(min(years))+'-'+str(max(years))+'; the y-axis is log-scaled.'

person_ctry_pop_plot_log = go.Figure(data = person_ctry_code_pop_data, layout=person_ctry_code_pop_layout)

person_ctry_pop_plot_log.update_yaxes(type="log",
dtick=1)
           
person_ctry_pop_plot_log.show()

# Save this plot as eps
if only_granted:
    filename = 'per_pop_size_all_log_only_granted'
else:
    filename = 'per_pop_size_all_log_all_appln'
    
#image_saver(person_ctry_pop_plot_log, filename, True)

person_ctry_code_pop_title_for_log


### Per labor force plot - a close-up that excludes South Korea and Japan

In [None]:
highest_total_labels_reduced = ['CN', 'US', 'DE', 'TW']


In [None]:
person_ctry_code_pop_data_reduced = []

for item in person_ctry_code_pop_data:
    
    if not item['name'] in ['South Korea', 'Japan']:
        
        person_ctry_code_pop_data_reduced.append(item)
        
#person_ctry_code_pop_data_reduced


In [None]:
person_ctry_pop_reduced_plot = go.Figure(data = person_ctry_code_pop_data_reduced, layout=person_ctry_code_pop_layout)
           
person_ctry_pop_reduced_plot.show()

# Save this plot as eps
if only_granted:
    filename = 'per_pop_size_without_jp_sk_only_granted'
else:
    filename = 'per_pop_size_without_jp_sk_all_appln'
    
#image_saver(person_ctry_pop_reduced_plot, filename, True)

person_ctry_code_pop_title


### Continents plot

In [None]:
df_continents = df.copy()
#df_continents.drop(columns=['NC', 'EA', 'YR'], inplace=True)
df_continents.drop(columns=['NC', 'MH'], inplace=True)


In [None]:
ctry_code_continent_dict['PR'] = 'North America'
ctry_code_continent_dict['UK'] = 'Europe'
ctry_code_continent_dict['RE'] = '  '
ctry_code_continent_dict['GF'] = 'South America'
ctry_code_continent_dict['NA'] = 'Africa'


In [None]:
df_continents = df_continents.set_index('year').transpose()
df_continents.index.rename('country', inplace=True)


In [None]:
continents_list = []

for item in df_continents.index:
    
    continents_list.append(ctry_code_continent_dict[item])
    

In [None]:
df_continents['Continent'] = continents_list


In [None]:
df_continents[df_continents['Continent']=='Europe/Asia']


In [None]:
df_continents_grouped = df_continents.groupby(by='Continent').sum()
df_continents_grouped = df_continents_grouped.transpose()
df_continents_grouped.reset_index(drop=False, inplace=True)
df_continents_grouped.rename(columns={'Europe/Asia':'Europe/Asia *'}, inplace=True)
#df_continents_grouped


In [None]:
continents = list(df_continents_grouped)
continents.remove('year')
continents


In [None]:
#df_continents_grouped


In [None]:
# Increment every value by 1 in order to facilitate logarithmic plotting

for continent in continents:
    
    df_continents_grouped[continent] = df_continents_grouped[continent].values + 1
    

In [None]:
df_continents_grouped


In [None]:
continents_averages = df_continents_grouped.sum(axis = 0) / len(df_continents_grouped)
continents_averages


In [None]:
# Asia's average x times higher than Europe's / North America's

print(round((continents_averages['Asia'] / continents_averages['Europe']), 2))

print(round((continents_averages['Asia'] / continents_averages['North America']), 2))


In [None]:
# Compute mean year-over-year increase 

to_compute_for = 'Europe'

col_ = df_continents_grouped[to_compute_for]

increase = []
for i in range(1,len(col_)):
    
    increase_this_year = (col_[i] - col_[i-1]) / col_[i-1]
    print(increase_this_year)
    increase.append(increase_this_year)
    
round((sum(increase) / len(increase)), 4)


In [None]:
#df_continents_grouped.drop('  ', axis = 1, inplace = True)


In [None]:
#df_continents_grouped

In [None]:
for i, continent in enumerate(list(df_continents_grouped)[1:]):
    x = df_continents_grouped['year']
    y = df_continents_grouped[continent]
    #country = country_labels_dict[highest_total_labels[i]]
    #print(country)
    
    if True:
        if i == 0:
            continent_data = [dict(type='scatter',
                                   x=x,
                                   y=y,
                                   name=continent,
                                   line_width = line_width
                                  )
                             ]
        else:
            if True:
                continent_data.append(dict(type='scatter',
                                           x=x,
                                           y=y,
                                           name=continent,
                                           line_width = line_width
                                          )
                                     )
     

In [None]:
continent_title = "Development of the number of battery IPFs:<br>Counted by inventors' continents of origin, "+str(min(years))+'-'+str(max(years))+ '<br>The y-axis is log-scaled and all values are incremented by 1'



continent_layout =dict(
    title=dict(
        text = continent_title,
        y = 0.95,
        x = 0.5,
        xanchor = 'center',
        yanchor = 'top',
        font = dict(color = 'black')
    ),
    yaxis = dict(
        color = 'black',
        title = '1 + Number of IPFs',
        showgrid = True,
        gridwidth = 1,
        gridcolor = 'black',
        zerolinecolor = 'black',
        zerolinewidth = 1
    ),
    xaxis = dict(
        color = 'black',
        title='Year',
        dtick = 1
    ),
    legend = dict(
        xanchor = "left",
        yanchor = "middle",
        y = 0.5, 
        x = 1,
        orientation = "v"
    ),
    plot_bgcolor = 'white'
)


In [None]:
continent_plot = go.Figure(data = continent_data, layout=continent_layout)

continent_plot.update_yaxes(type="log",
                            dtick=1)

continent_plot.update_xaxes(dtick=1)

continent_plot.add_annotation(x=2015.3, y=0.15,
                              text='* In PATSTAT, the Russian Federation and<br>Turkey are classified as "Europe/Asia".',
                              font = dict(color = 'black'),
                              showarrow=False,
                              yshift=10)
           
continent_plot.show()

# Save this plot as eps
if only_granted:
    filename = 'development_by_continent_only_granted'
else:
    filename = 'development_by_continent_all_appln'

image_saver(continent_plot, filename, True)

continent_title


### Per labor force plot, sorted by values after scaling

In [None]:
all_scaled_by_population = all_.copy()
#all_scaled_by_population


In [None]:
# Drop countries that are not in the labor force dataset

all_scaled_by_population.drop(columns = ['SH', 'GF', 'JE', 'AN', 'KN', 'CS'], inplace = True)


In [None]:
pop_dict_all = {}

for country in list(all_scaled_by_population):

    country_name = ctry_code_name_dict_world_bank[country]
    #population = (populations[(populations['Location'] == country_name) & (populations['Time'] >= min(years)) & (populations['Time'] <= max(years))]['PopTotal']).values
    pop_dict_all[country] = labor_force_dict[country_name]

# Note: Values are in one thousand inhabitants!

#pop_dict_all


In [None]:
# Scale by labor force

for country in list(all_scaled_by_population):
    
    try:
        all_scaled_by_population.at[:, country] = all_scaled_by_population.loc[:, country] / pop_dict_all[country] * (10**3)

    except Exception as e:
        print(country)
        print(ctry_code_name_dict[country])
        #print(country in list(all_scaled_by_population))
        #print(country in list(pop_dict_all))
        print(type(e))
        print(e)
        print()
        
#all_scaled_by_population


# Notes:
# Not existant anymore:
# - Yugoslavia
# - Netherlands Antilles
# - Czechoslovakia
# Not found in set(populations['Location']):
# - Jersey

In [None]:
# Exclude countries with irregular behaviour

#all_scaled_by_population.drop(columns=['VG', 'KY', 'LI','LU','BM','MC', 'BB'], inplace=True)
#all_scaled_by_population.drop(columns=['VG', 'KY', 'LI','LU', 'BB'], inplace=True)
#all_scaled_by_population.drop(columns=['VG', 'KY', 'LI','LU', 'BB', 'BS', 'MT'], inplace=True)



In [None]:
# Get the countries with the highest total (scaled)

# Define how many countries to plot
num_countries_to_plot_scaled = 8

highest_total_labels_scaled = list(all_scaled_by_population.sum(axis=0).sort_values(ascending=False)[:num_countries_to_plot_scaled].index)
highest_total_labels_scaled


In [None]:
all_scaled_by_population.sum(axis=0).sort_values(ascending=False)[:20]

In [None]:
for i, country in enumerate(highest_total_labels_scaled):
    
    if country in list(country_labels_dict):
        country_name = country_labels_dict[country]
    
    else:
        country_name = ctry_code_name_dict[country]

    x = df['year']
    y = all_scaled_by_population.loc[:, country]
    
    #print(country)
    
    line_width = 3
    
    if True:
        if i == 0:
            person_ctry_code_pop_data_2 = [
                dict(
                    type = 'scatter',
                    x = x,
                    y = y,
                    name = country_name,
                    line_width = line_width
                )
            ]
        else:
            if True:
                person_ctry_code_pop_data_2.append(
                    dict(
                        type = 'scatter',
                        x = x,
                        y = y,
                        name = country_name,
                        line_width = line_width
                    )
                )


In [None]:
person_ctry_code_pop_title_2 = "Development of the number of battery IPFs per 1M workers:<br>Counted by inventors' countries of origin, " +str(min(years))+'-'+str(max(years))

person_ctry_code_pop_layout_2 =dict(
    title=dict(
        text = person_ctry_code_pop_title_2,
        y = 0.9,
        x = 0.5,
        xanchor = 'center',
        yanchor = 'top',
        font = dict(color = 'black')
    ),
    yaxis = dict(
        color = 'black',
        title = 'Number of IPFs per 1M workers',
        showgrid = True,
        gridwidth = 1,
        gridcolor = 'black',
        zerolinecolor = 'black',
        zerolinewidth = 1
    ),
    xaxis = dict(
        color = 'black',
        title='Year',
        dtick = 1
    ),
    legend = dict(
        xanchor = "left",
        yanchor = "middle",
        y = 0.5, 
        x = 1,
        orientation = "v"
    ),
    plot_bgcolor = 'white'
)


In [None]:
person_ctry_pop_plot_2 = go.Figure(data = person_ctry_code_pop_data_2, layout=person_ctry_code_pop_layout_2)

#person_ctry_pop_plot_2.update_yaxes(type="log", dtick=1)
           
person_ctry_pop_plot_2.show()

# Save this plot as eps
if only_granted:
    filename = 'per_pop_size_countries_with_highest_scaled_total_only_granted'
else:
    filename = 'per_pop_size_countries_with_highest_scaled_total_all_appln'
    
image_saver(person_ctry_pop_plot_2, filename, True)

person_ctry_code_pop_title_2


#### For comparison: The plot from above that simply shows the top 8 countries in terms of total patent family output

In [None]:
person_ctry_code_plot.show()


### Plot that shows the top 6 from both (sorted by total and sorted by values after scaling

In [None]:
per_labor_force_countries_to_plot = highest_total_labels_scaled[:6] + ['US', 'CN']

for i, country in enumerate(per_labor_force_countries_to_plot):
    
    if country in list(country_labels_dict):
        country_name = country_labels_dict[country]
    
    else:
        country_name = ctry_code_name_dict[country]

    x = df['year']
    y = all_scaled_by_population.loc[:, country]
    
    #print(country)
    
    if True:
        if i == 0:
            person_ctry_code_pop_data_3 = [dict(type='scatter',
                                x=x,
                                y=y,
                                name=country_name
                               )]
        else:
            if True:
                person_ctry_code_pop_data_3.append(dict(type='scatter',
                                    x=x,
                                    y=y,
                                    name=country_name
                                   ))


In [None]:
person_ctry_code_pop_title_3 = "Development of the number of IPFs per 1M workers for selected countries:<br>Counted by inventors' countries of origin, " +str(min(years))+'-'+str(max(years))#+', y-axis is log-scaled'

person_ctry_code_pop_layout_3 = dict(
    title=dict(
        text = person_ctry_code_pop_title_3,
        y = 0.9,
        x = 0.5,
        xanchor = 'center',
        yanchor = 'top',
        font = dict(color = 'black')
    ),
    yaxis=dict(
        title = 'Number of IPFs per 1M workers',
        color = 'black',
        showgrid = True,
        gridwidth = 1,
        gridcolor = 'black',
        #type = "log",
        #dtick = 1
        #zeroline=True,
        zerolinecolor = 'black',
        zerolinewidth = 1
    ),
    xaxis=dict(
        title = 'Year',
        color = 'black',
        dtick = 1,
        #showgrid = True,
        #gridcolor = 'black'
    ),
    legend = dict(
        xanchor="left",
        yanchor="middle",
        y=0.5, 
        x=1,
        orientation="v"
    ),
    # Make background white
    plot_bgcolor = "white"
)


In [None]:
person_ctry_pop_plot_3 = go.Figure(data = person_ctry_code_pop_data_3, layout=person_ctry_code_pop_layout_3)

# Make background white
#person_ctry_pop_plot_3.update_layout(plot_bgcolor = "white")

#person_ctry_pop_plot_3.update_xaxes(showline=True, linewidth=2, linecolor='black')
#person_ctry_pop_plot_3.update_yaxes(showline=True, linewidth=2, linecolor='black')

#person_ctry_pop_plot_3.update_yaxes(showgrid=True, zeroline=True)
#person_ctry_pop_plot_3.update_xaxes(showgrid=True, zeroline=True)

person_ctry_pop_plot_3.show()

# Save this plot as eps
if only_granted:
    filename = 'per_pop_size_selected_countries_only_granted'
else:
    filename = 'per_pop_size_selected_countries_all_appln'
    
#image_saver(person_ctry_pop_plot_3, filename, True)

person_ctry_code_pop_title_3


## Plots that separate different battery technologies

### Create technologies dataframes

In [None]:
technologies_list = ['Lead-acid',
                     'Lithium-air',
                     'Lithium-ion',
                     'Lithium-sulfur',
                     'Other lithium',
                     'Magnesium-ion',
                     'Nickel-cadmium',
                     'Nickel-iron',
                     'Nickel-zinc',
                     'Nickel-metal hydride',
                     'Rechargeable alkaline',
                     'Sodium-sulfur',
                     'Sodium-ion',
                     'Solid-state',
                     'Aluminium-ion',
                     'Calcium(-ion)',
                     'Organic radical',
                     'Redox flow',
                     'Nickel–hydrogen']


In [None]:
len(technologies_list)


In [None]:
# New version using tags

family_id_ctry_codes_lists_list = []
ctry_codes_counts_lists_list = []
known_percentage_lists_list = []

for i in range(len(technologies_list)):
    
    print(technologies_list[i])

    #data_this_technology = data_technologies_list[i]
    one_hot_col_name_this_technology = 'is_'+technologies_list[i]
    print(one_hot_col_name_this_technology)
    data_this_technology = data[data[one_hot_col_name_this_technology]==1]
    print(len(data_this_technology))

    family_id_ctry_codes_list = []
    ctry_codes_counts_list = []
    known_percentage_list = []

    for year in tqdm(years):

        #print('Year', str(year))

        data_to_function = data_this_technology[data_this_technology['earliest_publn_year_this_family_id']==year]

        family_id_ctry_codes, ctry_codes_counts, known_percentage = get_counts(
            data_to_function,
            False,
            False,
            True,
            True)

        family_id_ctry_codes_list = family_id_ctry_codes_list + [family_id_ctry_codes]
        ctry_codes_counts_list = ctry_codes_counts_list + [ctry_codes_counts]
        known_percentage_list = known_percentage_list + [known_percentage]
        
    family_id_ctry_codes_lists_list = family_id_ctry_codes_lists_list + [family_id_ctry_codes_list]
    ctry_codes_counts_lists_list = ctry_codes_counts_lists_list + [ctry_codes_counts_list]
    known_percentage_lists_list = known_percentage_lists_list + [known_percentage_list]

#new = family_id_ctry_codes_lists_list.copy()


In [None]:
dfs_technologies_list = []

for i in range(len(technologies_list)):
    
    df_this_technology = pd.DataFrame.from_records(ctry_codes_counts_lists_list[i])
    df_this_technology.insert(loc=0, column='year', value=years)
    df_this_technology.rename(columns={'  ':'unknown'}, inplace=True)
    
    dfs_technologies_list = dfs_technologies_list + [df_this_technology]
    

In [None]:
#print(technologies_list[i])

#print(dfs_technologies_list[i].shape)

#dfs_technologies_list[i]


### Global plot

In [None]:
technology_totals = pd.DataFrame()

# Build df for whole timespan
for i, df in enumerate(dfs_technologies_list):
    
    this_technology_totals = df.drop(columns='year').sum(axis=1)
    
    technology_totals[technologies_list[i]] = this_technology_totals
    
technology_totals.insert(0, 'year', years)


In [None]:
technology_totals


In [None]:
# New with other lithium - inventors - technologies one-hot sum - new dataset containing charging

data_plot = []

for column in list(technology_totals.drop(columns='year')):
    
    data_plot = data_plot + [go.Bar(name=column,
                   x = years,
                   y = technology_totals[column])
                  ]

technologies_countries_all_plot = go.Figure(data_plot)

technologies_countries_all_title = "Development of the world's battery technology distribution, "+str(year_begin)+'-'+str(year_end)

# Change the bar mode
technologies_countries_all_plot.update_layout(barmode='stack',
                  title=dict(text = technologies_countries_all_title),
                  yaxis=dict(title='Number of battery patent families'),
                  xaxis=dict(title='Year'),
                  legend = dict(
                                      xanchor="left",
                                      yanchor="middle",
                                      y=0.5, 
                                      x=1,
                                      orientation="v")
                 )

technologies_countries_all_plot.update_xaxes(dtick=1)

technologies_countries_all_plot.show()

# Save this plot as eps
if only_granted:
    filename = 'technologies_countries_all'
else:
    filename = 'technologies_countries_all_all_appln'
    
#image_saver(technologies_countries_all_plot, filename, True)

technologies_countries_all_title


In [None]:
bubbles_plot_categories = ['Lead-acid',
                           'Lithium-ion',
                           'Lithium-sulfur',
                           'Solid-state',
                           'Other lithium',
                           'Sodium-ion',
                           'Redox flow']

bubbles_plot_categories.reverse()

bubbles_plot_data = technology_totals[['year']+bubbles_plot_categories]
bubbles_plot_data.set_index('year', inplace=True)

rows = list(bubbles_plot_data.index)
cols = list(bubbles_plot_data)

#print(rows)
#print()
#print(cols)
#print()

array = []
for row in rows:
    for col in cols:
        #print(row)
        #print(col)
        number = bubbles_plot_data.loc[row, col]
        #print(number)
        #print()
        
        new_row = [row, col, number]
        array.append(new_row)
        
bubbles_plot_data_transformed = pd.DataFrame(array, columns=['year', 'category', 'number'])

#bubbles_plot_data_transformed['number**(1/2)'] = bubbles_plot_data_transformed['number']**(1/2)

#bubbles_plot_data_transformed


In [None]:
category_column = list(bubbles_plot_data_transformed['category'])
category_column_transformed = []
for item in category_column:
    
    if ('-' in item):
        item = item.replace('-', '-<br>')
    if (' ' in item):
        item = item.replace(' ', '<br>')
    category_column_transformed.append(item)

bubbles_plot_data_transformed['category'] = category_column_transformed

#####

# For adding (except solid-state) to Lithium-ion
# I'm thinking it might cause more confusion than clarity
if False:
    
    category_column = list(bubbles_plot_data_transformed['category'])
    category_column_transformed = []
    for item in category_column:

        if item == 'Lithium-<br>ion':
            item = 'Lithium-ion<br>(except solid-state)'
        category_column_transformed.append(item)

    bubbles_plot_data_transformed['category'] = category_column_transformed

#bubbles_plot_data_transformed


In [None]:
scaler = 3

bubbles_plot = go.Figure(data=[go.Scatter(
    x = bubbles_plot_data_transformed['year'],
    y = bubbles_plot_data_transformed['category'],
    mode='markers+text',    
    marker_size = (bubbles_plot_data_transformed['number'] * scaler) ** (1/2),
    marker_color = colors_plotly_default[0],
    
    text = round(bubbles_plot_data_transformed['number'].astype(int)),
    
    textposition='bottom center',
    #textfont = dict(color = colors_plotly_default[0])
    textfont = dict(color = 'black')
    
)])

#bubbles_plot.update_xaxes(showgrid=True, gridwidth=1, gridcolor='black')
bubbles_plot.update_yaxes(showgrid=True, gridwidth=1, gridcolor='black')

bubbles_plot.update_xaxes(dtick=1)

# Make background white
bubbles_plot.update_layout(plot_bgcolor = "white")

# Set title
title = "Development of the world's battery patenting activity for selected battery types:<br>The depicted battery IPF fractional counts are rounded to the closest integer, "+str(year_begin)+'-'+str(year_end)
bubbles_plot.update_layout(
    title=dict(
        text = title,
        y = 0.9,
        x = 0.5,
        xanchor = 'center',
        yanchor = 'top',
        font = dict(color = 'black')
    ),
    yaxis=dict(title='Battery type',
               color = 'black'),
    xaxis=dict(title='Year',
               color = 'black',
               tickmode = 'array',
               tickvals = list(range(year_begin, year_end + 1))
    )
)

bubbles_plot.show()

image_saver(bubbles_plot, 'bubbles', True)


### Countries' plots

In [None]:
countries_to_plot = ['CN', 'JP', 'KR', 'US', 'DE']

countries_dfs_to_plot = []

for country in countries_to_plot:
        
    this_country_df = pd.DataFrame()
    
    for i in range(len(technologies_list)):
        
        try:

            this_country_df[technologies_list[i]] = dfs_technologies_list[i][country]
            
        except Exception as e:
            
            print(e)
            print(technologies_list[i])
            
            dfs_technologies_list[i][country] = np.nan
            
            this_country_df[technologies_list[i]] = dfs_technologies_list[i][country]
        
    countries_dfs_to_plot = countries_dfs_to_plot + [this_country_df]


In [None]:
def plot_technology_development(entities_dfs_to_plot, entity_index):
    
    data_plot = []

    for column in list(entities_dfs_to_plot[entity_index]):

        data_plot = data_plot + [go.Bar(name=column,
                       x = years,
                       y = entities_dfs_to_plot[entity_index][column])
                      ]

    fig = go.Figure(data_plot)

    country_written = country_labels_dict[countries_to_plot[entity_index]]

    title = "Development of "+country_written+"'s battery technology distribution, "+str(year_begin)+"-"+str(year_end)

    # Change the bar mode
    fig.update_layout(barmode='stack',
                      title=dict(text=title),
                      yaxis=dict(title='Number of battery patent families from '+country_written),
                      xaxis=dict(title='Year'),
                      legend = dict(
                                          xanchor="left",
                                          yanchor="middle",
                                          y=0.5, 
                                          x=1,
                                          orientation="v")
                     )

    fig.update_xaxes(dtick=1)
    
    fig.show()
    
    # Save this plot as eps
    
    filename = 'technologies_countries_'+country_written
    
    if not only_granted:
        filename = filename+'_all_appln'

    #image_saver(fig, filename, True)
    
    print(title)

In [None]:
plot_technology_development(countries_dfs_to_plot, 0)


In [None]:
plot_technology_development(countries_dfs_to_plot, 1)


In [None]:
plot_technology_development(countries_dfs_to_plot, 2)


In [None]:
plot_technology_development(countries_dfs_to_plot, 3)


In [None]:
plot_technology_development(countries_dfs_to_plot, 4)


## Clustering countries using their technology distributions

In [None]:
# Make copies before transforming
dfs_technologies_list_save = copy.deepcopy(dfs_technologies_list)
technologies_list_save = technologies_list.copy()


In [None]:
# This cell is for loading the copies
dfs_technologies_list = copy.deepcopy(dfs_technologies_list_save)
technologies_list = technologies_list_save.copy()


In [None]:
technologies_to_pool = [
    'Lithium-ion',
    'Other lithium'
]


In [None]:
dfs_to_pool = []
for technology_to_pool in technologies_to_pool:
    
    technology_index = technologies_list.index(technology_to_pool)
    print(technology_index)
    
    df_this_technology = dfs_technologies_list[technology_index]
    
    dfs_to_pool.append(df_this_technology)
    

In [None]:
sum_df = dfs_to_pool[0]

for df in dfs_to_pool[1:]:

    sum_df = sum_df.add(df, fill_value=0)

sum_df['year'] = (sum_df['year'] / 2).astype(int)
#sum_df


In [None]:
technologies_clustering = [
    'Lead-acid',
    'Lithium-sulfur',
    #'Other lithium and lithium-ion',
    #'Lithium-ion',
    'Sodium-ion',
    'Solid-state',
    'Redox flow'
]


In [None]:
dfs_technologies_list_clustering = []
dfs_technologies_list_clustering.append(dfs_technologies_list[technologies_list.index('Lead-acid')])
dfs_technologies_list_clustering.append(dfs_technologies_list[technologies_list.index('Lithium-sulfur')])
#dfs_technologies_list_clustering.append(sum_df)
#dfs_technologies_list_clustering.append(dfs_technologies_list[technologies_list.index('Lithium-ion')])
dfs_technologies_list_clustering.append(dfs_technologies_list[technologies_list.index('Sodium-ion')])
dfs_technologies_list_clustering.append(dfs_technologies_list[technologies_list.index('Solid-state')])
dfs_technologies_list_clustering.append(dfs_technologies_list[technologies_list.index('Redox flow')])


In [None]:
dfs_technologies_list = dfs_technologies_list_clustering


In [None]:
technologies_list = technologies_clustering


### Define functions for assessment, clustering, and plotting

In [None]:
figsizes = (15, 8)

sns.set()

def get_tech_dist_dfs_absolute_list(time_periods_list, dfs_technologies_list):

    tech_dist_dfs_absolute_list = []

    for j, time_period in enumerate(time_periods_list):

        distributions_df_this_time_period = pd.DataFrame()

        for i, df_ in enumerate(dfs_technologies_list):

            to_sum = df_[(df_['year'] >= time_period[0]) & (df_['year'] <= time_period[1])]

            to_sum = to_sum.fillna(0)

            time_period_sum = to_sum.drop(columns=['year', 'unknown']).sum(axis=0)

            distributions_df_this_time_period[technologies_list[i]] = time_period_sum

        distributions_df_this_time_period = distributions_df_this_time_period.fillna(0)

        tech_dist_dfs_absolute_list.append(distributions_df_this_time_period)

    return tech_dist_dfs_absolute_list


def get_tech_dist_dfs_normalised_list(tech_dist_dfs_absolute_list):

    # Create rows-wise distributions

    tech_dist_dfs_normalised_list = []
    rows_sums_notzero_list = []

    for distributions_df in tech_dist_dfs_absolute_list:

        rows_sums = distributions_df.sum(axis=1)

        distributions_df_scaled = distributions_df.divide(rows_sums, axis=0)

        tech_dist_dfs_normalised_list.append(distributions_df_scaled)

        # Get boolean indexer identifying the rows that don't have sum=0
        rows_sums_notzero = (rows_sums!=0)
        rows_sums_notzero_list.append(rows_sums_notzero)

    rows_sums_notzero_all = rows_sums_notzero_list[0]

    for boolean_series in rows_sums_notzero_list:

        rows_sums_notzero_all = rows_sums_notzero_all & boolean_series

    for i in range(len(tech_dist_dfs_normalised_list)):

        tech_dist_dfs_normalised_list[i] = tech_dist_dfs_normalised_list[i].loc[rows_sums_notzero_all, :]

    return tech_dist_dfs_normalised_list


def get_tech_dist_dfs_normalised_scaled_list(tech_dist_dfs_normalised_list):
    
    tech_dist_dfs_normalised_scaled_list = []
    scaler_list = []

    for distributions_df in tech_dist_dfs_normalised_list:

        rows_ = list(distributions_df.index)

        columns_ = list(distributions_df)

        scaler = MinMaxScaler()

        distributions_array_scaled = scaler.fit_transform(distributions_df)

        scaler_list.append(scaler)

        distributions_df_scaled = pd.DataFrame(data = distributions_array_scaled, index=rows_, columns=columns_)

        tech_dist_dfs_normalised_scaled_list.append(distributions_df_scaled)

    return tech_dist_dfs_normalised_scaled_list


# The following contains adapted code from Data Mining labs 8 and 12 at NOVA IMS, 2020.

def get_ss(df):
    """Computes the sum of squares for all variables given a dataset
    """
    ss = np.sum(df.var() * (df.count() - 1))
    return ss  # return sum of sum of squares of each df variable


def r2(df, labels):
    sst = get_ss(df)
    ssw = np.sum(df.groupby(labels).apply(get_ss))
    return 1 - ssw/sst
    
    
def get_r2_scores(df, clusterer, min_k=1, max_k=10):
    """
    Loop over different values of k. To be used with sklearn clusterers.
    """
    r2_clust = {}
    for n in range(min_k, max_k):
        clust = clone(clusterer).set_params(n_clusters=n)
        labels = clust.fit_predict(df)
        r2_clust[n] = r2(df, labels)
    return r2_clust


def get_dataframes(i,
                  tech_dist_dfs_absolute_list,
                  tech_dist_dfs_normalised_list,
                  tech_dist_dfs_normalised_scaled_list):
    
    data_absolute = tech_dist_dfs_absolute_list[i].copy().loc[tech_dist_dfs_normalised_list[i].index]

    data_normalised = tech_dist_dfs_normalised_list[i]

    data_to_cluster = tech_dist_dfs_normalised_scaled_list[i]
    
    return data_absolute, data_normalised, data_to_cluster


def check_clustering_methods(data_to_cluster):
    
    # Set up the clusterers
    kmeans = KMeans(
        init='k-means++',
        n_init=20,
        random_state=42
    )

    hierarchical = AgglomerativeClustering(
        affinity='euclidean'
    )
    
    # Obtaining the R² scores for each cluster solution
    r2_scores = {}
    r2_scores['kmeans'] = get_r2_scores(data_to_cluster, kmeans)

    for linkage in ['complete', 'average', 'single', 'ward']:
        r2_scores[linkage] = get_r2_scores(
            data_to_cluster, hierarchical.set_params(linkage=linkage)
        )

    pd.DataFrame(r2_scores)
    
    # Visualizing the R² scores for each cluster solution
    pd.DataFrame(r2_scores).plot.line(figsize=figsizes)

    title = "R² plot for various clustering methods"
    #plt.title(title, fontsize=18)
    
    plt.legend(title="Cluster methods", title_fontsize=14)
    plt.xlabel("Number of clusters", fontsize=14)
    plt.ylabel("R² metric", fontsize=14)
    plt.show()
    
    return title

    
def agglom_clustering_full(data_to_cluster, y_threshold_1 = 3.25, y_threshold_2 = 2.25, *args):
    
    # setting distance_threshold=0 and n_clusters=None ensures we compute the full tree
    linkage = 'ward'
    distance = 'euclidean'
    hclust = AgglomerativeClustering(linkage=linkage, affinity=distance, distance_threshold=0, n_clusters=None)
    hclust.fit_predict(data_to_cluster)
    
    # Adapted from:
    # https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_dendrogram.html#sphx-glr-auto-examples-cluster-plot-agglomerative-dendrogram-py

    # create the counts of samples under each node (number of points being merged)
    counts = np.zeros(hclust.children_.shape[0])
    n_samples = len(hclust.labels_)

    # hclust.children_ contains the observation ids that are being merged together
    # At the i-th iteration, children[i][0] and children[i][1] are merged to form node n_samples + i
    for i, merge in enumerate(hclust.children_):
        # track the number of observations in the current cluster being formed
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                # If this is True, then we are merging an observation
                current_count += 1  # leaf node
            else:
                # Otherwise, we are merging a previously formed cluster
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    # the hclust.children_ is used to indicate the two points/clusters being merged (dendrogram's u-joins)
    # the hclust.distances_ indicates the distance between the two points/clusters (height of the u-joins)
    # the counts indicate the number of points being merged (dendrogram's x-axis)
    linkage_matrix = np.column_stack(
        [hclust.children_, hclust.distances_, counts]
    ).astype(float)

    # Plot the corresponding dendrogram
    sns.set()
    fig = plt.figure(figsize=figsizes)

    plus_ = 0.03
    pos_x = 140
    text_fontsize = 15

    plt.hlines(y_threshold_1, 0, 1000, colors="r", linestyles="dashed")
    plt.hlines(y_threshold_2, 0, 1000, colors="r", linestyles="dashed")

    plt.text(pos_x, y_threshold_1 + plus_,'2 clusters', fontsize=text_fontsize)
    plt.text(pos_x, y_threshold_2 + plus_,'3 clusters', fontsize=text_fontsize)
    
    if len(args)!=0:
        y_threshold_5 = args[0]
        plt.hlines(y_threshold_5, 0, 1000, colors="r", linestyles="dashed")
        plt.text(pos_x, y_threshold_5 + plus_,'5 clusters', fontsize=text_fontsize)
        dendrogram(linkage_matrix, truncate_mode='level', p=5, color_threshold=y_threshold_5, above_threshold_color='k')
        
    else:
        dendrogram(linkage_matrix, truncate_mode='level', p=5, color_threshold=y_threshold_2, above_threshold_color='k')

    title = f'Dendrogram - Hierarchical Clustering using {linkage.title()}\'s linkage'
    #plt.title(title, fontsize=20)

    plt.xlabel('Index of point or (number of points in node)', fontsize=20)
    plt.ylabel(f'{distance.title()} Distance', fontsize=20)

    plt.tick_params(axis = 'both', labelsize = 15)

    plt.show()
    
    return title
    
    
def run_k_means(data_to_cluster, data_normalised, data_absolute, random, n_clusters=2):

    if random:
        kmeans = KMeans(n_clusters = n_clusters)
    else:
        kmeans = KMeans(n_clusters = n_clusters, random_state = 10)

    labels = kmeans.fit_predict(data_to_cluster)

    result_normalised = data_normalised.copy()
    result_normalised['label'] = labels

    result_absolute = data_absolute.copy()
    result_absolute['label'] = labels

    return result_normalised, result_absolute


def run_agglomerative(data_to_cluster, data_normalised, data_absolute, n_clusters=2):

    agglomerative = AgglomerativeClustering(n_clusters)

    labels = agglomerative.fit_predict(data_to_cluster)

    result_normalised = data_normalised.copy()
    result_normalised['label'] = labels

    result_absolute = data_absolute.copy()
    result_absolute['label'] = labels

    return result_normalised, result_absolute
    

def show_cluster_profiles(data_to_plot_clustering):
    
    centroids = data_to_plot_clustering.groupby(by='label', as_index=False).mean()

    sns.set()

    fig = plt.figure(figsize=figsizes)

    pd.plotting.parallel_coordinates(centroids, 'label', color=sns.color_palette())
    
    #Setting Layout
    plt.xlabel('Technology', fontsize=20)
    plt.ylabel('Distribution value', fontsize=20)
    
    title = 'Cluster profiles'
    #plt.title('Cluster profiles', fontsize=20)
    
    plt.tick_params(axis = 'both', labelsize = 15)

    plt.show()

    print(title)
    
def plot_cluster_counts(data_to_plot_clustering):
    
    first_colname = list(data_to_plot_clustering)[0]

    counts = data_to_plot_clustering.groupby(by='label', as_index=False).count().iloc[:,[0,1]].rename(
    columns=({first_colname:'count'}))
    
    sns.set()

    fig = plt.figure(figsize=figsizes)

    sns.barplot(x='label', y="count", data=counts)
    
    #Setting Layout
    plt.xlabel('Cluster', fontsize=20)
    plt.ylabel('Count', fontsize=20)
    
    title = 'Counts per cluster'
    #plt.title('Counts per cluster', fontsize=20)
    
    plt.tick_params(axis = 'both', labelsize = 15)

    plt.show()
    
    print(title)
    
    
def show_result(label):

    result_show = result_absolute.copy()

    result_show = result_show[result_show['label']==label]

    result_show['sum'] = result_show.sum(axis=1)

    result_show = result_show.sort_values(by='sum', axis=0, ascending=False)
    
    result_show = result_show.drop(columns=['sum'])
    
    # find
    
    countries_list = [
        country_labels_dict[item] if item in list(country_labels_dict) else ctry_code_name_dict[item] for item in list(result_show.index) 
    ]
    
    print(len(countries_list))
    print(countries_list)

    #return result_show.head(5)
    return result_show


def cluster_profiles_radar(plot_data, title, legend_pos, title_pad, yticks, save_fig, *args):
    
    features = list(plot_data)[:-1]
    num_features = len(features)

    centroids = {}
    for label in list(set(plot_data['label'])):

        data_this_cluster = plot_data[plot_data['label']==label].iloc[:, :num_features]

        centroid_this_cluster = list(data_this_cluster.sum(axis=0) / len(data_this_cluster))

        centroids[label] = centroid_this_cluster

    plt.style.use('ggplot')

    # Define feature names and add spaces to make plot prettier
    #['Lead-acid', 'Lithium-sulfur', 'Solid-state', 'Sodium-ion', 'Redox flow', 'Lead-acid']
    
    # 5 features
    if True:
        
        features_altered = []
        for i, item in enumerate(features):
            space = '          '
            if (i == 2) or (i == 3):
                item = space+item
            if i == 0:
                item = item+space

            features_altered.append(item)
        features = features_altered
       
    # 4 features
    if False:
        
        features_altered = []
        for i, item in enumerate(features):
            space = '          '
            if i == 0:
                item = space+item
            if i == 2:
                item = item+space

            features_altered.append(item)
        features = features_altered

    # Define angles
    angles = np.linspace(0, 2*np.pi, num_features, endpoint=False) + np.pi
    angles = np.concatenate((angles,[angles[0]]))

    # Repeat first item in features list and in centroids lists at their end
    features.append(features[0])

    for item in centroids:
        centroids[item].append(centroids[item][0])

    # Define figure
    fig=plt.figure(figsize=(6,7))
    #fig=plt.figure(figsize=(9,9))
    ax=fig.add_subplot(111, polar=True)

    # Define colors (use default plotly colors)
    colors = colors_plotly_default
    
    # Create each plot
    for i, item in enumerate(centroids):
        
        label_string = 'Cluster '+str(item+1)
        
        ax.plot(
            angles,
            centroids[item],
            '-',
            color = colors[i],
            linewidth = 3,
            label = label_string
        )
    
    # Set grid
    #ax.set_thetagrids(angles * 180/np.pi, features)
    
    start = 180
    positions = []
    num_features = 5
    for i in range(0, num_features+1):
        angle = start + i * (360 / num_features)

        if angle < 360:
            positions.append(angle)
        else:
            positions.append(angle- 360)

    ax.set_thetagrids(
        positions,
        features
    )
    
    plt.grid(
        visible=args[0],
        which=args[1],
        axis=args[2],
        color=args[3],
        linestyle=args[4],
        linewidth=args[5]
    )

    # Define Layout
    ax.set_facecolor("white")
    plt.tight_layout()
    
    # Define legend
    plt.legend(facecolor="white",
               frameon = False,
               loc='lower left', 
               #loc='upper right', 
               bbox_to_anchor = legend_pos
              )
    
    # Move x ticks further out (pad > 0)
    #ax.tick_params(axis='x', which='major', pad=30)

    # Change position of radial axis ticks
    r_label_angle = 0 #theta angle

    # Radial tick parameters
    ax.set_rlabel_position(r_label_angle)
    
    # Title
    plt.title(title,
              fontdict = None,
              loc = None,
              pad = title_pad,
              y = None)
    
    ax.set_yticks(yticks)
    
    ax.tick_params(axis='x', colors='black')
    ax.tick_params(axis='y', colors='black')

    # Save plot as .eps and display it
    if save_fig:
        plt.savefig('radar.eps')
    plt.show()
    
    return centroids
    

### 2010-2019

In [None]:
time_periods_list = [[2010,2019]]

time_periods_list


In [None]:
tech_dist_dfs_absolute_list_2010_2019 = get_tech_dist_dfs_absolute_list(time_periods_list, dfs_technologies_list)
#tech_dist_dfs_absolute_list_2010_2019[0]


In [None]:
tech_dist_dfs_absolute_list_2010_2019[0] = tech_dist_dfs_absolute_list_2010_2019[0]
print(len(tech_dist_dfs_absolute_list_2010_2019[0]))

bool_ = tech_dist_dfs_absolute_list_2010_2019[0].sum(axis=1)>= 0

tech_dist_dfs_absolute_list_2010_2019[0] = tech_dist_dfs_absolute_list_2010_2019[0][bool_]

print(len(tech_dist_dfs_absolute_list_2010_2019[0]))


In [None]:
tech_dist_dfs_normalised_list_2010_2019 = get_tech_dist_dfs_normalised_list(tech_dist_dfs_absolute_list_2010_2019)
tech_dist_dfs_normalised_list_2010_2019[0]

print(len(tech_dist_dfs_normalised_list_2010_2019[0]))


In [None]:
tech_dist_dfs_normalised_scaled_list_2010_2019 = get_tech_dist_dfs_normalised_scaled_list(tech_dist_dfs_normalised_list_2010_2019)
tech_dist_dfs_normalised_scaled_list_2010_2019[0]

print(len(tech_dist_dfs_normalised_scaled_list_2010_2019[0]))


In [None]:
data_absolute, data_normalised, data_to_cluster = get_dataframes(0,
                                                                tech_dist_dfs_absolute_list_2010_2019,
                                                                tech_dist_dfs_normalised_list_2010_2019,
                                                                tech_dist_dfs_normalised_scaled_list_2010_2019)


In [None]:
print(len(data_absolute))

print(len(data_normalised))

print(len(data_to_cluster))

print(list(data_absolute.index) == list(data_normalised.index))

print(list(data_absolute.index) == list(data_to_cluster.index))


In [None]:
check_clustering_methods(data_to_cluster)


In [None]:
agglom_clustering_full(data_to_cluster, y_threshold_1 = 2.25, y_threshold_2 = 1.75)


#### Find 2 clusters

In [None]:
result_normalised_2010_2019_k_2, result_absolute_2010_2019_k_2 = run_k_means(data_to_cluster, data_normalised, data_absolute, False, n_clusters=2)

result_normalised = result_normalised_2010_2019_k_2
result_absolute = result_absolute_2010_2019_k_2

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)


In [None]:
if False:
    
    result_normalised_2010_2019_h_2, result_absolute_2010_2019_h_2 = run_agglomerative(data_to_cluster, data_normalised, data_absolute, n_clusters=2)

    result_normalised = result_normalised_2010_2019_h_2
    result_absolute = result_absolute_2010_2019_h_2

    show_cluster_profiles(result_normalised)

    plot_cluster_counts(result_normalised)


In [None]:
centroids_2 = cluster_profiles_radar(
    result_normalised,
    "Clustering inventors' countries of origin by their\nbattery type distribution using recent ten years' data:\nProfiles of three clusters computed by k-means algorithm",
    (0.7, 0.87), # legend_pos
    40, # title_pad
    [0.2, 0.4, 0.6, 0.8], # y_ticks
    False,
    True,
    'major',
    'both',
    'black',
    '-',
    1
)


In [None]:
centroids_2


In [None]:
show_result(0)


In [None]:
show_result(1)


#### Find 3 clusters

In [None]:
result_normalised_2010_2019_k_3, result_absolute_2010_2019_k_3 = run_k_means(data_to_cluster, data_normalised, data_absolute, False, n_clusters=3)

result_normalised = result_normalised_2010_2019_k_3
result_absolute = result_absolute_2010_2019_k_3

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)


In [None]:
centroids_3 = cluster_profiles_radar(
    result_normalised,
    "Clustering inventors' countries of origin by their\nbattery type distribution using recent ten years' data:\nProfiles of three clusters computed by k-means algorithm",
    (0.77, 0.87), # legend_pos
    40, # title_pad
    [0.2, 0.4, 0.6, 0.8], # y_ticks
    True,
    True,
    'major',
    'both',
    'black',
    '-',
    1
)

##### Run k-means n times in order to compute affiliations distributions

In [None]:
check = data_absolute.copy()
check['sum'] = check.sum(axis = 1)
check.sort_values('sum', ascending = False, inplace = True)
countries_descending = list(check.index)
sums = list(check['sum'])
#countries_descending


In [None]:
sum_dict = {}

for i in range(len(check)):
    
    sum_dict[countries_descending[i]] = sums[i]
    
#sum_dict


In [None]:
n = 10000

clusterings = []

for i in tqdm(range(n)):

    run_multiple_k_means_normalised, run_multiple_k_means_absolute = run_k_means(data_to_cluster, data_normalised, data_absolute, True, n_clusters=3)

    clustering = []
    for cluster in set(run_multiple_k_means_normalised['label']):

        index = set(run_multiple_k_means_normalised[run_multiple_k_means_normalised['label'] == cluster].index)
        clustering.append(index)

    clusterings.append(clustering)
    
#clusterings

In [None]:
# Get all co-occurrences
co_occurrences = set()

for clustering in clusterings:
    
    for cluster in clustering:
        
        for pair in it.product(cluster, cluster, repeat = 1):
            
            if pair[0] != pair[1]:
                
                co_occurrences.add(pair)
    
len(co_occurrences)


In [None]:
# Get all theoretically possible triples

all_countries = set(run_multiple_k_means_normalised.index)

possible_triples = set()

for triple in it.product(all_countries, all_countries, all_countries, repeat = 1):
            
    if (triple[0] != triple[1]) & (triple[0] != triple[2]) & (triple[1] != triple[2]):

        possible_triples.add(triple)
        
possible_triples = list(possible_triples)

len(possible_triples)


In [None]:
# Get triples of countries that are always in their own separate cluster - these are triples that can be used as handles

never_occur_together = set()

for triple in possible_triples:
        
    pairs = []
    for pair in it.product(triple, triple, repeat = 1):
    
        if (pair[0] != pair[1]):
            pairs.append(pair)
        
    co_occurred = False
    for pair in pairs:
        
        if pair in co_occurrences:
            co_occurred = True
    
    if not co_occurred:
        never_occur_together.add(triple)
        
len(never_occur_together)
#never_occur_together


In [None]:
for country in list(never_occur_together)[0]:
    
    if country in country_labels_dict:
        print(country, country_labels_dict[country])
    else:
        print(country, ctry_code_name_dict[country])
        
print(list(never_occur_together)[0])


In [None]:
# Define handles
handles = ['TH', 'KP', 'CA']

n_before = len(run_multiple_k_means_normalised)
print(n_before)

neighbors = []

for i, clustering in enumerate(clusterings):
    
    for j, handle in enumerate(handles):
        
        for cluster in clustering:
                        
            if handle in cluster:
                
                if i == 0:

                    neighbors.append(cluster)

                else:
                    
                    neighbors[j] = neighbors[j].intersection(cluster)

n_after = 0

for neighborhood in neighbors:
    
    n_after += len(neighborhood)

print(n_after)

In [None]:
# Sort neighbors in descending according to their total count in these categories

neighbors_descending = [[], [], []]

for i, neighborhood in enumerate(neighbors):
    
    for country in countries_descending:
        
        if country in neighborhood:
            
            if country in country_labels_dict:
                neighbors_descending[i].append(country_labels_dict[country])
            else:
                neighbors_descending[i].append(ctry_code_name_dict[country])

neighbors_descending


In [None]:
for neighborhood in neighbors_descending:
    
    print(', '.join(neighborhood)+' ('+str(len(neighborhood))+' countries).')
    print()
        

#### Compute each country's cluster affiliation distribution and build most probable clusters

In [None]:
handles

In [None]:
counters = {}

print(handles)

for i, country in enumerate(all_countries):
    
    counter = {}
    for handle in handles:
        counter[handle] = 0
    
    for clustering in clusterings:
        
        for cluster in clustering:
            
            if country in cluster:
                
                for handle in handles:
                    
                    if handle in cluster:
                        
                        counter[handle] += 1
                        
    counters[country] = counter

distributions = {}
for country in list(counters):
    
    distributions[country] = []
    
for country in list(counters):
    
    for count in list(counters[country]):
                
        distributions[country].append(counters[country][count] / n)
    
handles_distributions = np.eye(3).tolist()

for i, handle in enumerate(handles):
    distributions[handle] = handles_distributions[i]
    
distributions


In [None]:
# Check if there are any leveled distributions (two values the same)

for country in list(distributions):
    
    distribution = distributions[country]
    
    for i,v1 in enumerate(distribution):
        for j,v2 in enumerate(distribution):
            if (i!=j) & (v1==v2) & ((v1!=0) & (v2!=0)):
                print(v1, v2)
                

In [None]:
most_prob_clusters = []
for i in range(len(handles)):
    most_prob_clusters.append([])
    
for country in list(distributions):
    
    distribution = distributions[country]
    max_ = max(distribution)
    
    max_index = distribution.index(max_)
    
    if country in country_labels_dict:
        country_written = country_labels_dict[country]
    else:
        country_written = ctry_code_name_dict[country]
                
    most_prob_clusters[max_index].append((country_written, max_, sum_dict[country]))
    
most_prob_clusters_sorted = []
for most_prob_cluster in most_prob_clusters:
    
    tuples = most_prob_cluster
    tuples = sorted(tuples, key=lambda tup: tup[2], reverse = True)
    tuples = sorted(tuples, key=lambda tup: tup[1], reverse = True)
    most_prob_clusters_sorted.append(tuples)

most_prob_clusters_sorted


#### Create LaTeX code

In [None]:
latex_code = '\\begin{itemize}'

for i, cluster in enumerate(most_prob_clusters_sorted):
    
    string_ = '\n\item Cluster '+str(i + 1)+' ('+str(len(cluster))+' countries):\n\n'
    
    for j, tuple_ in enumerate(cluster):
                
        prob = tuple_[1]
        
        if prob == 1:
            string_ = string_+'\colorbox{rb100}{\hz '+tuple_[0]+'}'
        elif (prob < 1) & (prob >= 0.95):
            string_ = string_+'\colorbox{rb70}{\hz '+tuple_[0]+'}'
        elif (prob < 0.95) & (prob >= 0.9):
            string_ = string_+'\colorbox{rb30}{\hz '+tuple_[0]+'}'
        elif (prob < 0.9) & (prob >= 0.85):
            string_ = string_+'\colorbox{rb0}{\hz '+tuple_[0]+'}'
        else:
            print('######################')
            print('outside defined ranges')
            print('######################')
            
        
        if j < len(cluster) - 1:
            string_ += ', '
        else:
            string_ += '.'
        
    latex_code += '\n'+string_
                
latex_code += '\n\n\end{itemize}'
print(latex_code)
        

In [None]:
if False:
    
    result_normalised_2010_2019_h_3, result_absolute_2010_2019_h_3 = run_agglomerative(data_to_cluster, data_normalised, data_absolute, n_clusters=3)

    result_normalised = result_normalised_2010_2019_h_3
    result_absolute = result_absolute_2010_2019_h_3

    show_cluster_profiles(result_normalised)

    plot_cluster_counts(result_normalised)


In [None]:
show_result(0)


In [None]:
list_ = ['Japan', 'USA', 'South Korea', 'Germany', 'Italy', 'Taiwan', 'Belgium', 'Austria', 'Netherlands', 'Australia', 'Thailand', 'Switzerland']
print(', '.join(list_)+' ('+str(len(list_))+' countries).')


In [None]:
show_result(1)


In [None]:
list_ = ['India', 'Russia', 'Turkey', 'Bulgaria', 'New Zealand', 'Luxembourg', 'Poland', 'Sweden', 'Malta', 'Mexico', "Democratic People's Republic of Korea", 'Kazakhstan', 'Hungary', 'Serbia', 'Greece']
print(', '.join(list_)+' ('+str(len(list_))+' countries).')


In [None]:
show_result(2)


In [None]:
list_ = ['China', 'UK', 'France', 'Canada', 'Spain', 'Israel', 'Norway', 'Hong Kong SAR (China)', 'Ukraine']
print(', '.join(list_)+' ('+str(len(list_))+' countries).')


#### Find 4 clusters

In [None]:
run_k_means


In [None]:
result_normalised_2010_2019_k_4, result_absolute_2010_2019_k_4 = run_k_means(
    data_to_cluster, data_normalised, data_absolute, False, n_clusters=4)

result_normalised = result_normalised_2010_2019_k_4
result_absolute = result_absolute_2010_2019_k_4

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)


In [None]:
if False:
    
    result_normalised_2010_2019_h_4, result_absolute_2010_2019_h_4 = run_agglomerative(
        data_to_cluster, data_normalised, data_absolute, n_clusters=4)

    result_normalised = result_normalised_2010_2019_h_4
    result_absolute = result_absolute_2010_2019_h_4

    show_cluster_profiles(result_normalised)

    plot_cluster_counts(result_normalised)


In [None]:
cluster_profiles_radar(
    result_normalised,
    "Clustering inventors' countries of origin by their\nbattery type distribution using recent ten years' data:\nProfiles of three clusters computed by k-means algorithm",
    (0.7, 0.87), # legend_pos
    40, # title_pad
    [0.2, 0.4, 0.6, 0.8], # y_ticks
    False,
    True,
    'major',
    'both',
    'black',
    '-',
    1
)


In [None]:
show_result(0)


In [None]:
show_result(1)


In [None]:
show_result(2)


In [None]:
show_result(3)


#### Find 5 clusters

In [None]:
result_normalised_2010_2019_k_5, result_absolute_2010_2019_k_5 = run_k_means(
    data_to_cluster, data_normalised, data_absolute, False, n_clusters=5)

result_normalised = result_normalised_2010_2019_k_5
result_absolute = result_absolute_2010_2019_k_5

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)


In [None]:
if False:
    
    result_normalised_2010_2019_h_5, result_absolute_2010_2019_h_5 = run_agglomerative(
        data_to_cluster, data_normalised, data_absolute, n_clusters=5)

    result_normalised = result_normalised_2010_2019_h_5
    result_absolute = result_absolute_2010_2019_h_5

    show_cluster_profiles(result_normalised)

    plot_cluster_counts(result_normalised)


In [None]:
cluster_profiles_radar(
    result_normalised,
    "Clustering inventors' countries of origin by their\nbattery type distribution using recent ten years' data:\nProfiles of three clusters computed by k-means algorithm",
    (0.7, 0.87), # legend_pos
    40, # title_pad
    [0.2, 0.4, 0.6, 0.8], # y_ticks
    False,
    True,
    'major',
    'both',
    'black',
    '-',
    1
)


In [None]:
show_result(0)


In [None]:
show_result(1)


In [None]:
show_result(2)


In [None]:
show_result(3)


In [None]:
show_result(4)


#### Find 6 clusters

In [None]:
result_normalised_2010_2019_k_6, result_absolute_2010_2019_k_6 = run_k_means(
    data_to_cluster, data_normalised, data_absolute, False, n_clusters=6)

result_normalised = result_normalised_2010_2019_k_6
result_absolute = result_absolute_2010_2019_k_6

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)


In [None]:
if False:
    
    result_normalised_2010_2019_h_6, result_absolute_2010_2019_h_6 = run_agglomerative(
        data_to_cluster, data_normalised, data_absolute, n_clusters=6)

    result_normalised = result_normalised_2010_2019_h_6
    result_absolute = result_absolute_2010_2019_h_6

    show_cluster_profiles(result_normalised)

    plot_cluster_counts(result_normalised)


In [None]:
cluster_profiles_radar(
    result_normalised,
    "Clustering inventors' countries of origin by their\nbattery type distribution using recent ten years' data:\nProfiles of three clusters computed by k-means algorithm",
    (0.7, 0.87), # legend_pos
    40, # title_pad
    [0.2, 0.4, 0.6, 0.8], # y_ticks
    False,
    True,
    'major',
    'both',
    'black',
    '-',
    1
)


In [None]:
show_result(0)


In [None]:
show_result(1)


In [None]:
show_result(2)


In [None]:
show_result(3)


In [None]:
show_result(4)


In [None]:
show_result(5)


### 2000-2009 and 2010-2019

In [None]:
time_periods_list = [[2000,2009], [2010,2019]]

time_periods_list


In [None]:
tech_dist_dfs_absolute_list_2000_2010_2019 = get_tech_dist_dfs_absolute_list(time_periods_list, dfs_technologies_list)
#tech_dist_dfs_absolute_list_2000_2010_2019[0]


In [None]:
#len(tech_dist_dfs_absolute_list_2000_2010_2019[1])
len(tech_dist_dfs_absolute_list_2000_2010_2019[0])


In [None]:
tech_dist_dfs_normalised_list_2000_2010_2019 = get_tech_dist_dfs_normalised_list(
    tech_dist_dfs_absolute_list_2000_2010_2019
)
#tech_dist_dfs_normalised_list_2000_2010_2019[0]


In [None]:
#len(tech_dist_dfs_normalised_list_2000_2010_2019[1])
len(tech_dist_dfs_normalised_list_2000_2010_2019[0])


In [None]:
tech_dist_dfs_normalised_scaled_list_2000_2010_2019 = get_tech_dist_dfs_normalised_scaled_list(
    tech_dist_dfs_normalised_list_2000_2010_2019
)
#tech_dist_dfs_normalised_scaled_list_2000_2010_2019[0]


In [None]:
if False:
    
    print(list(tech_dist_dfs_absolute_list_2000_2010_2019[0].index) == list(tech_dist_dfs_absolute_list_2000_2010_2019[1].index))

    print(list(tech_dist_dfs_normalised_list_2000_2010_2019[0].index) == list(tech_dist_dfs_normalised_list_2000_2010_2019[1].index))

    print(list(tech_dist_dfs_normalised_scaled_list_2000_2010_2019[0].index) == list(tech_dist_dfs_normalised_scaled_list_2000_2010_2019[1].index))


#### 2000-2009

In [None]:
data_absolute, data_normalised, data_to_cluster = get_dataframes(0,
                                                                tech_dist_dfs_absolute_list_2000_2010_2019,
                                                                tech_dist_dfs_normalised_list_2000_2010_2019,
                                                                tech_dist_dfs_normalised_scaled_list_2000_2010_2019)


In [None]:
check_clustering_methods(data_to_cluster)


In [None]:
agglom_clustering_full(data_to_cluster, 2.75, 2.2, 1.9)


##### Find 2 clusters

In [None]:
result_normalised_2000_2009_k_2, result_absolute_2000_2009_k_2 = run_k_means(data_to_cluster, data_normalised, data_absolute,
                                                                             False,
                                                                             n_clusters=2)

result_normalised = result_normalised_2000_2009_k_2
result_absolute = result_absolute_2000_2009_k_2

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)


In [None]:
result_normalised_2000_2009_h_2, result_absolute_2000_2009_h_2 = run_agglomerative(data_to_cluster, data_normalised, data_absolute, n_clusters=2)

result_normalised = result_normalised_2000_2009_h_2
result_absolute = result_absolute_2000_2009_h_2

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)


In [None]:
show_result(0)


In [None]:
show_result(1)


##### Find 3 clusters

In [None]:
result_normalised_2000_2009_k_3, result_absolute_2000_2009_k_3 = run_k_means(data_to_cluster, data_normalised, data_absolute,
                                                                             False,
                                                                             n_clusters=3)

result_normalised = result_normalised_2000_2009_k_3
result_absolute = result_absolute_2000_2009_k_3

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)

In [None]:
show_result(1)

In [None]:
result_normalised_2000_2009_h_3, result_absolute_2000_2009_h_3 = run_agglomerative(data_to_cluster, data_normalised, data_absolute, n_clusters=3)

result_normalised = result_normalised_2000_2009_h_3
result_absolute = result_absolute_2000_2009_h_3

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)

##### Find 5 clusters

In [None]:
result_normalised_2000_2009_k_5, result_absolute_2000_2009_k_5 = run_k_means(data_to_cluster, data_normalised, data_absolute,
                                                                             False,
                                                                             n_clusters=5)

result_normalised = result_normalised_2000_2009_k_5
result_absolute = result_absolute_2000_2009_k_5

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)


In [None]:
result_normalised_2000_2009_h_5, result_absolute_2000_2009_h_5 = run_agglomerative(data_to_cluster, data_normalised, data_absolute, n_clusters=5)

result_normalised = result_normalised_2000_2009_h_5
result_absolute = result_absolute_2000_2009_h_5

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)


#### 2010-2019

In [None]:
data_absolute, data_normalised, data_to_cluster = get_dataframes(1,
                                                                tech_dist_dfs_absolute_list_2000_2010_2019,
                                                                tech_dist_dfs_normalised_list_2000_2010_2019,
                                                                tech_dist_dfs_normalised_scaled_list_2000_2010_2019)


In [None]:
check_clustering_methods(data_to_cluster)


In [None]:
agglom_clustering_full(data_to_cluster, 2.5, 1.85)


##### Find 2 clusters

In [None]:
result_normalised_2010_2019_k_2, result_absolute_2010_2019_k_2 = run_k_means(data_to_cluster, data_normalised, data_absolute,
                                                                             False,
                                                                             n_clusters=2)

result_normalised = result_normalised_2010_2019_k_2
result_absolute = result_absolute_2010_2019_k_2

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)


In [None]:
result_normalised_2010_2019_k_2, result_absolute_2010_2019_k_2 = run_agglomerative(data_to_cluster, data_normalised, data_absolute, n_clusters=2)

result_normalised = result_normalised_2010_2019_k_2
result_absolute = result_absolute_2010_2019_k_2

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)


In [None]:
show_result(0)


In [None]:
show_result(1)


##### Find 3 clusters

In [None]:
result_normalised_2010_2019_k_3, result_absolute_2010_2019_k_3 = run_k_means(
    data_to_cluster, data_normalised, data_absolute, 
    False,
    n_clusters=3)

result_normalised = result_normalised_2010_2019_k_3
result_absolute = result_absolute_2010_2019_k_3

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)


In [None]:
cluster_profiles_radar(
    result_normalised,
    "Clustering inventors' countries of origin by their\nbattery type distribution using recent ten years' data:\nProfiles of three clusters computed by k-means algorithm",
    (0.7, 0.87), # legend_pos
    40, # title_pad
    [0.2, 0.4, 0.6, 0.8], # y_ticks
    False,
    True,
    'major',
    'both',
    'black',
    '-',
    1
)


In [None]:
len(result_normalised)


In [None]:
if False:
    
    result_normalised_2010_2019_h_3, result_absolute_2010_2019_h_3 = run_agglomerative(data_to_cluster, data_normalised, data_absolute, n_clusters=3)

    result_normalised = result_normalised_2010_2019_h_3
    result_absolute = result_absolute_2010_2019_h_3

    show_cluster_profiles(result_normalised)

    plot_cluster_counts(result_normalised)


In [None]:
# find
show_result(0)


In [None]:
show_result(1)


In [None]:
show_result(2)


##### Find 4 clusters

In [None]:
result_normalised_2010_2019_k_4, result_absolute_2010_2019_k_4 = run_k_means(
    data_to_cluster, data_normalised, data_absolute,
    False,
    n_clusters=4)

result_normalised = result_normalised_2010_2019_k_4
result_absolute = result_absolute_2010_2019_k_4

show_cluster_profiles(result_normalised)

plot_cluster_counts(result_normalised)

In [None]:
if False:
    
    result_normalised_2010_2019_h_4, result_absolute_2010_2019_h_4 = run_agglomerative(
        data_to_cluster, data_normalised, data_absolute, n_clusters=4)

    result_normalised = result_normalised_2010_2019_h_4
    result_absolute = result_absolute_2010_2019_h_4

    show_cluster_profiles(result_normalised)

    plot_cluster_counts(result_normalised)

In [None]:
show_result(0)


In [None]:
show_result(1)


In [None]:
show_result(2)


In [None]:
show_result(3)


## Collect all final plots in one place

In [None]:
continent_plot.show()


In [None]:
person_ctry_code_plot.show()


In [None]:
person_ctry_pop_plot_2.show()


In [None]:
bubbles_plot.show()


In [None]:
cluster_profiles_radar(
    result_normalised,
    "Clustering inventors' countries of origin by their\nbattery type distribution using recent ten years' data:\nProfiles of three clusters computed by k-means algorithm",
    (0.7, 0.87), # legend_pos
    40, # title_pad
    [0.2, 0.4, 0.6, 0.8], # y_ticks
    True,
    'major',
    'both',
    'black',
    '-',
    1
)


In [None]:
technologies_countries_all_plot.show()
