In [1]:
import os
import json

import numpy as np
import pandas as pd
from tqdm import tqdm
import datetime as dt
import matplotlib.pyplot as plt
from collections import Counter, defaultdict

In [2]:
def f2cat(filename: str) -> str:
    return filename.split('.')[0]

class Simplified():
    def __init__(self, input_path='./input'):
        self.input_path = input_path

    def list_all_categories(self):
        files = os.listdir(os.path.join(self.input_path, 'train_simplified'))
        if 'desktop.ini' in files: files.remove('desktop.ini')
        return sorted([f2cat(f) for f in files], key=str.lower)

    def read_training_csv(self, category, nrows=None, usecols=None):
        df = pd.read_csv(os.path.join(self.input_path, 'train_simplified', category + '.csv'),
                         nrows=nrows, parse_dates=['timestamp'], usecols=usecols)
        df = df.drop(columns=['timestamp', 'word'])
        df.drawing = df.drawing.apply(json.loads)
        df.countrycode = df.countrycode.apply(lambda x: countrymap[x])
        return df

In [3]:
countrycodes = ['US', 'GB', 'CA', 'DE', 'AU', 'RU', 'BR', 'SE', 'FI', 'CZ', 'IT', 'PL', 'FR', 'KR', 'TH', 'PH', 'SA', 'HU', 'NL', 'ID', 'RO', 'IN', 'SK', 'VN', 'JP', 'AE', 'TW', 'UA', 'MY', 'NO', 'NZ', 'IE', 'HR', 'TR', 'RS', 'BG', 'HK', 'AT', 'DK', 'MX', 'ES', 'PT', 'CH', 'SG', 'BE', 'IL', 'AR', 'EE', 'ZA', 'BA', 'LT', 'IQ', 'GR', 'EG', 'KZ', 'LV', 'KW', 'DZ', 'CL', 'BY', 'CO', 'QA', 'KH', 'SI', 'JO', 'PK', 'ZZ', 'PE', 'BH', 'MA', 'MK', 'GE', 'MD', 'IS', 'OM', 'BD', 'PR', 'PS', 'VE', 'LB', 'EC', 'TN', 'NP', 'MT', 'CY', 'BN', 'ME', 'CR', 'TT', 'LU', 'AM', 'AL', 'UY', 'GU', 'DO', 'MO', 'AZ', 'RE', 'GT', 'HN', 'MV', 'PA', 'KE', 'KG', 'MN', 'SV', 'BO', 'MU', 'JM', 'CN', 'JE']
countrymap = defaultdict(lambda: len(countrycodes), {key:i for i,key in enumerate(countrycodes)})

In [4]:
start = dt.datetime.now()
s = Simplified('../input')
NCSVS = 200
categories = s.list_all_categories()
print("Total categories : ", len(categories))

340


In [None]:
ls = []
recognized = Counter()
countries = Counter()
for y in tqdm(range(len(categories))):
    cat = categories[y]
    df = pd.read_csv('../input/train_simplified/'+cat+'.csv')
    countries.update(dict(df.countrycode.value_counts()))
    recognized.update(dict(df.recognized.value_counts()))
    ls.append(len(df))

In [None]:
# Visualize country data
breakdown = []
for country in countries:
    breakdown.append((country, countries[country], countries[country]/497057.090))

breakdown = sorted(breakdown, key=lambda x: x[1], reverse=True)
for data in breakdown[:10]: print(data[0], data[1], data[2])
    
fig, axs = plt.subplots()
axs.plot([x[2] for x in breakdown], marker='.')
axs.set_ylabel('% of datapoints')
axs.set_xlabel('n-th country rank')

In [None]:
for y in tqdm(range(len(categories))):
    cat = categories[y]
    df = s.read_training_csv(cat, nrows=100000)
    df['y'] = y
    df['key_id'] = (df.key_id // 10 ** 7) % NCSVS
    for k in range(NCSVS):
        filename = 'shuffle_csvs/train_k{}.csv'.format(k)
        chunk = df[df.key_id == k]
        chunk = chunk.drop(['key_id'], axis=1)
        if y==0: chunk.to_csv(filename, index=False)
        else:    chunk.to_csv(filename, mode='a', header=False, index=False)

 21%|████████████████▉                                                                | 71/340 [06:28<21:03,  4.70s/it]

In [None]:
for k in tqdm(range(NCSVS)):
    filename = 'shuffle_csvs/train_k{}.csv'.format(k)
    if os.path.exists(filename):
        df = pd.read_csv(filename)
        df['rnd'] = np.random.rand(len(df))
        df = df.sort_values(by='rnd').drop('rnd', axis=1)
        df.to_csv(filename, index=False)# + '.gz', compression='gzip', index=False)
#         os.remove(filename)
print(df.shape)

In [None]:
end = dt.datetime.now()
print('Latest run {}.\nTotal time {}s'.format(end, (end - start).seconds))