In [1]:
from setup_general import *
from setup_embedding import *

# Setup


In [2]:
type_indicators = {}
with open('data/type_indicators/type_ind_cut.txt', 'r') as f:
    for line in f:
        a = line.split('\'')
        type = a[1]
        indicators = a[2].split()
        type_indicators[type] = indicators
save_indicators = {}
with open('data/type_indicators/save_indicator.txt', 'r') as f:
    for line in f:
        a = line.split('\'')
        type = a[1]
        indicators = a[2].split()
        save_indicators[type] = indicators

In [3]:
# naive functions for type from text keywords

def filtering(text):
    pred = []
    for type in types:            
        if type in text:
            pred.append(type)
    if ('drawing' in text) or ('sketch' in text) or ('design' in text):
        pred.append('design/drawing/sketch')
    if len(pred) > 0:
        return pred[-1]
    else:
        return -1
    
def indicating(text):
    pred = []
    for type in types:
        for indicator in type_indicators[type]:
            if indicator in text:
                pred.append(type)
    if len(pred) > 0:
        return pred[-1]
    else:
        return -1

def save_indicating(text):
    pred = []
    for type in types:
        if type in save_indicators.keys():
            for indicator in save_indicators[type]:
                if indicator in text:
                    pred.append(type)
    if len(pred) > 0:
        return pred[-1]
    else:
        return -1


In [4]:
# convert to numeric
def replace_value(value: str):
    if pd.isnull(value):
        return value
    return np.float64(value.replace(',', '.'))


# convert to numeric and only keep year part
def replace_start_end(value: str):
    if pd.isnull(value):
        return value
    if re.match('^d?ddd$', value):
        return int(value)
    if re.match('dddd$', value):
        return int(value[-4:])
    elif not value[0].isdigit():
        return int(f'19{value[-2:]}')
    else:
        return nan


def extract_year_from_name(row):
    name = row['name']
    start = row['start']
    if pd.isnull(start) and not pd.isnull(name):
        match = re.search('\d\d\d\d', name)
        if match:
            start = match.group()
    return start


def preprocess_dataframe(df, submission=False):
    categorical_cols = ['material', 'location', 'before_Christ', 'country_and_unit', 'technique', 'parameter',
                        'museum_abbr', 'damages', 'state', 'color', 'event_type', 'collection_mark']
    categorical_cols += ['unit', 'participants_role', 'participant', 'musealia_mark']

    # just keeping track what values are used
    numeric_cols = ['start', 'end', 'value', 'collection_queue_nr', 'is_original', 'ks', 'element_count',
                    'musealia_seria_nr', 'musealia_queue_nr']

    dropped_cols = ['id', 'parish']  # can't use
    dropped_cols += ['full_nr', 'class', 'collection_additional_nr', 'additional_text', 'text', 'initial_info',
                     'musealia_additional_nr']  # 'commentary','name', 'legend'

    if not submission: dropped_cols.append('type')

    df['start'] = df['start'].apply(replace_start_end)
    df['end'] = df['end'].apply(replace_start_end)
    df['value'] = df['value'].apply(replace_value)
    df['start'] = df[['name', 'start']].apply(extract_year_from_name, axis=1)

    df = df.drop(columns=dropped_cols)
    df = pd.get_dummies(df, columns=categorical_cols)
    df = df.fillna(0)
    return df

In [5]:
def extract_label_from_comment(row):
    # comment #################################################
    comment = row['commentary']

    if not pd.isnull(comment):
        comment = str(comment).lower()

        comment_dict = {
            'lakk': 'pitser/templijäljend',
            'must-valge negatiiv': 'fotonegatiiv',
            'pitserilakk': 'pitser/templijäljend',
            'käepide': 'pitsat',
            'перф': 'fotonegatiiv',
            'fotoemulsioon': 'fotomaterjal',
            'plakat':'plakat'
        }
        for key, val in comment_dict.items():
            if comment.startswith(key):
                return val

        if re.match('^\d,\d\d\sg$', comment):
            return 'münt'

        if 'diapositiiv' in comment:
            return 'diapositiiv'

    # name #################################################
    name = row['name']

    if not pd.isnull(name):
        name = str(name).lower()
        if name == ['denaar', 'killing', 'penn', 'schilling', '1/2 örtug', 'dirhem', 'fyrk']:
            return 'münt'

        for val in ['medal', 'plakat', 'märkmed', 'maal', 'kiri', 'kleit', 'kava', 'joonistus', 'graafika', 'dokument',
                    'ajakiri', 'telegramm', 'skulptuur', 'raamat', 'postkaart', 'nukk', 'skulptuur', 'käsikiri']:
            if name.startswith(val):
                return val

        name_dict = {
            'kaustik': 'kaustik/vihik',
            'vihik': 'kaustik/vihik',
            'reprofoto': 'diapositiiv',
        }
        for key, val in name_dict.items():
            if name.startswith(key):
                return val
    return nan

def replace_predictions(labels, pred):
    result = np.array(pred, copy=True)
    for i, label in enumerate(labels):
        if not pd.isnull(label) and label != 0:
            result[i] = label
    return result

# combine models via class-probability combination (soft-voting)


In [6]:
# is the full ds used for submission?
full = False
# submit to 
sub_name = 'kaspar_type_checks_xgrfnn_no_emb.csv'

In [7]:
#define models to be used for testing use 03 for submission use full
import pickle
xgb = XGBClassifier()
xgb.load_model('models/xg/xg_est_data_smote100_03.json')

rf = pickle.load(open('./models/rf/train_prep_full_best' , 'rb'))

boost_emb = XGBClassifier()
boost_emb.load_model('models/nlp/xgboost_03.json')

nn = TabNetClassifier()
nn.load_model('models/nn/tabnet_full.zip')



In [8]:
data = test_est_prepared.copy() if full else val_est_prepared.copy()

features = data.drop('type', axis=1)
labels = data.type

if not full:
    # at least xgboost cannot deal with string labels
    label_encoder = LabelEncoder()
    label_encoder = label_encoder.fit(labels)
    labels = label_encoder.transform(labels)
    y_test = labels

X_test = features

In [9]:
results = pd.DataFrame()
results['id'] = X_test.index
results.set_index('id', inplace=True)
if not full: results['type'] = y_test

#results['rf'] = rf.predict(X_test)
#results['xg'] = xgb.predict_proba(X_test)
#results['nn'] = nn.predict(X_test.values)

results['filter'] = [-1] * len(results)
results['indi'] = [-1] * len(results)
results['save'] = [-1] * len(results)
results['emb'] = [[-1]] * len(results)


results['xg'] = [[-1]] * len(results)

In [10]:
for i,item in enumerate(xgb.predict_proba(X_test)):
    results['xg'].iloc[i] = np.array(item)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  results['xg'].iloc[i] = np.array(item)


In [11]:
text = test_curie.copy() if full else val_curie.copy()

features = text.drop('type', axis=1)
labels = text.type

#text['pred'] = boost_emb.predict_proba(features)
text['pred'] = [[-1]] * len(features)

In [12]:
for i,item in enumerate(boost_emb.predict_proba(features)):
    text['pred'].iloc[i] = np.array(item)

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  text['pred'].iloc[i] = np.array(item)


In [13]:
for index, item in text.iterrows():
        if index in results.index:
            results.at[index, 'emb'] = item['pred']

# evalaluate

In [14]:
from operator import add
def vote(preds):
    if preds[-1][0] == -1:
        preds = preds[:-1]
    res = np.sum(preds, axis=0)
    return np.argmax(res)

In [23]:
train, val = train_test_split(train_prep, test_size=0.3, random_state=0)

In [25]:
train.to_csv('data/prepared_ready/train_prepared.csv')
val.to_csv('data/prepared_ready/val_prepared.csv')

In [15]:
results['prediction'] = results.apply(lambda row: vote([row.xg, row.emb]), axis=1)
if not full:
    results.prediction = results.prediction.replace(type_lookup.id.to_list(), type_lookup.estonian.to_list())

    a = results.prediction.copy().tolist()

    df_submission = pd.read_csv("data/general/val.csv")
    x2 = preprocess_dataframe(df_submission, submission=True)
    # reorder columns + add missing columns + remove extra columns
    x2_labels = x2.apply(extract_label_from_comment, axis=1)    

    results.prediction = replace_predictions(x2_labels, results.prediction)
    submission = pd.DataFrame({'id': results.index ,'type': results.prediction})

    b = results.prediction.copy().tolist()

    count = 0
    for i in range(len(a)):
        if a[i] != b[i]:
            count += 1
            print(a[i], b[i])

    print(count)

    results.type = results.type.replace(type_lookup.id.to_list(), type_lookup.estonian.to_list())

    print(accuracy_score(results.type, results.prediction))
    print(classification_report(results.type, results.prediction))

kaustik/vihik märkmed
kiri, postkaart kiri
kiri postkaart
fotonegatiiv kiri
dokument kava
foto diapositiiv
foto skulptuur
väiketrükis postkaart
kavand/joonis/eskiis kava
graafika joonistus
10
0.960952380952381
                               precision    recall  f1-score   support

                      ajakiri       1.00      0.88      0.94        33
                      ajaleht       0.83      0.74      0.78        34
                        album       1.00      0.85      0.92        13
          arheoloogiline leid       1.00      1.00      1.00       240
             aukiri/auaadress       1.00      0.70      0.82        10
                  diapositiiv       0.95      1.00      0.97        19
           digitaalne kujutis       1.00      1.00      1.00        57
                     dokument       0.83      0.89      0.86       125
                          ehe       1.00      1.00      1.00         6
                         foto       0.97      0.99      0.98      1067
        

In [21]:
results[results.type == 'aukiri/auaadress']

Unnamed: 0_level_0,type,filter,indi,save,emb,xg,prediction
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
3981017,aukiri/auaadress,-1,-1,-1,"[0.05482011, 0.0065371706, 0.004327592, 0.0033103472, 0.0057972725, 0.0033087928, 0.005312763, 0.0041445396, 0.0039782496, 0.004369275, 0.003493294, 0.0028150713, 0.0029248407, 0.010927359, 0.11970181, 0.0044690203, 0.009431428, 0.003429786, 0.011684738, 0.0065783095, 0.0075291474, 0.006122921, 0.04034316, 0.09955946, 0.009296848, 0.016754968, 0.016625019, 0.0038950744, 0.007263255, 0.0056075403, 0.0027574417, 0.010583477, 0.06880006, 0.0048672473, 0.0031847265, 0.007827149, 0.25767124, 0.035119977, 0.0044708448, 0.0035428074, 0.004447958, 0.001981915, 0.005400954, 0.015822336, 0.0066100056, 0.0028548504, 0.0038920257, 0.006920653, 0.03261955, 0.003431533, 0.002740225, 0.011506938, 0.003415175, 0.009222703, 0.0059490027]","[3.7493588e-05, 3.3211967e-05, 6.718247e-05, 2.148774e-05, 0.00027233292, 7.523039e-05, 2.1203827e-05, 3.3023665e-05, 5.0521943e-05, 1.9591118e-05, 2.0848433e-05, 8.974229e-05, 2.6517198e-05, 3.5062236e-05, 0.001228506, 6.0772723e-05, 1.6292895e-05, 1.6710923e-05, 3.0420912e-05, 0.58532476, 1.6482089e-05, 1.3070626e-05, 0.01896723, 0.07556466, 2.5793084e-05, 0.0006818629, 0.00018835215, 5.852108e-05, 5.6444725e-05, 3.2559776e-05, 8.168336e-06, 4.733427e-05, 0.00022884808, 9.5231546e-05, 3.645736e-05, 0.00017868419, 0.02611204, 4.6379937e-05, 8.92714e-06, 0.00026367634, 3.4266945e-05, 0.00363886, 0.027102204, 7.411641e-05, 3.1205407e-05, 4.566685e-05, 2.515986e-05, 3.3118726e-05, 3.0267278e-05, 2.6893458e-05, 4.527091e-05, 0.00028714054, 1.8289616e-05, 1.8518875e-05, 0.25847742]",kutse
3919333,aukiri/auaadress,-1,-1,-1,"[5.455613e-05, 2.9018593e-05, 5.1901276e-05, 4.953952e-05, 3.3625485e-05, 8.44691e-05, 6.3002226e-05, 6.5141576e-05, 4.5150573e-05, 4.6654408e-05, 4.7051093e-05, 2.1001213e-05, 4.05426e-05, 5.8504793e-05, 0.002722113, 6.391697e-05, 0.000100076686, 8.698379e-05, 4.4668537e-05, 0.0001707827, 4.2809635e-05, 0.0006295877, 0.006583444, 0.9856528, 7.549475e-05, 3.7460493e-05, 0.00021332045, 6.805574e-05, 6.4047854e-05, 0.00024764065, 4.93309e-05, 3.8167185e-05, 6.676827e-05, 9.604127e-05, 4.7659592e-05, 5.1244144e-05, 0.0002871228, 0.0002792484, 8.567215e-05, 5.616535e-05, 5.9146805e-05, 0.00025728907, 0.00018804094, 3.814415e-05, 8.854093e-05, 5.6457222e-05, 7.186659e-05, 0.00013256414, 0.00012873393, 5.1353043e-05, 4.3712716e-05, 9.2473834e-05, 5.110827e-05, 7.1364666e-05, 0.00011843657]","[3.3368382e-05, 3.24704e-05, 6.5682405e-05, 2.1007958e-05, 8.6437e-05, 6.602499e-05, 1.530377e-05, 7.919516e-05, 4.9393875e-05, 2.3682496e-05, 2.0382924e-05, 2.282999e-05, 2.7821097e-05, 3.4848315e-05, 0.0077552563, 5.941578e-05, 0.00022608397, 8.21487e-06, 4.588319e-05, 0.008883088, 1.1483538e-05, 1.8231589e-05, 0.68556213, 0.25389504, 2.2570512e-05, 5.6596775e-05, 0.0009525746, 5.479174e-05, 5.5184366e-05, 3.1832773e-05, 8.21545e-06, 6.024293e-05, 0.00013472026, 0.0021103963, 3.564333e-05, 3.6529153e-05, 0.00043715938, 4.5344354e-05, 8.006517e-06, 4.609116e-05, 3.3501794e-05, 0.003468853, 0.03243137, 5.3845597e-05, 3.076855e-05, 1.869431e-05, 3.5977537e-05, 3.237921e-05, 3.206993e-05, 2.629295e-05, 4.4260087e-05, 0.0021687571, 1.7881224e-05, 1.9134859e-05, 0.00044701443]",aukiri/auaadress
3357383,aukiri/auaadress,-1,-1,-1,"[0.0002968234, 0.000104709405, 0.00022225745, 0.0001787915, 0.0003193982, 0.00039727162, 0.00020423828, 0.00014923282, 0.00034095583, 0.00024111076, 0.00033520145, 9.626588e-05, 0.00021201523, 0.0002519607, 0.00061706104, 0.00024390333, 0.00013444416, 0.00021485344, 6.000971e-05, 0.00024743256, 0.00018960515, 0.0005009036, 0.00049251295, 0.97482705, 0.00023598812, 0.0004114674, 0.0003035261, 0.00023660716, 0.0030597365, 0.00027730115, 0.00031913817, 0.0006262896, 0.0009564463, 0.00019306986, 0.00025390126, 0.0003521886, 0.004736345, 0.00057706016, 0.00013655965, 0.00024737947, 0.0002440867, 0.0006974296, 0.00013084496, 0.00045648063, 0.00022526628, 0.00036406322, 0.00014053729, 0.00043958865, 0.00021325586, 0.00045487835, 0.0005099274, 0.0015162056, 0.00015957661, 0.00043438442, 0.0002124816]","[1.5097351e-05, 3.2253586e-06, 6.524382e-06, 2.6023479e-06, 3.3862016e-05, 6.558419e-06, 2.6670884e-06, 1.1799526e-05, 9.395962e-06, 2.9891032e-06, 2.1752016e-06, 4.1855956e-06, 1.9816896e-05, 2.669814e-06, 8.567478e-05, 8.0229465e-06, 2.951193e-06, 5.4326006e-06, 4.020479e-06, 4.727457e-06, 9.165928e-06, 1.8913373e-06, 0.0058237324, 0.978323, 3.461823e-06, 7.72924e-06, 0.00056662475, 3.9907177e-06, 0.014485962, 5.2748055e-06, 2.5154382e-06, 5.345153e-06, 7.519662e-05, 2.195036e-06, 3.921488e-06, 9.526245e-05, 7.6683034e-05, 6.0951943e-06, 3.21439e-06, 3.7562295e-06, 3.3278125e-06, 0.0001458808, 3.563179e-05, 4.494456e-06, 3.4280763e-06, 1.3838052e-05, 6.81651e-06, 3.4739937e-06, 4.102788e-06, 2.6117384e-06, 1.6604316e-05, 1.3369646e-05, 2.9256473e-06, 1.8456367e-06, 2.2972429e-06]",aukiri/auaadress
2601136,aukiri/auaadress,-1,-1,-1,"[0.0005852842, 0.0009492809, 0.00036559353, 0.00097428216, 0.46167693, 0.00080585136, 0.00065019855, 0.0006340824, 0.0010119078, 0.0012452521, 0.0006727174, 0.0036678754, 0.0006777405, 0.001480181, 0.003291682, 0.00046406346, 0.0054377518, 0.00096086925, 0.003959462, 0.0013884109, 0.0012855448, 0.00054625206, 0.0035639834, 0.0004908742, 0.00049170444, 0.00052383915, 0.124670826, 0.0005896263, 0.00048776498, 0.0007200852, 0.0010929053, 0.0004669731, 0.000572041, 0.00094819383, 0.00041630247, 0.0015343047, 0.3365208, 0.010891257, 0.0016683788, 0.00042058955, 0.0008056704, 0.012163395, 0.0003246152, 0.00016737556, 0.0008329433, 0.00028820732, 0.001811489, 0.0007899551, 0.00065112434, 0.0007079606, 0.00044761662, 0.00065022206, 0.00042781056, 0.0005857817, 0.0005441489]","[7.732092e-05, 4.626696e-05, 0.00016555267, 3.9214654e-05, 0.0020468282, 1.3943638e-05, 0.00013565987, 0.00056930014, 6.3054074e-05, 4.7666435e-05, 5.866741e-05, 4.090746e-05, 3.6940604e-05, 4.4594923e-05, 0.024948029, 8.4661326e-05, 3.1546602e-05, 0.0011393966, 4.099158e-05, 0.003306738, 1.8898767e-05, 5.9162176e-05, 0.57534003, 0.0004147213, 6.4363776e-05, 4.8755042e-05, 0.19774765, 1.588629e-05, 7.8632e-05, 0.0001165811, 0.00013281996, 0.00011618308, 0.0015679303, 3.9075032e-05, 5.0788058e-05, 6.0883503e-05, 0.045088038, 4.6593785e-05, 5.9093833e-05, 2.645331e-05, 0.00023395292, 0.12406142, 0.021060297, 0.00012116343, 5.2237134e-05, 4.697351e-05, 2.8985141e-05, 4.983357e-05, 4.3032083e-05, 3.6708192e-05, 6.3920226e-05, 8.4060106e-05, 2.919053e-05, 2.1541326e-05, 3.6885547e-05]",kiri
2301893,aukiri/auaadress,-1,-1,-1,"[7.53118e-05, 4.4530174e-05, 7.713388e-05, 3.0156038e-05, 0.00022688818, 3.4856534e-05, 6.401084e-05, 3.7575348e-05, 6.8613386e-05, 3.5629848e-05, 4.1394935e-05, 1.7935232e-05, 9.195305e-05, 3.3685566e-05, 0.0003050259, 3.4963738e-05, 4.677202e-05, 2.7204711e-05, 2.9911032e-05, 4.6149937e-05, 2.9072275e-05, 7.5750475e-05, 0.0009373545, 0.9952799, 7.3355004e-05, 7.6101875e-05, 0.00016085956, 2.160633e-05, 3.6127072e-05, 5.334247e-05, 3.402659e-05, 3.0552037e-05, 2.7143784e-05, 3.202165e-05, 6.0697374e-05, 2.6012644e-05, 0.00043620417, 5.7703157e-05, 3.4516906e-05, 3.448706e-05, 4.2525855e-05, 9.640117e-05, 4.3461772e-05, 0.00037859078, 4.6329675e-05, 1.43095e-05, 2.7855564e-05, 5.3533022e-05, 3.402792e-05, 3.832102e-05, 5.478162e-05, 0.00015678808, 4.053173e-05, 4.749507e-05, 3.8579576e-05]","[0.00052734214, 0.00022317216, 0.00045144142, 0.00014438969, 0.0004840973, 0.00045342057, 0.00018454377, 0.00064450694, 0.00059926877, 0.000110759574, 0.00015825938, 0.00026168817, 0.00017818579, 0.0002395158, 0.15279602, 0.0006255867, 6.888634e-05, 0.00023667478, 0.0006647324, 0.00016420521, 0.00065952423, 0.00014737406, 0.17525594, 0.45192367, 0.0005358745, 0.003233676, 0.0030163776, 0.0002890534, 0.0005055258, 0.00045257338, 0.00022954514, 0.00036984702, 0.119774215, 7.148626e-05, 0.00024498, 0.0014618606, 0.0021189132, 0.0006736716, 0.00012929288, 0.00028580517, 0.00023026123, 0.07405727, 0.00017845532, 0.000411162, 0.00018828428, 0.0002804962, 0.0003043036, 0.00027369426, 0.0005230383, 0.00018071395, 0.0010411436, 0.00062731805, 0.00015712951, 0.00077909813, 0.00017178452]",aukiri/auaadress
3980694,aukiri/auaadress,-1,-1,-1,"[0.00016361146, 0.00012514833, 4.284193e-05, 4.0546915e-05, 0.002029596, 8.590211e-05, 4.332722e-05, 0.00010133158, 4.239553e-05, 6.572852e-05, 3.5469206e-05, 4.0639752e-05, 5.3909152e-05, 8.0478065e-05, 0.00048128964, 3.9601266e-05, 0.00022150294, 3.179812e-05, 7.970031e-05, 3.324587e-05, 3.057975e-05, 0.0001231261, 0.00032923208, 0.9919328, 7.258692e-05, 0.0002624752, 0.0002825648, 4.9783983e-05, 4.633152e-05, 0.0003306562, 7.2443145e-05, 5.4011714e-05, 8.543006e-05, 0.00012945337, 3.4044173e-05, 7.012423e-05, 0.0008799564, 0.00012730621, 4.2953616e-05, 3.4643832e-05, 4.75479e-05, 8.125117e-05, 0.00017544498, 0.000117842224, 5.5185414e-05, 3.0832878e-05, 0.000109534325, 0.00011607147, 4.6183428e-05, 3.668251e-05, 2.686336e-05, 0.00011299827, 3.6507627e-05, 0.00013657166, 4.190457e-05]","[1.814262e-05, 8.7805965e-06, 1.7761737e-05, 5.680935e-06, 3.4291606e-05, 3.3893033e-05, 5.605879e-06, 8.7308135e-06, 1.3357005e-05, 5.1795055e-06, 5.5119194e-06, 1.9343133e-05, 7.010631e-06, 9.269765e-06, 0.00054336945, 1.6067124e-05, 4.30752e-06, 4.4180383e-06, 7.2734865e-06, 0.13803054, 4.3575437e-06, 3.4556192e-06, 0.0051152804, 0.023837794, 6.81919e-06, 8.42904e-05, 4.979664e-05, 1.626995e-05, 1.4922873e-05, 4.5989204e-06, 2.1595488e-06, 1.251426e-05, 8.22933e-05, 2.5177367e-05, 9.638616e-06, 4.7240574e-05, 0.006903512, 1.226194e-05, 2.3601597e-06, 6.9710855e-05, 9.059513e-06, 0.0007873535, 0.0012925606, 1.9594916e-05, 8.250095e-06, 1.9278308e-05, 6.102233e-06, 8.755946e-06, 8.002072e-06, 7.1101003e-06, 1.1968735e-05, 7.59143e-05, 4.835414e-06, 4.896025e-06, 0.8226334]",aukiri/auaadress
4082342,aukiri/auaadress,-1,-1,-1,"[2.6137008e-05, 1.9280622e-05, 5.116026e-05, 3.067902e-05, 6.7242545e-05, 3.779148e-05, 3.2592343e-05, 4.0387815e-05, 3.3496057e-05, 2.5195139e-05, 7.388749e-05, 2.3570949e-05, 5.658633e-05, 3.8971077e-05, 0.0017567308, 3.0194718e-05, 5.1637202e-05, 4.383266e-05, 0.000112688176, 8.3363615e-05, 3.1739964e-05, 9.649307e-05, 0.0022823405, 0.9913374, 4.3811095e-05, 4.629881e-05, 5.039881e-05, 2.9394421e-05, 3.1736876e-05, 5.568981e-05, 1.1168893e-05, 2.5841004e-05, 2.6864729e-05, 6.084314e-05, 2.940387e-05, 2.7593944e-05, 0.0020588457, 0.00013904693, 4.5654273e-05, 4.0414554e-05, 4.1676776e-05, 7.7309975e-05, 6.711742e-05, 7.5179916e-05, 3.5254103e-05, 3.092303e-05, 7.6694516e-05, 0.00018792732, 3.2182106e-05, 3.8379727e-05, 2.1758853e-05, 0.00010300317, 2.7835891e-05, 4.2804415e-05, 3.5526366e-05]","[0.0009748938, 0.00013253036, 0.0002680876, 8.574554e-05, 0.004288453, 0.00031696534, 6.722925e-05, 0.00048993394, 0.00020160478, 9.666187e-05, 8.319442e-05, 0.00013364587, 0.00011355387, 0.000142236, 0.03693428, 0.00024250991, 0.00035845803, 0.0001888932, 0.00017492373, 0.0059506153, 4.687092e-05, 5.478046e-05, 0.6814322, 0.01314053, 9.212322e-05, 0.003275447, 0.0138858445, 4.1599214e-05, 0.00022523908, 0.00019695733, 2.8741571e-05, 0.00024588613, 0.0037011302, 0.013324545, 0.00014548091, 0.016579889, 0.003146562, 0.00018507635, 3.539872e-05, 0.00016234505, 0.00013674007, 0.035888307, 0.15816972, 0.0002197748, 0.0001439755, 0.00010533211, 0.00014684498, 0.00013215815, 0.0001308958, 0.00010731663, 0.00018065085, 0.0012851232, 7.298355e-05, 7.8100355e-05, 0.0020110235]",aukiri/auaadress
3838819,aukiri/auaadress,-1,-1,-1,"[5.110482e-05, 7.935102e-05, 8.397675e-05, 4.8142e-05, 0.00017802836, 5.6838588e-05, 6.541052e-05, 8.639204e-05, 0.00016228738, 6.750966e-05, 6.848658e-05, 3.50317e-05, 5.045065e-05, 0.00010255201, 0.0010146261, 5.3875207e-05, 5.3270604e-05, 6.818484e-05, 8.788887e-05, 0.00036527097, 4.5903216e-05, 0.00015644086, 0.0002557211, 0.9895325, 6.493808e-05, 5.603375e-05, 0.00012903498, 5.3771062e-05, 7.7829885e-05, 0.0015687644, 0.00010270587, 6.0578437e-05, 5.6848992e-05, 5.1906238e-05, 5.7896705e-05, 0.00017203836, 0.00025100852, 0.0016468167, 0.00011106275, 6.2572435e-05, 6.5885266e-05, 0.00021385417, 8.099861e-05, 0.00021045037, 6.216085e-05, 8.7327084e-05, 9.922009e-05, 0.00017670881, 0.00073963596, 4.9904414e-05, 8.1506856e-05, 0.00016991944, 4.9666498e-05, 6.585957e-05, 0.0005539219]","[0.00026035242, 3.8930026e-05, 7.874918e-05, 2.518723e-05, 5.5539174e-05, 7.9094425e-05, 1.8348284e-05, 6.894865e-05, 5.9220176e-05, 2.8393863e-05, 2.362329e-05, 2.6342528e-05, 3.3355824e-05, 3.940361e-05, 0.060090672, 7.123588e-05, 0.00018089914, 1.9684967e-05, 5.5011136e-05, 0.00051488716, 1.3768072e-05, 2.2895887e-05, 0.7144585, 0.0023281977, 2.7060689e-05, 0.019043759, 0.0006804552, 7.338229e-05, 6.616267e-05, 3.8165545e-05, 2.4775747e-05, 7.222766e-05, 8.121254e-05, 9.417796e-06, 4.273417e-05, 4.704737e-05, 0.0016548313, 5.4365057e-05, 9.599325e-06, 5.756196e-05, 4.016664e-05, 0.0035511514, 0.19516084, 6.455751e-05, 4.1935684e-05, 2.145717e-05, 5.1522773e-05, 3.882073e-05, 3.844988e-05, 3.1523636e-05, 5.3065083e-05, 8.157947e-05, 2.1438495e-05, 3.71722e-05, 0.00019236696]",aukiri/auaadress
2996200,aukiri/auaadress,-1,-1,-1,"[0.0022064904, 0.00049809716, 0.0011803814, 0.0011002464, 0.032493386, 0.0011414068, 0.0018297245, 0.0017305006, 0.0010891235, 0.001000172, 0.0020189642, 0.00080954545, 0.0009398364, 0.0010874096, 0.0077684317, 0.0013766644, 0.001045563, 0.0017300155, 0.0010703838, 0.0016798297, 0.0010615113, 0.0022395132, 0.005370971, 0.66825163, 0.0013219651, 0.0015097783, 0.020217013, 0.002476616, 0.0012941608, 0.005011194, 0.0019170786, 0.0016636516, 0.0012178856, 0.0012746734, 0.0011990269, 0.0011696107, 0.0053833653, 0.07417574, 0.0019461999, 0.0022368112, 0.001719956, 0.086811386, 0.0009277781, 0.023201002, 0.002088322, 0.0025970486, 0.0014533368, 0.0038286326, 0.0017425985, 0.0011405244, 0.0010521598, 0.0050859335, 0.0011350876, 0.0014299755, 0.001051734]","[0.00010473036, 1.5946529e-05, 3.2257267e-05, 1.0317212e-05, 9.165889e-05, 3.304365e-05, 5.928588e-06, 2.2720706e-05, 2.1732436e-05, 9.40656e-06, 9.67659e-06, 6.4567874e-05, 1.2732076e-05, 1.5877e-05, 0.004525698, 2.9179668e-05, 7.574106e-06, 4.371892e-06, 1.4606422e-05, 0.021711903, 7.913779e-06, 6.2757845e-06, 0.00830134, 0.00012461198, 1.2384397e-05, 1.5169471e-05, 0.00041465374, 5.5373357e-06, 2.7101603e-05, 1.5633384e-05, 2.2237698e-06, 2.272727e-05, 0.00018417036, 5.3907177e-05, 1.7504786e-05, 7.45298e-05, 0.0041704215, 2.2269052e-05, 4.6430114e-06, 1.0308253e-05, 1.6453072e-05, 0.9577341, 0.0007680131, 3.558656e-05, 1.5757741e-05, 1.2094964e-05, 1.1725402e-05, 1.5901762e-05, 1.4532655e-05, 1.2912737e-05, 2.173656e-05, 0.0009056365, 8.781652e-06, 1.406596e-05, 0.00017551567]",kava
4033990,aukiri/auaadress,-1,-1,-1,"[6.439242e-05, 4.3835582e-05, 6.401085e-05, 5.2840063e-05, 0.00088498474, 7.735392e-05, 0.00019269086, 6.590794e-05, 5.3754182e-05, 8.884103e-05, 5.0185852e-05, 0.0001200033, 4.256842e-05, 6.742652e-05, 0.00072834874, 0.000105685926, 0.0012661049, 8.366794e-05, 0.0017866598, 0.00021005851, 6.215906e-05, 0.00011346459, 0.0039056037, 0.98593396, 7.556916e-05, 6.044829e-05, 6.116212e-05, 3.5359524e-05, 6.91824e-05, 5.2661824e-05, 3.340558e-05, 5.5730026e-05, 7.4042946e-05, 6.5927554e-05, 6.465474e-05, 9.240164e-05, 0.0005712897, 0.00022903412, 9.822998e-05, 5.779155e-05, 0.00011721135, 0.00022067543, 6.153164e-05, 3.1834603e-05, 5.1611987e-05, 0.00010773032, 4.6859484e-05, 0.0009997013, 9.10862e-05, 5.4774468e-05, 5.0614348e-05, 0.00017929822, 7.096595e-05, 9.166945e-05, 6.305729e-05]","[4.4452423e-05, 2.3950368e-05, 4.8447742e-05, 1.5495598e-05, 6.883029e-05, 5.704659e-05, 2.3693447e-05, 0.00011286869, 3.6433226e-05, 2.694315e-05, 1.8440584e-05, 3.4969824e-05, 2.052101e-05, 2.5284642e-05, 0.6123347, 4.3825443e-05, 5.463426e-05, 1.1329969e-05, 3.1211617e-05, 0.00015885782, 2.6434045e-05, 1.3993566e-05, 0.015483022, 0.35318983, 1.6398986e-05, 7.029496e-05, 0.0001993641, 0.0001374813, 4.0704326e-05, 2.3480048e-05, 8.335871e-05, 3.8214625e-05, 0.0016769356, 0.00011325881, 2.629074e-05, 5.2283933e-05, 0.005641363, 2.851621e-05, 6.0763746e-06, 0.0030209695, 2.471113e-05, 0.005268134, 0.0012506808, 5.9380196e-05, 4.4771426e-05, 1.41291e-05, 1.9429033e-05, 2.7623995e-05, 2.3654975e-05, 2.7504047e-05, 3.2646512e-05, 8.086719e-05, 1.3189301e-05, 2.2868928e-05, 1.0165227e-05]",aukiri/auaadress


In [None]:
album foto
ajaleht dokument


# submission

In [16]:
if full:
    results.prediction = results.prediction.replace(type_lookup.id.to_list(), type_lookup.estonian.to_list())

    a = results.prediction.copy().tolist()

    df_submission = pd.read_csv("data/general/test.csv")
    x2 = preprocess_dataframe(df_submission, submission=True)
    # reorder columns + add missing columns + remove extra columns
    x2_labels = x2.apply(extract_label_from_comment, axis=1)    

    results.prediction = replace_predictions(x2_labels, results.prediction)
    submission = pd.DataFrame({'id': results.index ,'type': results.prediction})

    b = results.prediction.copy().tolist()

    count = 0
    for i in range(len(a)):
        if a[i] != b[i]:
            count += 1
            print(a[i], b[i])

    print(count)


    #submission.groupby('type').nunique()  # predicted classes

    submission.to_csv(f'submissions/{sub_name}', index=False)