In [None]:
# coding: utf-8
__author__ = 'ZFTurbo: https://kaggle.com/zfturbo'


USE_PARALLEL = False
USE_RDKIT = False
LIMIT_CHECKS = 500
PROC_NUM = 16
DEBUG_LIMIT = 100000000


import sys
if USE_RDKIT:
    from rdkit import Chem
    from rdkit import RDLogger

if __name__ == '__main__':
    if len(sys.argv) == 2:
        inchi = sys.argv[1].strip()
        try:
            mol = Chem.MolFromInchi(inchi)
            if mol is None:
                print(0, end='')
            else:
                print(1, end='')
        except Exception as e:
            print(0, end='')
        exit()


import Levenshtein
import warnings
import subprocess
import itertools
from multiprocessing import Process
import gc
import numpy as np
import pandas as pd
import os
import time
import glob
import operator


warnings.simplefilter("ignore")
warnings.filterwarnings("ignore", message='ERROR:')


def get_avg_strings(s1, s2):
    dist = Levenshtein.distance(s1, s2)
    s1_new = s1
    intermediate_strings = [s1_new]
    for i in range(dist):
        e = Levenshtein.editops(s1_new, s2)
        s1_new = Levenshtein.apply_edit(e[:1], s1_new, s2)
        intermediate_strings.append(s1_new)
    check = Levenshtein.distance(s2, s1_new)
    if check != 0:
        print('Some error here!')
        exit()
    return intermediate_strings[1:-1]


def check_if_valid_exec_v2(inchi: str) -> int:
    from subprocess import check_output
    try:
        script_path = os.path.realpath(__file__)
        ret = check_output([sys.executable, script_path, inchi], universal_newlines=True)
        if ret == '1':
            return 1
        else:
            return 0
    except Exception as e:
        print('Exception: {}'.format(e))
        return 0


def normalize_inchi(inchi):
    try:
        mol = Chem.MolFromInchi(inchi)
        return 'bad' if (mol is None) else Chem.MolToInchi(mol)
    except:
        return 'bad'


def merge_all_variants(str1):
    res = dict()
    # First generate all possible conmbinations
    data = list(itertools.product(*str1))
    # Now sum distances for all parts and create new array
    for i in range(len(data)):
        dist = 0
        str_list = []
        for j in range(len(data[i])):
            dist += data[i][j][1]
            str_list.append(data[i][j][0])
        str_data = '/'.join(str_list)
        res[str_data] = dist

    res = sort_dict_by_values(res, reverse=False)
    return res


def sort_dict_by_values(a, reverse=True):
    sorted_x = sorted(a.items(), key=operator.itemgetter(1), reverse=reverse)
    return sorted_x


def get_score(y_true, y_pred):
    scores = np.zeros(len(y_true), dtype=np.float32)
    for i in range(len(y_true)):
        score = Levenshtein.distance(y_true[i], y_pred[i])
        scores[i] = score
    avg_score = np.mean(scores)
    return avg_score


def ensemble_v5_partial_with_rdkit(subm_list, use_valid_check, validation, parts_number=16, current_part=None,
                                   exclude_ids=None):
    verbose = False
    error_fixed = 0
    error_not_fixed = 0
    no_error = 0

    print('Start part: {} from {}'.format(current_part, parts_number))
    start_time = time.time()
    st = []
    ids_intersection = set()
    for path, w in subm_list:
        if DEBUG_LIMIT > 0:
            s = pd.read_csv(path, nrows=DEBUG_LIMIT)
        else:
            s = pd.read_csv(path)
        if exclude_ids is not None:
            cond = ~s['image_id'].isin(exclude_ids)
            s = s[cond]
        s = s.sort_values('image_id')
        s.reset_index(inplace=True, drop=True)
        if current_part is not None:
            start = current_part * (len(s) // parts_number)
            if current_part == parts_number - 1:
                end = len(s)
            else:
                end = (current_part + 1) * (len(s) // parts_number)
            print('Process: {:2d} Start: {:8d} End: {:8d} Subm: {}'.format(current_part, start, end,
                                                                           os.path.basename(path)))
            s = s[start:end]
        if validation is False:
            s['prediction'] = s['InChI']
        s.reset_index(inplace=True, drop=True)
        st.append(s.copy())
        if len(ids_intersection) == 0:
            ids_intersection = set(s['image_id'])
        else:
            ids_intersection &= set(s['image_id'])
        s = None
        gc.collect()

    print('IDs to use: {}'.format(len(ids_intersection)))

    if validation:
        for i in range(len(st)):
            score1 = get_score(st[i]['InChI'].values, st[i]['prediction'].values)
            st[i] = st[i][st[i]['image_id'].isin(ids_intersection)]
            st[i].reset_index(inplace=True, drop=True)
            score2 = get_score(st[i]['InChI'].values, st[i]['prediction'].values)
            print("Score full: {:.6f} Score same IDs: {:.6f} File: {}".format(score1, score2,
                                                                              os.path.basename(subm_list[i][0])))
    else:
        for i in range(len(st)):
            st[i] = st[i][st[i]['image_id'].isin(ids_intersection)]
            st[i].reset_index(inplace=True, drop=True)

    if 0:
        # Subm difference
        print('Init diff')
        for i in range(len(st)):
            for j in range(i + 1, len(st)):
                diff = get_score(st[i]['prediction'].values, st[j]['prediction'].values)
                print('Diff between {} and {}: {:.4f}'.format(i, j, diff))

    cond = (st[0]['prediction'] != st[0]['prediction'])
    for i in range(len(st)):
        for j in range(len(st)):
            if i == j:
                continue
            cond |= (st[i]['prediction'] != st[j]['prediction'])

    for i in range(len(st)):
        st[i] = st[i][cond].copy()
        print(len(st[i]))

    if 0:
        # Subm difference
        print('After removing same')
        for i in range(len(st)):
            for j in range(i + 1, len(st)):
                diff = get_score(st[i]['prediction'].values, st[j]['prediction'].values)
                print('Diff between {} and {}: {:.4f}'.format(i, j, diff))

    s = st[0][['image_id', 'InChI']]
    feats = []
    for i in range(len(st)):
        s[str(i)] = st[i]['prediction']
        feats.append(str(i))

    good = []
    match_stat = np.zeros(len(st), dtype=np.int32)
    ids = s['image_id'].values
    matrix = s[feats].values
    invalid_choices = 0
    valid_choices = 0

    print('After exclude same inchis. IDs to process: {} Size of DF: {}'.format(len(ids), len(s)))
    for image_number in range(len(ids)):
        id = ids[image_number]
        print('ID: {} Number: {} from {} Time: {:.2f} sec'.format(id, image_number, len(ids), time.time() - start_time))
        p = matrix[image_number]

        parts = []
        lengths = dict()
        indexes_order = dict()
        for i in range(0, len(p)):
            parts.append(p[i].split('/'))
            l1 = len(parts[-1])
            if l1 in lengths:
                lengths[l1] += 1
            else:
                lengths[l1] = 1
            indexes = []
            for j in range(2, l1):
                if len(parts[-1][j]) > 0:
                    indexes.append(parts[-1][j][0])
                else:
                    indexes.append('')
            indexes = tuple(indexes)
            if indexes in indexes_order:
                indexes_order[indexes] += 1
            else:
                indexes_order[indexes] = 1
        lengths = sort_dict_by_values(lengths)
        indexes_order = sort_dict_by_values(indexes_order)

        for p in parts:
            print(p)
        if len(lengths) == 1:
            print(lengths, indexes_order)
        else:
            print(lengths, indexes_order)
            # exit()

        new_parts = []

        # First part always the same
        new_parts.append([(parts[0][0], 0)])

        # Append next parts
        for k in range(1, lengths[0][0]):
            res = np.zeros(len(parts), dtype=np.int32)

            for i in range(0, len(parts)):

                # We need to skip different length Inchis
                # if (k > 1) and (len(parts[i]) != lengths[0][0]):
                if k >= len(parts[i]):
                    res[i] = 1000000000
                    continue

                parts1 = parts[i][k]
                val = 0
                for j in range(0, len(parts)):
                    if i == j:
                        continue

                    # We need to skip different length Inchis
                    # if (k > 1) and (len(parts[j]) != lengths[0][0]):
                    if k >= len(parts[j]):
                        continue

                    parts2 = parts[j][k]
                    val += Levenshtein.distance(parts1, parts2)
                res[i] = int(val)
            sort_index = res.argsort()

            try:
                partial_res = []
                for i in range(len(sort_index)):
                    if k < len(parts[sort_index[i]]):
                        add_tuple = (parts[sort_index[i]][k], res[sort_index[i]])
                        if add_tuple not in partial_res:
                            partial_res.append(add_tuple)
            except Exception as e:
                print(len(parts))
                print(parts)
                print(res)
                print(len(res))
                print(sort_index)
                exit()

            print('Best match: {} [{}]'.format(partial_res, res))
            new_parts.append(partial_res)

        new_parts = merge_all_variants(new_parts)
        print('Length of candidates: {}'.format(len(new_parts)))

        if USE_RDKIT:
            good_index = -100000
            for i in range(min(len(new_parts), LIMIT_CHECKS)):
                if check_if_valid_exec_v2(new_parts[i][0]):
                    good_index = i
                    break
        else:
            good_index = 0

        if good_index == -100000:
            print('Error not fixed!')
            error_not_fixed += 1
            good_index = 0
        elif good_index == 0:
            print('No error!')
            no_error += 1
        else:
            print('Error was fixed!')
            error_fixed += 1

        good.append(new_parts[good_index][0])
        print(good[-1])

    if validation is True:
        s['prediction'] = good
        if 1:
            s_final = pd.read_csv(subm_list[0][0])
            if exclude_ids is not None:
                cond = ~s_final['image_id'].isin(exclude_ids)
                s_final = s_final[cond]
            s_final = s_final.sort_values('image_id')
            s_final.reset_index(inplace=True, drop=True)
            if current_part is not None:
                start = current_part * (len(s_final) // parts_number)
                if current_part == parts_number - 1:
                    end = len(s_final)
                else:
                    end = (current_part + 1) * (len(s_final) // parts_number)
                print('Start: {} End: {}'.format(start, end))
                s_final = s_final[start:end]
            s_final = s_final[s_final['image_id'].isin(ids_intersection)]
            s_final.reset_index(inplace=True, drop=True)
            s_final = s_final[~s_final['image_id'].isin(s['image_id'].values)]
            print(list(s.columns.values))
            print(list(s_final.columns.values))
            s_final = pd.concat((s_final[['image_id', 'InChI', 'prediction']], s[['image_id', 'InChI', 'prediction']]),
                                axis=0)
        s_final = s
        print(len(s_final))
        s_final[['image_id', 'InChI', 'prediction']].to_csv(
            subm_list[0][0][:-4] + '_fixed_valid_check_{}_part_{}.csv'.format(use_valid_check, current_part),
            index=False)
        score = get_score(s_final['InChI'].values, s_final['prediction'].values)
        print('Ensemble score: {:.6f}'.format(score))
    else:
        s['InChI'] = good
        if 1:
            s_final = pd.read_csv(subm_list[0][0])
            if exclude_ids is not None:
                cond = ~s_final['image_id'].isin(exclude_ids)
                s_final = s_final[cond]
            s_final = s_final.sort_values('image_id')
            s_final.reset_index(inplace=True, drop=True)
            if current_part is not None:
                start = current_part * (len(s_final) // parts_number)
                if current_part == parts_number - 1:
                    end = len(s_final)
                else:
                    end = (current_part + 1) * (len(s_final) // parts_number)
                print('Start: {} End: {}'.format(start, end))
                s_final = s_final[start:end]
            s_final = s_final[~s_final['image_id'].isin(s['image_id'].values)]
            s_final = pd.concat((s_final, s[['image_id', 'InChI']]), axis=0)
        print(len(s_final))
        out_path = os.path.basename(subm_list[0][0])[:-4] + '_fixed_valid_check_{}_part_{}_from_{}.csv'.format(
            use_valid_check, current_part, parts_number)
        s_final[['image_id', 'InChI']].to_csv(out_path, index=False)

    print('Stat for validation check. No error {} Error fixed: {} Error not fixed: {}'.format(no_error, error_fixed,
                                                                                              error_not_fixed))
    print('Valid check enabled: {} Time: {:.2f} sec'.format(use_valid_check, time.time() - start_time))


if __name__ == '__main__':
    parts = PROC_NUM

    subm_list = [
        ('../input/pl-bms-molecular-translation/submission.csv', 1),  # 3.63
        ('../input/efficientnet-multi-layer-lstm-inference/submission.csv', 1),  # 3.69
        ('../input/tensorflow-tpu-training-baseline-predictions/submission.csv', 1),  # 3.7
        ('../input/bms-efficientnetv2-tpu-32-epocs-final-lr-42/submission.csv', 1),  # 4.2
    ]
    exclude_ids = np.array([])

    if USE_PARALLEL:
        plist = []
        for current_part in range(parts):
            p = Process(target=ensemble_v5_partial_with_rdkit,
                        args=(subm_list, False, False, parts, current_part, exclude_ids))
            plist.append(p)

        for i in range(len(plist)):
            print('Start {}'.format(i))
            plist[i].start()

        for i in range(len(plist)):
            plist[i].join()

        # Merge or parts
        files = glob.glob(
            os.path.basename(subm_list[0][0])[:-4] + '_fixed_valid_check_{}_part_*_from_{}.csv'.format(False, parts))
        all_parts = []
        for f in files:
            print('Read: {}'.format(f))
            s = pd.read_csv(f)
            all_parts.append(s)
        s = pd.concat(all_parts, axis=0)
        s.to_csv(os.path.basename(subm_list[0][0])[:-4] + '_fixed_valid_check_merged.csv', index=False)
    else:
        ensemble_v5_partial_with_rdkit(subm_list, False, False, 1, 0, exclude_ids)