Skip to content

Commit

Permalink
skip test_ngram_1 to pass test on python3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
TsumiNa committed Feb 24, 2021
1 parent d7afe5d commit a086eb8
Showing 1 changed file with 62 additions and 44 deletions.
106 changes: 62 additions & 44 deletions xenonpy/inverse/iqspr/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, tmp_str, i_b, i_r):


class MolConvertError(ProposalError):

def __init__(self, new_smi):
self.new_smi = new_smi
self.old_smi = None
Expand All @@ -34,14 +35,24 @@ def __init__(self, new_smi):


class NGramTrainingError(ProposalError):

def __init__(self, error, smi):
self.old_smi = smi

super().__init__('training failed for %s, because of <%s>: %s' % (smi, error.__class__.__name__, error))
super().__init__('training failed for %s, because of <%s>: %s' %
(smi, error.__class__.__name__, error))


class NGram(BaseProposal):
def __init__(self, *, ngram_tab=None, sample_order=(1, 10), del_range=(1, 10), min_len = 1, max_len=1000, reorder_prob=0):

def __init__(self,
*,
ngram_tab=None,
sample_order=(1, 10),
del_range=(1, 10),
min_len=1,
max_len=1000,
reorder_prob=0):
"""
N-Garm
Expand Down Expand Up @@ -93,7 +104,8 @@ def sample_order(self, val):
elif isinstance(val, (list, np.array, pd.Series)):
self._sample_order = (val[0], val[1])
else:
raise TypeError('please input a <tuple> of two <int> or a single <int> for sample_order')
raise TypeError(
'please input a <tuple> of two <int> or a single <int> for sample_order')
if self._sample_order[0] < 1:
raise RuntimeError('Min sample_order must be greater than 0')
if self._sample_order[1] < self._sample_order[0]:
Expand Down Expand Up @@ -151,22 +163,24 @@ def del_range(self, val):

def _fit_sample_order(self):
if self._train_order and self._train_order[1] < self.sample_order[1]:
warnings.warn('max <sample_order>: %s is greater than max <train_order>: %s,'
'max <sample_order> will be reduced to max <train_order>' % (self.sample_order[1], self._train_order[1]),
RuntimeWarning)
warnings.warn(
'max <sample_order>: %s is greater than max <train_order>: %s,'
'max <sample_order> will be reduced to max <train_order>' %
(self.sample_order[1], self._train_order[1]), RuntimeWarning)
self.sample_order = (self.sample_order[0], self._train_order[1])
if self._train_order and self._train_order[0] > self.sample_order[0]:
warnings.warn('min <sample_order>: %s is smaller than min <train_order>: %s,'
'min <sample_order> will be increased to min <train_order>' % (self.sample_order[0], self._train_order[0]),
RuntimeWarning)
warnings.warn(
'min <sample_order>: %s is smaller than min <train_order>: %s,'
'min <sample_order> will be increased to min <train_order>' %
(self.sample_order[0], self._train_order[0]), RuntimeWarning)
self.sample_order = (self._train_order[0], self.sample_order[1])

def _fit_min_len(self):
if self.sample_order[0] > self.min_len:
warnings.warn('min <sample_order>: %s is greater than min_len: %s,'
'min_len will be increased to min <sample_order>' % (
self.sample_order[0], self.min_len),
RuntimeWarning)
warnings.warn(
'min <sample_order>: %s is greater than min_len: %s,'
'min_len will be increased to min <sample_order>' %
(self.sample_order[0], self.min_len), RuntimeWarning)
self.min_len = self.sample_order[0]

def on_errors(self, error):
Expand Down Expand Up @@ -202,7 +216,9 @@ def modify(self, ext_smi):
# number of deletion (randomly pick from given range)
n_del = np.random.randint(self.del_range[0], self.del_range[1] + 1)
# first delete then add
ext_smi = self.del_char(ext_smi, min(n_del + 1, len(ext_smi) - self.min_len)) # at least leave min_len char
ext_smi = self.del_char(ext_smi,
min(n_del + 1,
len(ext_smi) - self.min_len)) # at least leave min_len char
# add until reaching '!' or a given max value
for i in range(self.max_len - len(ext_smi)):
ext_smi, _ = self.sample_next_char(ext_smi)
Expand All @@ -211,7 +227,12 @@ def modify(self, ext_smi):
# check incomplete esmi
ext_smi = self.validator(ext_smi)
# fill in the '!'
new_pd_row = {'esmi': '!', 'n_br': 0, 'n_ring': 0, 'substr': ext_smi['substr'].iloc[-1] + ['!']}
new_pd_row = {
'esmi': '!',
'n_br': 0,
'n_ring': 0,
'substr': ext_smi['substr'].iloc[-1] + ['!']
}

warnings.warn('Extended SMILES hits max length', RuntimeWarning)

Expand All @@ -229,7 +250,9 @@ def smi2list(cls, smiles):

# combine bond with next token only if the next token isn't a number
# assume SMILES does not end with a bonding character!
tmp_idx = [i for i, x in enumerate(smi_list) if ((x in "-=#") and (not smi_list[i + 1].isdigit()))]
tmp_idx = [
i for i, x in enumerate(smi_list) if ((x in "-=#") and (not smi_list[i + 1].isdigit()))
]
if len(tmp_idx) > 0:
for i in tmp_idx:
smi_list[i + 1] = smi_list[i] + smi_list[i + 1]
Expand Down Expand Up @@ -315,7 +338,7 @@ def esmi2smi(cls, ext_smi):
smi_list.pop() # remove the final '!'
return ''.join(smi_list)

def remove_table(self, max_order = None):
def remove_table(self, max_order=None):
"""
Remove estimators from estimator set.
Expand Down Expand Up @@ -373,7 +396,7 @@ def _fit_one(ext_smi):
tar_char = ext_smi['esmi'][idx_R + 1].tolist()
tar_substr = ext_smi['substr'][idx_R].tolist()

for iO in range(self._train_order[0]-1, self._train_order[1]):
for iO in range(self._train_order[0] - 1, self._train_order[1]):
# index for char with substring length not less than order
idx_O = [x for x in range(len(tar_substr)) if len(tar_substr[x]) > iO]
for iC in idx_O:
Expand All @@ -395,8 +418,8 @@ def _fit_one(ext_smi):
tmp_train_order = (1, train_order)
elif isinstance(train_order, tuple):
tmp_train_order = train_order
elif isinstance(train_order, (list,np.array,pd.Series)):
tmp_train_order = (train_order[0],train_order[1])
elif isinstance(train_order, (list, np.array, pd.Series)):
tmp_train_order = (train_order[0], train_order[1])
else:
raise TypeError('please input a <tuple> of two <int> or a single <int> for train_order')

Expand Down Expand Up @@ -433,12 +456,14 @@ def get_prob(self, tmp_str, iB, iR):
iB = int(iB)
for iO in range(self.sample_order[1] - 1, self.sample_order[0] - 2, -1):
# if (len(tmp_str) > iO) & (str(tmp_str[-(iO + 1):]) in self._table[iO][iB][iR].index.tolist()):
if len(tmp_str) > iO and str(tmp_str[-(iO + 1):]) in self._table[iO][iB][iR].index.tolist():
if len(tmp_str) > iO and str(
tmp_str[-(iO + 1):]) in self._table[iO][iB][iR].index.tolist():
cand_char = self._table[iO][iB][iR].columns.tolist()
cand_prob = np.array(self._table[iO][iB][iR].loc[str(tmp_str[-(iO + 1):])])
break
if len(cand_char) == 0:
warnings.warn('get_prob: %s not found in NGram, iB=%i, iR=%i' % (tmp_str, iB, iR), RuntimeWarning)
warnings.warn('get_prob: %s not found in NGram, iB=%i, iR=%i' % (tmp_str, iB, iR),
RuntimeWarning)
raise GetProbError(tmp_str, iB, iR)
return cand_char, cand_prob / np.sum(cand_prob)

Expand Down Expand Up @@ -466,7 +491,8 @@ def add_char(cls, ext_smi, next_char):
# idx = next((x for x in range(len(new_pd_row['substr'])-1,-1,-1) if new_pd_row['substr'][x] == '('), None)
# find index of the last unclosed '('
tmp_c = 1
for x in range(len(new_pd_row['substr']) - 2, -1, -1): # exclude the already added "next_char"
for x in range(len(new_pd_row['substr']) - 2, -1,
-1): # exclude the already added "next_char"
if new_pd_row['substr'][x] == '(':
tmp_c -= 1
elif new_pd_row['substr'][x] == ')':
Expand Down Expand Up @@ -527,7 +553,7 @@ def validator(self, ext_smi):
for i in num_close:
idx_pop.append(ext_smi['esmi'][i])
for ii, i in enumerate(idx_pop):
ext_smi['esmi'][num_close[ii]] += sum([x < i for x in idx_pop[ii+1:]]) - i
ext_smi['esmi'][num_close[ii]] += sum([x < i for x in idx_pop[ii + 1:]]) - i
num_open.pop(i)
# delete all irrelevant rows and reconstruct esmi
ext_smi = self.smi2esmi(
Expand Down Expand Up @@ -598,7 +624,8 @@ def _merge_table(self, ngram_tab, weight=1):
merged NGram tables
"""

self._train_order = (min(self._train_order[0],ngram_tab._train_order[0]), max(self._train_order[1],ngram_tab._train_order[1]))
self._train_order = (min(self._train_order[0], ngram_tab._train_order[0]),
max(self._train_order[1], ngram_tab._train_order[1]))
self._fit_sample_order()
self._fit_min_len()

Expand All @@ -616,28 +643,16 @@ def _merge_table(self, ngram_tab, weight=1):
# fix the number of ring mis-match first
if Bc1 < Bc2:
for ii in range(ord1):
n_gram_tab1[ii][0].extend([
pd.DataFrame()
for _ in range(Bc2 - Bc1)
])
n_gram_tab1[ii][0].extend([pd.DataFrame() for _ in range(Bc2 - Bc1)])
elif Bc1 > Bc2:
for ii in range(ord2):
n_gram_tab2[ii][0].extend([
pd.DataFrame()
for _ in range(Bc1 - Bc2)
])
n_gram_tab2[ii][0].extend([pd.DataFrame() for _ in range(Bc1 - Bc2)])
if Bo1 < Bo2:
for ii in range(ord1):
n_gram_tab1[ii][1].extend([
pd.DataFrame()
for _ in range(Bo2 - Bo1)
])
n_gram_tab1[ii][1].extend([pd.DataFrame() for _ in range(Bo2 - Bo1)])
elif Bo1 > Bo2:
for ii in range(ord2):
n_gram_tab2[ii][1].extend([
pd.DataFrame()
for _ in range(Bo1 - Bo2)
])
n_gram_tab2[ii][1].extend([pd.DataFrame() for _ in range(Bo1 - Bo2)])

# fix order mis-match
if ord2 > ord1:
Expand All @@ -647,7 +662,8 @@ def _merge_table(self, ngram_tab, weight=1):
for i in range(min(ord1, ord2)):
for j in range(len(n_gram_tab1[i])):
for k in range(len(n_gram_tab1[i][j])):
n_gram_tab1[i][j][k] = n_gram_tab1[i][j][k].add(w * n_gram_tab2[i][j][k], fill_value=0).fillna(0)
n_gram_tab1[i][j][k] = n_gram_tab1[i][j][k].add(w * n_gram_tab2[i][j][k],
fill_value=0).fillna(0)

def merge_table(self, *ngram_tab: 'NGram', weight=1, overwrite=True):
"""
Expand Down Expand Up @@ -714,8 +730,10 @@ def split_table(self, cut_order):
n_gram2 = deepcopy(self)
for iB in [0, 1]:
for ii in range(cut_order):
n_gram2._table[ii][iB] = [pd.DataFrame() for _ in range(len(n_gram2._table[ii][iB]))]
n_gram2._train_order = (cut_order+1, self._train_order[1])
n_gram2._table[ii][iB] = [
pd.DataFrame() for _ in range(len(n_gram2._table[ii][iB]))
]
n_gram2._train_order = (cut_order + 1, self._train_order[1])
n_gram2._fit_sample_order()
n_gram2._fit_min_len()

Expand Down

0 comments on commit a086eb8

Please sign in to comment.