In [25]:
import torch.nn as nn
import torch
from transformers import BertModel, BertTokenizer, RobertaModel, RobertaTokenizer
import pdb
from torch.nn import functional

from gensim import models
import torch.nn.functional as F
import copy
import math
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

import logging
from glob import glob
from torch.autograd import Variable
import numpy as np
import os
import sys

import warnings
warnings.filterwarnings("ignore")

from copy import deepcopy
import re

import pandas as pd
import json

import random
import copy
import nltk
import argparse

import time
import torch.optim
from collections import OrderedDict
try:
	import cPickle as pickle
except ImportError:
	import pickle
    
import math
import torch.nn.functional as f

## Components

In [26]:
#########################################
# contextual_embeddings.py
#########################################

class BertEncoder(nn.Module):
	def __init__(self, bert_model = 'bert-base-uncased',device = 'cuda:0 ', freeze_bert = False):
		super(BertEncoder, self).__init__()
		self.bert_layer = BertModel.from_pretrained(bert_model)
		self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model)
		self.device = device
		
		if freeze_bert:
			for p in self.bert_layer.parameters():
				p.requires_grad = False
		
	def bertify_input(self, sentences):
		'''
		Preprocess the input sentences using bert tokenizer and converts them to a torch tensor containing token ids

		'''
		#Tokenize the input sentences for feeding into BERT
		# pdb.set_trace()
		all_tokens  = [['[CLS]'] + self.bert_tokenizer.tokenize(sentence) + ['[SEP]'] for sentence in sentences]

		index_retrieve = []
		for sent in all_tokens:
			cur_ls = []
			for j in range(1, len(sent)):
				if sent[j][0] == '#':
					continue
				else:
					cur_ls.append(j)
			index_retrieve.append(cur_ls)
		
		#Pad all the sentences to a maximum length
		input_lengths = [len(tokens) for tokens in all_tokens]
		max_length    = max(input_lengths)
		padded_tokens = [tokens + ['[PAD]' for _ in range(max_length - len(tokens))] for tokens in all_tokens]

		#Convert tokens to token ids
		token_ids = torch.tensor([self.bert_tokenizer.convert_tokens_to_ids(tokens) for tokens in padded_tokens]).to(self.device)

		#Obtain attention masks
		pad_token = self.bert_tokenizer.convert_tokens_to_ids('[PAD]')
		attn_masks = (token_ids != pad_token).long()

		return token_ids, attn_masks, input_lengths, index_retrieve

	def forward(self, sentences):
		'''
		Feed the batch of sentences to a BERT encoder to obtain contextualized representations of each token
		'''
		#Preprocess sentences
		token_ids, attn_masks, input_lengths, index_retrieve = self.bertify_input(sentences)

		#Feed through bert
		cont_reps, _ = self.bert_layer(token_ids, attention_mask = attn_masks)

		return cont_reps, input_lengths, token_ids, index_retrieve

class RobertaEncoder(nn.Module):
	def __init__(self, roberta_model = 'roberta-base', device = 'cuda:0 ', freeze_roberta = False):
		super(RobertaEncoder, self).__init__()
		self.roberta_layer = RobertaModel.from_pretrained(roberta_model)
		self.roberta_tokenizer = RobertaTokenizer.from_pretrained(roberta_model)
		self.device = device
		
		if freeze_roberta:
			for p in self.roberta_layer.parameters():
				p.requires_grad = False
		
	def robertify_input(self, sentences):
		'''
		Preprocess the input sentences using roberta tokenizer and converts them to a torch tensor containing token ids

		'''
		# Tokenize the input sentences for feeding into RoBERTa
		all_tokens  = [['<s>'] + self.roberta_tokenizer.tokenize(sentence) + ['</s>'] for sentence in sentences]
		
		index_retrieve = []
		for sent in all_tokens:
			cur_ls = [1]
			for j in range(2, len(sent)):
				if sent[j][0] == '\u0120':
					cur_ls.append(j)
			index_retrieve.append(cur_ls)				
		
		# Pad all the sentences to a maximum length
		input_lengths = [len(tokens) for tokens in all_tokens]
		max_length    = max(input_lengths)
		padded_tokens = [tokens + ['<pad>' for _ in range(max_length - len(tokens))] for tokens in all_tokens]

		# Convert tokens to token ids
		token_ids = torch.tensor([self.roberta_tokenizer.convert_tokens_to_ids(tokens) for tokens in padded_tokens]).to(self.device)

		# Obtain attention masks
		pad_token = self.roberta_tokenizer.convert_tokens_to_ids('<pad>')
		attn_masks = (token_ids != pad_token).long()

		return token_ids, attn_masks, input_lengths, index_retrieve

	def forward(self, sentences):
		'''
		Feed the batch of sentences to a RoBERTa encoder to obtain contextualized representations of each token
		'''
		# Preprocess sentences
		token_ids, attn_masks, input_lengths, index_retrieve = self.robertify_input(sentences)

		# Feed through RoBERTa
		cont_reps = self.roberta_layer(token_ids, attention_mask = attn_masks)
		cont_reps = cont_reps.last_hidden_state
		return cont_reps, input_lengths, token_ids, index_retrieve
    
#########################################
# masked_cross_entropy.py
#########################################

def sequence_mask(sequence_length, max_len=None):
    if max_len is None:
        max_len = sequence_length.data.max()
    batch_size = sequence_length.size(0)
    seq_range = torch.arange(0, max_len).long()
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    if sequence_length.is_cuda:
        seq_range_expand = seq_range_expand.cuda()
    seq_length_expand = (sequence_length.unsqueeze(1).expand_as(seq_range_expand))
    return seq_range_expand < seq_length_expand


def masked_cross_entropy(logits, target, length):
    if torch.cuda.is_available():
        length = torch.LongTensor(length).cuda()
    else:
        length = torch.LongTensor(length)
    """
    Args:
        logits: A Variable containing a FloatTensor of size
            (batch, max_len, num_classes) which contains the
            unnormalized probability for each class.
        target: A Variable containing a LongTensor of size
            (batch, max_len) which contains the index of the true
            class for each corresponding step.
        length: A Variable containing a LongTensor of size (batch,)
            which contains the length of each data in a batch.
    Returns:
        loss: An average loss value masked by the length.
    """

    # logits_flat: (batch * max_len, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))
    # log_probs_flat: (batch * max_len, num_classes)
    log_probs_flat = functional.log_softmax(logits_flat, dim=1)
    # target_flat: (batch * max_len, 1)
    target_flat = target.view(-1, 1)
    # losses_flat: (batch * max_len, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)

    # losses: (batch, max_len)
    losses = losses_flat.view(*target.size())
    # mask: (batch, max_len)
    mask = sequence_mask(sequence_length=length, max_len=target.size(1))
    losses = losses * mask.float()
    loss = losses.sum() / length.float().sum()
    # if loss.item() > 10:
    #     print(losses, target)
    return loss


def masked_cross_entropy_without_logit(logits, target, length):
    if torch.cuda.is_available():
        length = torch.LongTensor(length).cuda()
    else:
        length = torch.LongTensor(length)
    """
    Args:
        logits: A Variable containing a FloatTensor of size
            (batch, max_len, num_classes) which contains the
            unnormalized probability for each class.
        target: A Variable containing a LongTensor of size
            (batch, max_len) which contains the index of the true
            class for each corresponding step.
        length: A Variable containing a LongTensor of size (batch,)
            which contains the length of each data in a batch.
    Returns:
        loss: An average loss value masked by the length.
    """

    # logits_flat: (batch * max_len, num_classes)
    logits_flat = logits.view(-1, logits.size(-1))

    # log_probs_flat: (batch * max_len, num_classes)
    log_probs_flat = torch.log(logits_flat + 1e-12)

    # target_flat: (batch * max_len, 1)
    target_flat = target.view(-1, 1)
    # losses_flat: (batch * max_len, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)

    # losses: (batch, max_len)
    losses = losses_flat.view(*target.size())

    # mask: (batch, max_len)
    mask = sequence_mask(sequence_length=length, max_len=target.size(1))
    losses = losses * mask.float()
    loss = losses.sum() / length.float().sum()
    # if loss.item() > 10:
    #     print(losses, target)
    return loss


#########################################
# models.py
#########################################

class Embedding(nn.Module):
	def __init__(self, config, input_lang, input_size, embedding_size, dropout=0.5):
		super(Embedding, self).__init__()

		self.config = config
		self.input_lang = input_lang
		self.input_size = input_size
		self.embedding_size = embedding_size

		if self.config.embedding == 'word2vec':
			self.config.embedding_size = 300
			self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(self._form_embeddings(self.config.word2vec_bin)), freeze = self.config.freeze_emb)
		else:
			self.embedding = nn.Embedding(input_size, embedding_size, padding_idx=0)
		self.em_dropout = nn.Dropout(dropout)

	def _form_embeddings(self, file_path):
		weights_all = models.KeyedVectors.load_word2vec_format(file_path, limit=200000, binary=True)
		weight_req  = torch.randn(self.input_size, self.config.embedding_size)
		for temp_ind in range(len(self.input_lang.index2word)):
			value = self.input_lang.index2word[temp_ind]
			if value in weights_all:
				weight_req[temp_ind] = torch.FloatTensor(weights_all[value])
		# for key, value in self.voc1.id2w.items():
		# 	if value in weights_all:
		# 		weight_req[key] = torch.FloatTensor(weights_all[value])

		return weight_req

	def forward(self, input_seqs):
		embedded = self.embedding(input_seqs)  # S x B x E
		embedded = self.em_dropout(embedded)
		return embedded

class EncoderRNN(nn.Module):
	def __init__(self, input_size, embedding_size, hidden_size, n_layers=2, dropout=0.5):
		super(EncoderRNN, self).__init__()

		self.input_size = input_size
		self.embedding_size = embedding_size
		self.hidden_size = hidden_size
		self.n_layers = n_layers
		self.dropout = dropout

		self.embedding = nn.Embedding(input_size, embedding_size, padding_idx=0)
		self.em_dropout = nn.Dropout(dropout)
		self.gru = nn.GRU(embedding_size, hidden_size, n_layers, dropout=dropout, bidirectional=True)

	def forward(self, input_seqs, input_lengths, hidden=None):
		# Note: we run this all at once (over multiple batches of multiple sequences)
		embedded = self.embedding(input_seqs)  # S x B x E
		embedded = self.em_dropout(embedded)
		packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
		outputs, hidden = self.gru(packed, hidden)
		outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs)  # unpack (back to padded)
		outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:]  # Sum bidirectional outputs
		# S x B x H
		return outputs, hidden


class Attn(nn.Module):
	def __init__(self, hidden_size):
		super(Attn, self).__init__()
		self.hidden_size = hidden_size
		self.attn = nn.Linear(hidden_size * 2, hidden_size)
		self.score = nn.Linear(hidden_size, 1, bias=False)
		self.softmax = nn.Softmax(dim=1)

	def forward(self, hidden, encoder_outputs, seq_mask=None):
		max_len = encoder_outputs.size(0)
		repeat_dims = [1] * hidden.dim()
		repeat_dims[0] = max_len
		hidden = hidden.repeat(*repeat_dims)  # S x B x H
		# For each position of encoder outputs
		this_batch_size = encoder_outputs.size(1)
		energy_in = torch.cat((hidden, encoder_outputs), 2).view(-1, 2 * self.hidden_size)
		attn_energies = self.score(torch.tanh(self.attn(energy_in)))  # (S x B) x 1
		attn_energies = attn_energies.squeeze(1)
		attn_energies = attn_energies.view(max_len, this_batch_size).transpose(0, 1)  # B x S
		if seq_mask is not None:
			attn_energies = attn_energies.masked_fill_(seq_mask, -1e12)
		attn_energies = self.softmax(attn_energies)
		# Normalize energies to weights in range 0 to 1, resize to B x 1 x S
		return attn_energies.unsqueeze(1)


class AttnDecoderRNN(nn.Module):
	def __init__(
			self, hidden_size, embedding_size, input_size, output_size, n_layers=2, dropout=0.5):
		super(AttnDecoderRNN, self).__init__()

		# Keep for reference
		self.embedding_size = embedding_size
		self.hidden_size = hidden_size
		self.input_size = input_size
		self.output_size = output_size
		self.n_layers = n_layers
		self.dropout = dropout

		# Define layers
		self.em_dropout = nn.Dropout(dropout)
		self.embedding = nn.Embedding(input_size, embedding_size, padding_idx=0)
		self.gru = nn.GRU(hidden_size + embedding_size, hidden_size, n_layers, dropout=dropout)
		self.concat = nn.Linear(hidden_size * 2, hidden_size)
		self.out = nn.Linear(hidden_size, output_size)
		# Choose attention model
		self.attn = Attn(hidden_size)

	def forward(self, input_seq, last_hidden, encoder_outputs, seq_mask):
		# Get the embedding of the current input word (last output word)
		batch_size = input_seq.size(0)
		embedded = self.embedding(input_seq)
		embedded = self.em_dropout(embedded)
		embedded = embedded.view(1, batch_size, self.embedding_size)  # S=1 x B x N

		# Calculate attention from current RNN state and all encoder outputs;
		# apply to encoder outputs to get weighted average
		attn_weights = self.attn(last_hidden[-1].unsqueeze(0), encoder_outputs, seq_mask)
		context = attn_weights.bmm(encoder_outputs.transpose(0, 1))  # B x S=1 x N

		# Get current hidden state from input word and last hidden state
		rnn_output, hidden = self.gru(torch.cat((embedded, context.transpose(0, 1)), 2), last_hidden)

		# Attentional vector using the RNN hidden state and context vector
		# concatenated together (Luong eq. 5)
		output = self.out(torch.tanh(self.concat(torch.cat((rnn_output.squeeze(0), context.squeeze(1)), 1))))

		# Return final output, hidden state
		return output, hidden


class TreeNode:  # the class save the tree node
	def __init__(self, embedding, left_flag=False):
		self.embedding = embedding
		self.left_flag = left_flag


class Score(nn.Module):
	def __init__(self, input_size, hidden_size):
		super(Score, self).__init__()
		self.input_size = input_size
		self.hidden_size = hidden_size
		self.attn = nn.Linear(hidden_size + input_size, hidden_size)
		self.score = nn.Linear(hidden_size, 1, bias=False)

	def forward(self, hidden, num_embeddings, num_mask=None):
		max_len = num_embeddings.size(1)
		repeat_dims = [1] * hidden.dim()
		repeat_dims[1] = max_len
		hidden = hidden.repeat(*repeat_dims)  # B x O x H
		# For each position of encoder outputs
		this_batch_size = num_embeddings.size(0)
		energy_in = torch.cat((hidden, num_embeddings), 2).view(-1, self.input_size + self.hidden_size)
		score = self.score(torch.tanh(self.attn(energy_in)))  # (B x O) x 1
		score = score.squeeze(1)
		score = score.view(this_batch_size, -1)  # B x O
		if num_mask is not None:
			score = score.masked_fill_(num_mask, -1e12)
		return score


class TreeAttn(nn.Module):
	def __init__(self, input_size, hidden_size):
		super(TreeAttn, self).__init__()
		self.input_size = input_size
		self.hidden_size = hidden_size
		self.attn = nn.Linear(hidden_size + input_size, hidden_size)
		self.score = nn.Linear(hidden_size, 1)

	def forward(self, hidden, encoder_outputs, seq_mask=None):
		max_len = encoder_outputs.size(0)

		repeat_dims = [1] * hidden.dim()
		repeat_dims[0] = max_len
		hidden = hidden.repeat(*repeat_dims)  # S x B x H
		this_batch_size = encoder_outputs.size(1)

		energy_in = torch.cat((hidden, encoder_outputs), 2).view(-1, self.input_size + self.hidden_size)

		score_feature = torch.tanh(self.attn(energy_in))
		attn_energies = self.score(score_feature)  # (S x B) x 1
		attn_energies = attn_energies.squeeze(1)
		attn_energies = attn_energies.view(max_len, this_batch_size).transpose(0, 1)  # B x S
		if seq_mask is not None:
			attn_energies = attn_energies.masked_fill_(seq_mask, -1e12)
		attn_energies = nn.functional.softmax(attn_energies, dim=1)  # B x S

		return attn_energies.unsqueeze(1)


class EncoderSeq(nn.Module):
	# def __init__(self, input_size, embedding_size, hidden_size, n_layers=2, dropout=0.5):
	def __init__(self, cell_type, embedding_size, hidden_size, n_layers=2, dropout=0.5):
		super(EncoderSeq, self).__init__()

		# self.input_size = input_size
		self.embedding_size = embedding_size
		self.hidden_size = hidden_size
		self.n_layers = n_layers
		self.dropout = dropout

		# self.embedding = nn.Embedding(input_size, embedding_size, padding_idx=0)
		# self.em_dropout = nn.Dropout(dropout)

		if cell_type == 'lstm':
			self.rnn = nn.LSTM(self.embedding_size, self.hidden_size,
							   num_layers=self.n_layers,
							   dropout=(0 if self.n_layers == 1 else self.dropout),
							   bidirectional=True)
		elif cell_type == 'gru':
			self.rnn = nn.GRU(embedding_size, hidden_size, n_layers, dropout=dropout, bidirectional=True)
		else:
			self.rnn = nn.RNN(self.embedding_size, self.hidden_size,
							  num_layers=self.n_layers,
							  nonlinearity='tanh',							# ['relu', 'tanh']
							  dropout=(0 if self.n_layers == 1 else self.dropout),
							  bidirectional=True)

		self.gcn = Graph_Module(hidden_size, hidden_size, hidden_size)

	def forward(self, embedded, input_lengths, orig_idx, batch_graph, hidden=None):
		# Note: we run this all at once (over multiple batches of multiple sequences)
		# embedded = self.embedding(input_seqs)  # S x B x E
		# embedded = self.em_dropout(embedded)
		packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
		pade_hidden = hidden
		# pade_outputs, pade_hidden = self.gru_pade(packed, pade_hidden)
		pade_outputs, pade_hidden = self.rnn(packed, pade_hidden)
		pade_outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(pade_outputs)

		if orig_idx is not None:
			pade_outputs = pade_outputs.index_select(1, orig_idx)

		problem_output = pade_outputs[-1, :, :self.hidden_size] + pade_outputs[0, :, self.hidden_size:]
		pade_outputs = pade_outputs[:, :, :self.hidden_size] + pade_outputs[:, :, self.hidden_size:]  # S x B x H
		# pdb.set_trace()
		_, pade_outputs = self.gcn(pade_outputs, batch_graph)
		pade_outputs = pade_outputs.transpose(0, 1)
		return pade_outputs, problem_output


class Prediction(nn.Module):
	# a seq2tree decoder with Problem aware dynamic encoding

	def __init__(self, hidden_size, op_nums, input_size, dropout=0.5):
		super(Prediction, self).__init__()

		# Keep for reference
		self.hidden_size = hidden_size
		self.input_size = input_size
		self.op_nums = op_nums

		# Define layers
		self.dropout = nn.Dropout(dropout)

		self.embedding_weight = nn.Parameter(torch.randn(1, input_size, hidden_size))

		# for Computational symbols and Generated numbers
		self.concat_l = nn.Linear(hidden_size, hidden_size)
		self.concat_r = nn.Linear(hidden_size * 2, hidden_size)
		self.concat_lg = nn.Linear(hidden_size, hidden_size)
		self.concat_rg = nn.Linear(hidden_size * 2, hidden_size)

		self.ops = nn.Linear(hidden_size * 2, op_nums)

		self.attn = TreeAttn(hidden_size, hidden_size)
		self.score = Score(hidden_size * 2, hidden_size)

	def forward(self, node_stacks, left_childs, encoder_outputs, num_pades, padding_hidden, seq_mask, mask_nums):
		current_embeddings = []

		for st in node_stacks:
			if len(st) == 0:
				current_embeddings.append(padding_hidden)
			else:
				current_node = st[-1]
				current_embeddings.append(current_node.embedding)

		current_node_temp = []
		for l, c in zip(left_childs, current_embeddings):
			if l is None:
				c = self.dropout(c)
				g = torch.tanh(self.concat_l(c))
				t = torch.sigmoid(self.concat_lg(c))
				current_node_temp.append(g * t)
			else:
				ld = self.dropout(l)
				c = self.dropout(c)
				g = torch.tanh(self.concat_r(torch.cat((ld, c), 1)))
				t = torch.sigmoid(self.concat_rg(torch.cat((ld, c), 1)))
				current_node_temp.append(g * t)

		current_node = torch.stack(current_node_temp)

		current_embeddings = self.dropout(current_node)

		current_attn = self.attn(current_embeddings.transpose(0, 1), encoder_outputs, seq_mask)
		current_context = current_attn.bmm(encoder_outputs.transpose(0, 1))  # B x 1 x N

		# the information to get the current quantity
		batch_size = current_embeddings.size(0)
		# predict the output (this node corresponding to output(number or operator)) with PADE

		repeat_dims = [1] * self.embedding_weight.dim()
		repeat_dims[0] = batch_size
		embedding_weight = self.embedding_weight.repeat(*repeat_dims)  # B x input_size x N
		embedding_weight = torch.cat((embedding_weight, num_pades), dim=1)  # B x O x N

		leaf_input = torch.cat((current_node, current_context), 2)
		leaf_input = leaf_input.squeeze(1)
		leaf_input = self.dropout(leaf_input)

		# p_leaf = nn.functional.softmax(self.is_leaf(leaf_input), 1)
		# max pooling the embedding_weight
		embedding_weight_ = self.dropout(embedding_weight)
		num_score = self.score(leaf_input.unsqueeze(1), embedding_weight_, mask_nums)

		# num_score = nn.functional.softmax(num_score, 1)

		op = self.ops(leaf_input)

		# return p_leaf, num_score, op, current_embeddings, current_attn

		return num_score, op, current_node, current_context, embedding_weight


class GenerateNode(nn.Module):
	def __init__(self, hidden_size, op_nums, embedding_size, dropout=0.5):
		super(GenerateNode, self).__init__()

		self.embedding_size = embedding_size
		self.hidden_size = hidden_size

		self.embeddings = nn.Embedding(op_nums, embedding_size)
		self.em_dropout = nn.Dropout(dropout)
		self.generate_l = nn.Linear(hidden_size * 2 + embedding_size, hidden_size)
		self.generate_r = nn.Linear(hidden_size * 2 + embedding_size, hidden_size)
		self.generate_lg = nn.Linear(hidden_size * 2 + embedding_size, hidden_size)
		self.generate_rg = nn.Linear(hidden_size * 2 + embedding_size, hidden_size)

	def forward(self, node_embedding, node_label, current_context):
		node_label_ = self.embeddings(node_label)
		node_label = self.em_dropout(node_label_)
		node_embedding = node_embedding.squeeze(1)
		current_context = current_context.squeeze(1)
		node_embedding = self.em_dropout(node_embedding)
		current_context = self.em_dropout(current_context)

		l_child = torch.tanh(self.generate_l(torch.cat((node_embedding, current_context, node_label), 1)))
		l_child_g = torch.sigmoid(self.generate_lg(torch.cat((node_embedding, current_context, node_label), 1)))
		r_child = torch.tanh(self.generate_r(torch.cat((node_embedding, current_context, node_label), 1)))
		r_child_g = torch.sigmoid(self.generate_rg(torch.cat((node_embedding, current_context, node_label), 1)))
		l_child = l_child * l_child_g
		r_child = r_child * r_child_g
		return l_child, r_child, node_label_


class Merge(nn.Module):
	def __init__(self, hidden_size, embedding_size, dropout=0.5):
		super(Merge, self).__init__()

		self.embedding_size = embedding_size
		self.hidden_size = hidden_size

		self.em_dropout = nn.Dropout(dropout)
		self.merge = nn.Linear(hidden_size * 2 + embedding_size, hidden_size)
		self.merge_g = nn.Linear(hidden_size * 2 + embedding_size, hidden_size)

	def forward(self, node_embedding, sub_tree_1, sub_tree_2):
		sub_tree_1 = self.em_dropout(sub_tree_1)
		sub_tree_2 = self.em_dropout(sub_tree_2)
		node_embedding = self.em_dropout(node_embedding)

		sub_tree = torch.tanh(self.merge(torch.cat((node_embedding, sub_tree_1, sub_tree_2), 1)))
		sub_tree_g = torch.sigmoid(self.merge_g(torch.cat((node_embedding, sub_tree_1, sub_tree_2), 1)))
		sub_tree = sub_tree * sub_tree_g
		return sub_tree
	
	
	
# Graph Module
def clones(module, N):
	"Produce N identical layers."
	return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class LayerNorm(nn.Module):
	"Construct a layernorm module (See citation for details)."
	def __init__(self, features, eps=1e-6):
		super(LayerNorm, self).__init__()
		self.a_2 = nn.Parameter(torch.ones(features))
		self.b_2 = nn.Parameter(torch.zeros(features))
		self.eps = eps

	def forward(self, x):
		mean = x.mean(-1, keepdim=True)
		std = x.std(-1, keepdim=True)
		return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

class PositionwiseFeedForward(nn.Module):
	"Implements FFN equation."
	def __init__(self, d_model, d_ff,d_out, dropout=0.1):
		super(PositionwiseFeedForward, self).__init__()
		self.w_1 = nn.Linear(d_model, d_ff)
		self.w_2 = nn.Linear(d_ff, d_out)
		self.dropout = nn.Dropout(dropout)

	def forward(self, x):
		return self.w_2(self.dropout(F.relu(self.w_1(x))))

class Graph_Module(nn.Module):
	def __init__(self, indim, hiddim, outdim, dropout=0.3):
		super(Graph_Module, self).__init__()
		'''
		## Variables:
		- indim: dimensionality of input node features
		- hiddim: dimensionality of the joint hidden embedding
		- outdim: dimensionality of the output node features
		- combined_feature_dim: dimensionality of the joint hidden embedding for graph
		- K: number of graph nodes/objects on the image
		'''
		self.in_dim = indim
		#self.combined_dim = outdim
		
		#self.edge_layer_1 = nn.Linear(indim, outdim)
		#self.edge_layer_2 = nn.Linear(outdim, outdim)
		
		#self.dropout = nn.Dropout(p=dropout)
		#self.edge_layer_1 = nn.utils.weight_norm(self.edge_layer_1)
		#self.edge_layer_2 = nn.utils.weight_norm(self.edge_layer_2)
		self.h = 4
		self.d_k = outdim//self.h
		
		#layer = GCN(indim, hiddim, self.d_k, dropout)
		self.graph = clones(GCN(indim, hiddim, self.d_k, dropout), 4)
		
		#self.Graph_0 = GCN(indim, hiddim, outdim//4, dropout)
		#self.Graph_1 = GCN(indim, hiddim, outdim//4, dropout)
		#self.Graph_2 = GCN(indim, hiddim, outdim//4, dropout)
		#self.Graph_3 = GCN(indim, hiddim, outdim//4, dropout)
		
		self.feed_foward = PositionwiseFeedForward(indim, hiddim, outdim, dropout)
		self.norm = LayerNorm(outdim)

	def get_adj(self, graph_nodes):
		'''
		## Inputs:
		- graph_nodes (batch_size, K, in_feat_dim): input features
		## Returns:
		- adjacency matrix (batch_size, K, K)
		'''
		self.K = graph_nodes.size(1)
		graph_nodes = graph_nodes.contiguous().view(-1, self.in_dim)
		
		# layer 1
		h = self.edge_layer_1(graph_nodes)
		h = F.relu(h)
		
		# layer 2
		h = self.edge_layer_2(h)
		h = F.relu(h)

		# outer product
		h = h.view(-1, self.K, self.combined_dim)
		adjacency_matrix = torch.matmul(h, h.transpose(1, 2))
		
		adjacency_matrix = self.b_normal(adjacency_matrix)

		return adjacency_matrix
	
	def normalize(self, A, symmetric=True):
		'''
		## Inputs:
		- adjacency matrix (K, K) : A
		## Returns:
		- adjacency matrix (K, K) 
		'''
		A = A + torch.eye(A.size(0)).cuda().float()
		d = A.sum(1)
		if symmetric:
			# D = D^{-1/2}
			D = torch.diag(torch.pow(d, -0.5))
			return D.mm(A).mm(D)
		else :
			D = torch.diag(torch.pow(d,-1))
			return D.mm(A)
	   
	def b_normal(self, adj):
		batch = adj.size(0)
		for i in range(batch):
			adj[i] = self.normalize(adj[i])
		return adj

	def forward(self, graph_nodes, graph):
		'''
		## Inputs:
		- graph_nodes (batch_size, K, in_feat_dim): input features
		## Returns:
		- graph_encode_features (batch_size, K, out_feat_dim)
		'''
		nbatches = graph_nodes.size(0)
		mbatches = graph.size(0)
		if nbatches != mbatches:
			graph_nodes = graph_nodes.transpose(0, 1)
		# adj (batch_size, K, K): adjacency matrix
		if not bool(graph.numel()):
			adj = self.get_adj(graph_nodes)
			#adj = adj.unsqueeze(1)
			#adj = torch.cat((adj,adj,adj),1)
			adj_list = [adj,adj,adj,adj]
		else:
			adj = graph.float()
			adj_list = [adj[:,1,:],adj[:,1,:],adj[:,4,:],adj[:,4,:]]
		#print(adj)
		
		g_feature = \
			tuple([l(graph_nodes,x) for l, x in zip(self.graph,adj_list)])
		#g_feature_0 = self.Graph_0(graph_nodes,adj[0])
		#g_feature_1 = self.Graph_1(graph_nodes,adj[1])
		#g_feature_2 = self.Graph_2(graph_nodes,adj[2])
		#g_feature_3 = self.Graph_3(graph_nodes,adj[3])
		#print('g_feature')
		#print(type(g_feature))
		
		
		g_feature = self.norm(torch.cat(g_feature,2)) + graph_nodes
		#print('g_feature')
		#print(g_feature.shape)
		
		graph_encode_features = self.feed_foward(g_feature) + g_feature
		
		return adj, graph_encode_features

# GCN
class GCN(nn.Module):
	def __init__(self, in_feat_dim, nhid, out_feat_dim, dropout):
		super(GCN, self).__init__()
		'''
		## Inputs:
		- graph_nodes (batch_size, K, in_feat_dim): input features
		- adjacency matrix (batch_size, K, K)
		## Returns:
		- gcn_enhance_feature (batch_size, K, out_feat_dim)
		'''
		self.gc1 = GraphConvolution(in_feat_dim, nhid)
		self.gc2 = GraphConvolution(nhid, out_feat_dim)
		self.dropout = dropout

	def forward(self, x, adj):
		x = F.relu(self.gc1(x, adj))
		x = F.dropout(x, self.dropout, training=self.training)
		x = self.gc2(x, adj)
		return x
	
# Graph_Conv
class GraphConvolution(Module):
	"""
	Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
	"""

	def __init__(self, in_features, out_features, bias=True):
		super(GraphConvolution, self).__init__()
		self.in_features = in_features
		self.out_features = out_features
		self.weight = Parameter(torch.FloatTensor(in_features, out_features))
		if bias:
			self.bias = Parameter(torch.FloatTensor(out_features))
		else:
			self.register_parameter('bias', None)
		self.reset_parameters()

	def reset_parameters(self):
		stdv = 1. / math.sqrt(self.weight.size(1))
		self.weight.data.uniform_(-stdv, stdv)
		if self.bias is not None:
			self.bias.data.uniform_(-stdv, stdv)

	def forward(self, input, adj):
		#print(input.shape)
		#print(self.weight.shape)
		support = torch.matmul(input, self.weight)
		#print(adj.shape)
		#print(support.shape)
		output = torch.matmul(adj, support)
		
		if self.bias is not None:
			return output + self.bias
		else:
			return output

	def __repr__(self):
		return self.__class__.__name__ + ' (' \
			   + str(self.in_features) + ' -> ' \
			   + str(self.out_features) + ')'


## utils

In [27]:
#########################################
# expression_transfer.py
#########################################

# An expression tree node
class Et:
    # Constructor to create a node
    def __init__(self, value):
        self.value = value
        self.left = None
        self.right = None


# Returns root of constructed tree for given postfix expression
def construct_exp_tree(postfix):
    stack = []

    # Traverse through every character of input expression
    for char in postfix:

        # if operand, simply push into stack
        if char not in ["+", "-", "*", "/", "^"]:
            t = Et(char)
            stack.append(t)
        # Operator
        else:
            # Pop two top nodes
            t = Et(char)
            t1 = stack.pop()
            t2 = stack.pop()

            # make them children
            t.right = t1
            t.left = t2

            # Add this subexpression to stack
            stack.append(t)
    # Only element  will be the root of expression tree
    t = stack.pop()
    return t


def from_infix_to_postfix(expression):
    st = list()
    res = list()
    priority = {"+": 0, "-": 0, "*": 1, "/": 1, "^": 2}
    for e in expression:
        if e in ["(", "["]:
            st.append(e)
        elif e == ")":
            c = st.pop()
            while c != "(":
                res.append(c)
                c = st.pop()
        elif e == "]":
            c = st.pop()
            while c != "[":
                res.append(c)
                c = st.pop()
        elif e in priority:
            while len(st) > 0 and st[-1] not in ["(", "["] and priority[e] <= priority[st[-1]]:
                res.append(st.pop())
            st.append(e)
        else:
            res.append(e)
    while len(st) > 0:
        res.append(st.pop())
    return res


def from_infix_to_prefix(expression):
    st = list()
    res = list()
    priority = {"+": 0, "-": 0, "*": 1, "/": 1, "^": 2}
    expression = deepcopy(expression)
    expression.reverse()
    for e in expression:
        if e in [")", "]"]:
            st.append(e)
        elif e == "(":
            c = st.pop()
            while c != ")":
                res.append(c)
                c = st.pop()
        elif e == "[":
            c = st.pop()
            while c != "]":
                res.append(c)
                c = st.pop()
        elif e in priority:
            while len(st) > 0 and st[-1] not in [")", "]"] and priority[e] < priority[st[-1]]:
                res.append(st.pop())
            st.append(e)
        else:
            res.append(e)
    while len(st) > 0:
        res.append(st.pop())
    res.reverse()
    return res


def out_expression_list(test, output_lang, num_list, num_stack=None):
    max_index = output_lang.n_words
    res = []
    for i in test:
        # if i == 0:
        #     return res
        if i < max_index - 1:
            idx = output_lang.index2word[i]
            if idx[0] == "N":
                if int(idx[1:]) >= len(num_list):
                    return None
                res.append(num_list[int(idx[1:])])
            else:
                res.append(idx)
        else:
            pos_list = num_stack.pop()
            c = num_list[pos_list[0]]
            res.append(c)
    return res


def compute_postfix_expression(post_fix):
    st = list()
    operators = ["+", "-", "^", "*", "/"]
    for p in post_fix:
        if p not in operators:
            pos = re.search("\d+\(", p)
            if pos:
                st.append(eval(p[pos.start(): pos.end() - 1] + "+" + p[pos.end() - 1:]))
            elif p[-1] == "%":
                    st.append(float(p[:-1]) / 100)
            else:
                st.append(eval(p))
        elif p == "+" and len(st) > 1:
            a = st.pop()
            b = st.pop()
            st.append(a + b)
        elif p == "*" and len(st) > 1:
            a = st.pop()
            b = st.pop()
            st.append(a * b)
        elif p == "*" and len(st) > 1:
            a = st.pop()
            b = st.pop()
            st.append(a * b)
        elif p == "/" and len(st) > 1:
            a = st.pop()
            b = st.pop()
            if a == 0:
                return None
            st.append(b / a)
        elif p == "-" and len(st) > 1:
            a = st.pop()
            b = st.pop()
            st.append(b - a)
        elif p == "^" and len(st) > 1:
            a = st.pop()
            b = st.pop()
            st.append(a ** b)
        else:
            return None
    if len(st) == 1:
        return st.pop()
    return None


def compute_prefix_expression(pre_fix):
    st = list()
    operators = ["+", "-", "^", "*", "/"]
    pre_fix = deepcopy(pre_fix)
    pre_fix.reverse()
    for p in pre_fix:
        if p not in operators:
            pos = re.search("\d+\(", p)
            if pos:
                st.append(eval(p[pos.start(): pos.end() - 1] + "+" + p[pos.end() - 1:]))
            elif p[-1] == "%":
                st.append(float(p[:-1]) / 100)
            else:
                st.append(eval(p))
        elif p == "+" and len(st) > 1:
            a = st.pop()
            b = st.pop()
            st.append(a + b)
        elif p == "*" and len(st) > 1:
            a = st.pop()
            b = st.pop()
            st.append(a * b)
        elif p == "*" and len(st) > 1:
            a = st.pop()
            b = st.pop()
            st.append(a * b)
        elif p == "/" and len(st) > 1:
            a = st.pop()
            b = st.pop()
            if b == 0:
                return None
            st.append(a / b)
        elif p == "-" and len(st) > 1:
            a = st.pop()
            b = st.pop()
            st.append(a - b)
        elif p == "^" and len(st) > 1:
            a = st.pop()
            b = st.pop()
            if float(eval(b)) != 2.0 or float(eval(b)) != 3.0:
                return None
            st.append(a ** b)
        else:
            return None
    if len(st) == 1:
        return st.pop()
    return None


#########################################
# helper.py
#########################################

def gpu_init_pytorch(gpu_num):
	'''
		Initialize GPU
	'''
	torch.cuda.set_device(int(gpu_num))
	device = torch.device("cuda:{}".format(
		gpu_num) if torch.cuda.is_available() else "cpu")
	return device

def create_save_directories(path):
	if not os.path.exists(path):
		os.makedirs(path)

def stack_to_string(stack):
	op = ""
	for i in stack:
		if op == "":
			op = op + i
		else:
			op = op + ' ' + i
	return op

def index_batch_to_words(input_batch, input_length, lang):
	'''
		Args:
			input_batch: List of BS x Max_len
			input_length: List of BS
		Return:
			contextual_input: List of BS
	'''
	contextual_input = []
	for i in range(len(input_batch)):
		contextual_input.append(stack_to_string(sentence_from_indexes(lang, input_batch[i][:input_length[i]])))

	return contextual_input

def sort_by_len(seqs, input_len, device=None, dim=1):
	orig_idx = list(range(seqs.size(dim)))
	# pdb.set_trace()

	# Index by which sorting needs to be done
	sorted_idx = sorted(orig_idx, key=lambda k: input_len[k], reverse=True)
	sorted_idx= torch.LongTensor(sorted_idx)
	if device:
		sorted_idx = sorted_idx.to(device)

	sorted_seqs = seqs.index_select(1, sorted_idx)
	sorted_lens=  [input_len[i] for i in sorted_idx]

	# For restoring original order
	orig_idx = sorted(orig_idx, key=lambda k: sorted_idx[k])
	orig_idx = torch.LongTensor(orig_idx)
	if device:
		orig_idx = orig_idx.to(device)
		# sorted_lens = torch.LongTensor(sorted_lens).to(device)
	return sorted_seqs, sorted_lens, orig_idx

def save_checkpoint(state, epoch, logger, model_path, ckpt):
	'''
		Saves the model state along with epoch number. The name format is important for 
		the load functions. Don't mess with it.

		Args:
			model state
			epoch number
			logger variable
			directory to save models
			checkpoint name
	'''
	ckpt_path = os.path.join(model_path, '{}.pt'.format(ckpt))
	logger.info('Saving Checkpoint at : {}'.format(ckpt_path))
	torch.save(state, ckpt_path)

def load_checkpoint(config, embedding, encoder, predict, generate, merge, mode, ckpt_path, logger, device,
					embedding_optimizer = None, encoder_optimizer = None, predict_optimizer = None, generate_optimizer = None, merge_optimizer = None,
					embedding_scheduler = None, encoder_scheduler = None, predict_scheduler = None, generate_scheduler = None, merge_scheduler = None
					):
	checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage)

	embedding.load_state_dict(checkpoint['embedding_state_dict'])
	encoder.load_state_dict(checkpoint['encoder_state_dict'])
	predict.load_state_dict(checkpoint['predict_state_dict'])
	generate.load_state_dict(checkpoint['generate_state_dict'])
	merge.load_state_dict(checkpoint['merge_state_dict'])

	if mode == 'train':
		embedding_optimizer.load_state_dict(checkpoint['embedding_optimizer_state_dict'])
		encoder_optimizer.load_state_dict(checkpoint['encoder_optimizer_state_dict'])
		predict_optimizer.load_state_dict(checkpoint['predict_optimizer_state_dict'])
		generate_optimizer.load_state_dict(checkpoint['generate_optimizer_state_dict'])
		merge_optimizer.load_state_dict(checkpoint['merge_optimizer_state_dict'])

		embedding_scheduler.load_state_dict(checkpoint['embedding_scheduler_state_dict'])
		encoder_scheduler.load_state_dict(checkpoint['encoder_scheduler_state_dict'])
		predict_scheduler.load_state_dict(checkpoint['predict_scheduler_state_dict'])
		generate_scheduler.load_state_dict(checkpoint['generate_scheduler_state_dict'])
		merge_scheduler.load_state_dict(checkpoint['merge_scheduler_state_dict'])

	start_epoch = checkpoint['epoch']
	min_train_loss  = checkpoint['min_train_loss']
	max_train_acc = checkpoint['max_train_acc']
	max_val_acc = checkpoint['max_val_acc']
	equation_acc = checkpoint['equation_acc']
	best_epoch = checkpoint['best_epoch']
	generate_nums = checkpoint['generate_nums']

	embedding.to(device)
	encoder.to(device)
	predict.to(device)
	generate.to(device)
	merge.to(device)

	logger.info('Successfully Loaded Checkpoint from {}, with epoch number: {} for {}'.format(ckpt_path, start_epoch, mode))

	if mode == 'train':
		embedding.train()
		encoder.train()
		predict.train()
		generate.train()
		merge.train()
	else:
		embedding.eval()
		encoder.eval()
		predict.eval()
		generate.eval()
		merge.eval()		

	return start_epoch, min_train_loss, max_train_acc, max_val_acc, equation_acc, best_epoch, generate_nums

def get_latest_checkpoint(model_path, logger):
	'''
		Looks for the checkpoint with highest epoch number in the directory "model_path" 

		Args:
			model_path: including the run_name
			logger variable: to log messages
		Returns:
			checkpoint: path to the latest checkpoint 
	'''

	ckpts = glob('{}/*.pt'.format(model_path))
	ckpts = sorted(ckpts)

	if len(ckpts) == 0:
		logger.warning('No Checkpoints Found')

		return None
	else:
		#pdb.set_trace()
		#latest_epoch = max([int(x.split('_')[-1].split('.')[0]) for x in ckpts])
		#ckpts = sorted(ckpts, key= lambda x: int(x.split('_')[-1].split('.')[0]) , reverse=True )
		ckpt_path = ckpts[0]
		#logger.info('Checkpoint found with epoch number : {}'.format(latest_epoch))
		logger.debug('Checkpoint found at : {}'.format(ckpt_path))

		return ckpt_path
    
#########################################
# logger.py
#########################################

def get_logger(name, log_file_path='./logs/temp.log', logging_level=logging.INFO, log_format='%(asctime)s | %(levelname)s | %(filename)s: %(lineno)s : %(funcName)s() ::\t %(message)s'):
	logger = logging.getLogger(name)
	logger.setLevel(logging_level)
	formatter = logging.Formatter(log_format)

	file_handler = logging.FileHandler(log_file_path, mode='w')
	file_handler.setLevel(logging_level)
	file_handler.setFormatter(formatter)

	stream_handler = logging.StreamHandler()
	stream_handler.setLevel(logging_level)
	stream_handler.setFormatter(formatter)

	logger.addHandler(file_handler)
	logger.addHandler(stream_handler)

	return logger

def print_log(logger, dict):
	string = ''
	for key, value in dict.items():
		string += '\n {}: {}\t'.format(key.replace('_', ' '), value)
	logger.info(string)

def store_results(config, max_train_acc, max_val_acc, eq_acc, min_train_loss, best_epoch):
	try:
		with open(config.result_path) as f:
			res_data =json.load(f)
	except:
		res_data = {}
	try:
		min_train_loss = min_train_loss.item()
	except:
		pass
	# try:
	# 	min_val_loss = min_val_loss.item()
	# except:
	# 	pass
	# try:
	data= {'run name' : str(config.run_name)
	, 'max val acc': str(max_val_acc)
	, 'equation acc': str(eq_acc)
	, 'max train acc': str(max_train_acc)
	, 'min train loss': str(min_train_loss)
	, 'best epoch': str(best_epoch)
	, 'epochs' : config.epochs
	, 'dataset' : config.dataset
	, 'embedding': config.embedding
	, 'embedding_size' : config.embedding_size
	, 'embedding_lr': config.emb_lr
	, 'freeze_emb': config.freeze_emb
	, 'cell_type' : config.cell_type
	, 'hidden_size' : config.hidden_size
	, 'depth' : config.depth
	, 'lr' : config.lr
	, 'batch_size' : config.batch_size
	, 'dropout' : config.dropout
	}
	res_data[str(config.run_name)] = data

	# with open(config.result_path, 'w', encoding='utf-8') as f:
	# 	json.dump(res_data, f, ensure_ascii= False, indent= 4)
	# except:
	# 	pdb.set_trace()

def store_val_results(config, acc_score, folds_scores):
	try:
		with open(config.val_result_path) as f:
			res_data = json.load(f)
	except:
		res_data = {}

	try:
		data= {'run_name' : str(config.run_name)
		, '5-fold avg acc score' : str(acc_score)
		, 'Fold0 acc' : folds_scores[0]
		, 'Fold1 acc' : folds_scores[1]
		, 'Fold2 acc' : folds_scores[2]
		, 'Fold3 acc' : folds_scores[3]
		, 'Fold4 acc' : folds_scores[4]
		, 'epochs' : config.epochs
		, 'embedding': config.embedding
		, 'embedding_size' : config.embedding_size
		, 'embedding_lr': config.emb_lr
		, 'freeze_emb': config.freeze_emb
		, 'cell_type' : config.cell_type
		, 'hidden_size' : config.hidden_size
		, 'depth' : config.depth
		, 'lr' : config.lr
		, 'batch_size' : config.batch_size
		, 'dropout' : config.dropout
		}
		res_data[str(config.run_name)] = data

		# with open(config.val_result_path, 'w', encoding='utf-8') as f:
		# 	json.dump(res_data, f, ensure_ascii= False, indent= 4)
	except:
		pdb.set_trace()
        
        
#########################################
# pre_data.py
#########################################

PAD_token = 0


class Lang:
	"""
	class to save the vocab and two dict: the word->index and index->word
	"""
	def __init__(self):
		self.word2index = {}
		self.word2count = {}
		self.index2word = []
		self.n_words = 0  # Count word tokens
		self.num_start = 0

	def add_sen_to_vocab(self, sentence):  # add words of sentence to vocab
		for word in sentence:
			if re.search("N\d+|NUM|\d+", word):
				continue
			if word not in self.index2word:
				self.word2index[word] = self.n_words
				self.word2count[word] = 1
				self.index2word.append(word)
				self.n_words += 1
			else:
				self.word2count[word] += 1

	def trim(self, min_count):  # trim words below a certain count threshold
		keep_words = []

		for k, v in self.word2count.items():
			if v >= min_count:
				keep_words.append(k)

		print('keep_words %s / %s = %.4f' % (
			len(keep_words), len(self.index2word), len(keep_words) / len(self.index2word)
		))

		# Reinitialize dictionaries
		self.word2index = {}
		# self.word2count = {}
		self.index2word = []
		self.n_words = 0  # Count default tokens

		for word in keep_words:
			self.word2index[word] = self.n_words
			self.index2word.append(word)
			self.n_words += 1

	def build_input_lang(self, logger, trim_min_count):  # build the input lang vocab and dict
		if trim_min_count > 0:
			self.trim(trim_min_count)
			self.index2word = ["PAD", "NUM", "UNK"] + self.index2word
		else:
			self.index2word = ["PAD", "NUM"] + self.index2word
		self.word2index = {}
		self.n_words = len(self.index2word)
		for i, j in enumerate(self.index2word):
			self.word2index[j] = i

	def build_output_lang(self, generate_num, copy_nums):  # build the output lang vocab and dict
		self.index2word = ["PAD", "EOS"] + self.index2word + generate_num + ["N" + str(i) for i in range(copy_nums)] +\
						  ["SOS", "UNK"]
		self.n_words = len(self.index2word)
		for i, j in enumerate(self.index2word):
			self.word2index[j] = i

	def build_output_lang_for_tree(self, generate_num, copy_nums):  # build the output lang vocab and dict
		self.num_start = len(self.index2word)

		self.index2word = self.index2word + generate_num + ["N" + str(i) for i in range(copy_nums)] + ["UNK"]
		self.n_words = len(self.index2word)

		for i, j in enumerate(self.index2word):
			self.word2index[j] = i

def load_raw_data(data_path, dataset, is_train = True):  # load the data to list(dict())
	train_ls = None
	if is_train:
		train_path = os.path.join(data_path, dataset, 'train.csv')
		train_df = pd.read_csv(train_path, converters={'group_nums': eval})
		train_ls = train_df.to_dict('records')

	dev_path = os.path.join(data_path, dataset, 'dev.csv')
	dev_df = pd.read_csv(dev_path, converters={'group_nums': eval})
	dev_ls = dev_df.to_dict('records')

	return train_ls, dev_ls


# remove the superfluous brackets
def remove_brackets(x):
	y = x
	if x[0] == "(" and x[-1] == ")":
		x = x[1:-1]
		flag = True
		count = 0
		for s in x:
			if s == ")":
				count -= 1
				if count < 0:
					flag = False
					break
			elif s == "(":
				count += 1
		if flag:
			return x
	return y


def load_mawps_data(filename):  # load the json data to list(dict()) for MAWPS
	print("Reading lines...")
	f = open(filename, encoding="utf-8")
	data = json.load(f)
	out_data = []
	for d in data:
		if "lEquations" not in d or len(d["lEquations"]) != 1: # Only single equations
			continue
		x = d["lEquations"][0].replace(" ", "")

		if "lQueryVars" in d and len(d["lQueryVars"]) == 1: # When Equations are annotated with variables
			v = d["lQueryVars"][0]
			if v + "=" == x[:len(v)+1]: # If eqn of the form 'Var=...'  
				xt = x[len(v)+1:]
				if len(set(xt) - set("0123456789.+-*/()")) == 0:
					temp = d.copy()
					temp["lEquations"] = xt
					out_data.append(temp)
					continue

			if "=" + v == x[-len(v)-1:]: # If eqn of the form '...=Var'
				xt = x[:-len(v)-1]
				if len(set(xt) - set("0123456789.+-*/()")) == 0:
					temp = d.copy()
					temp["lEquations"] = xt
					out_data.append(temp)
					continue

		if len(set(x) - set("0123456789.+-*/()=xX")) != 0: # If equation has anything not in the set on RHS of -
			continue

		if x[:2] == "x=" or x[:2] == "X=":
			if len(set(x[2:]) - set("0123456789.+-*/()")) == 0:
				temp = d.copy()
				temp["lEquations"] = x[2:]
				out_data.append(temp)
				continue
		if x[-2:] == "=x" or x[-2:] == "=X":
			if len(set(x[:-2]) - set("0123456789.+-*/()")) == 0:
				temp = d.copy()
				temp["lEquations"] = x[:-2]
				out_data.append(temp)
				continue
	return out_data


def load_roth_data(filename):  # load the json data to dict(dict()) for roth data
	print("Reading lines...")
	f = open(filename, encoding="utf-8")
	data = json.load(f)
	out_data = {}
	for d in data:
		if "lEquations" not in d or len(d["lEquations"]) != 1:
			continue
		x = d["lEquations"][0].replace(" ", "")

		if "lQueryVars" in d and len(d["lQueryVars"]) == 1:
			v = d["lQueryVars"][0]
			if v + "=" == x[:len(v)+1]:
				xt = x[len(v)+1:]
				if len(set(xt) - set("0123456789.+-*/()")) == 0:
					temp = d.copy()
					temp["lEquations"] = remove_brackets(xt)
					y = temp["sQuestion"]
					seg = y.strip().split(" ")
					temp_y = ""
					for s in seg:
						if len(s) > 1 and (s[-1] == "," or s[-1] == "." or s[-1] == "?"):
							temp_y += s[:-1] + " " + s[-1:] + " "
						else:
							temp_y += s + " "
					temp["sQuestion"] = temp_y[:-1]
					out_data[temp["iIndex"]] = temp
					continue

			if "=" + v == x[-len(v)-1:]:
				xt = x[:-len(v)-1]
				if len(set(xt) - set("0123456789.+-*/()")) == 0:
					temp = d.copy()
					temp["lEquations"] = remove_brackets(xt)
					y = temp["sQuestion"]
					seg = y.strip().split(" ")
					temp_y = ""
					for s in seg:
						if len(s) > 1 and (s[-1] == "," or s[-1] == "." or s[-1] == "?"):
							temp_y += s[:-1] + " " + s[-1:] + " "
						else:
							temp_y += s + " "
					temp["sQuestion"] = temp_y[:-1]
					out_data[temp["iIndex"]] = temp
					continue

		if len(set(x) - set("0123456789.+-*/()=xX")) != 0:
			continue

		if x[:2] == "x=" or x[:2] == "X=":
			if len(set(x[2:]) - set("0123456789.+-*/()")) == 0:
				temp = d.copy()
				temp["lEquations"] = remove_brackets(x[2:])
				y = temp["sQuestion"]
				seg = y.strip().split(" ")
				temp_y = ""
				for s in seg:
					if len(s) > 1 and (s[-1] == "," or s[-1] == "." or s[-1] == "?"):
						temp_y += s[:-1] + " " + s[-1:] + " "
					else:
						temp_y += s + " "
				temp["sQuestion"] = temp_y[:-1]
				out_data[temp["iIndex"]] = temp
				continue
		if x[-2:] == "=x" or x[-2:] == "=X":
			if len(set(x[:-2]) - set("0123456789.+-*/()")) == 0:
				temp = d.copy()
				temp["lEquations"] = remove_brackets(x[2:])
				y = temp["sQuestion"]
				seg = y.strip().split(" ")
				temp_y = ""
				for s in seg:
					if len(s) > 1 and (s[-1] == "," or s[-1] == "." or s[-1] == "?"):
						temp_y += s[:-1] + " " + s[-1:] + " "
					else:
						temp_y += s + " "
				temp["sQuestion"] = temp_y[:-1]
				out_data[temp["iIndex"]] = temp
				continue
	return out_data


def transfer_num(train_ls, dev_ls, chall = False):  # transfer num into "NUM"
	print("Transfer numbers...")
	dev_pairs = []
	generate_nums = []
	generate_nums_dict = {}
	copy_nums = 0

	if train_ls != None:
		train_pairs = []
		for d in train_ls:
			# nums = []
			nums = d['Numbers'].split()
			input_seq = []
			seg = nltk.word_tokenize(d["Question"].strip())
			equation = d["Equation"].split()

			numz = ['0','1','2','3','4','5','6','7','8','9']
			opz = ['+', '-', '*', '/']
			idxs = []
			for s in range(len(seg)):
				if len(seg[s]) >= 7 and seg[s][:6] == "number" and seg[s][6] in numz:
					input_seq.append("NUM")
					idxs.append(s)
				else:
					input_seq.append(seg[s])
			if copy_nums < len(nums):
				copy_nums = len(nums)

			out_seq = []
			for e1 in equation:
				if len(e1) >= 7 and e1[:6] == "number":
					out_seq.append('N'+e1[6:])
				elif e1 not in opz:
					generate_nums.append(e1)
					if e1 not in generate_nums_dict:
						generate_nums_dict[e1] = 1
					else:
						generate_nums_dict[e1] += 1
					out_seq.append(e1)
				else:
					out_seq.append(e1)

			train_pairs.append((input_seq, out_seq, nums, idxs, d['group_nums']))
	else:
		train_pairs = None

	for d in dev_ls:
		# nums = []
		nums = d['Numbers'].split()
		input_seq = []
		try:
			seg = nltk.word_tokenize(d["Question"].strip())
		except:
			pdb.set_trace()
		equation = d["Equation"].split()

		numz = ['0','1','2','3','4','5','6','7','8','9']
		opz = ['+', '-', '*', '/']
		idxs = []
		for s in range(len(seg)):
			if len(seg[s]) >= 7 and seg[s][:6] == "number" and seg[s][6] in numz:
				input_seq.append("NUM")
				idxs.append(s)
			else:
				input_seq.append(seg[s])
		if copy_nums < len(nums):
			copy_nums = len(nums)

		out_seq = []
		for e1 in equation:
			if len(e1) >= 7 and e1[:6] == "number":
				out_seq.append('N'+e1[6:])
			elif e1 not in opz:
				generate_nums.append(e1)
				if e1 not in generate_nums_dict:
					generate_nums_dict[e1] = 1
				else:
					generate_nums_dict[e1] += 1
				out_seq.append(e1)
			else:
				out_seq.append(e1)
		if chall:
			dev_pairs.append((input_seq, out_seq, nums, idxs, d['group_nums'], d['Type'], d['Variation Type'], d['Annotator'], d['Alternate']))
		else:
			dev_pairs.append((input_seq, out_seq, nums, idxs, d['group_nums']))

	temp_g = []
	for g in generate_nums_dict:
		if generate_nums_dict[g] >= 5:
			temp_g.append(g)
	return train_pairs, dev_pairs, temp_g, copy_nums


def transfer_english_num(data):  # transfer num into "NUM"
	print("Transfer numbers...")
	pattern = re.compile("\d+,\d+|\d+\.\d+|\d+")
	pairs = []
	generate_nums = {} # Unmentioned numbers used in eqns in atleast 5 examples
	copy_nums = 0 # Maximum number of numbers in a single sentence
	for d in data:
		nums = []
		input_seq = []
		seg = d["sQuestion"].strip().split(" ")
		equations = d["lEquations"]

		for s in seg:
			pos = re.search(pattern, s)
			if pos:
				if pos.start() > 0:
					input_seq.append(s[:pos.start()])
				num = s[pos.start(): pos.end()]
				nums.append(num.replace(",", ""))
				input_seq.append("NUM")
				if pos.end() < len(s):
					input_seq.append(s[pos.end():])
			else:
				input_seq.append(s)

		if copy_nums < len(nums):
			copy_nums = len(nums)
		eq_segs = []
		temp_eq = ""
		for e in equations:
			if e not in "()+-*/":
				temp_eq += e
			elif temp_eq != "":
				count_eq = []
				for n_idx, n in enumerate(nums):
					if abs(float(n) - float(temp_eq)) < 1e-4:
						count_eq.append(n_idx)
						if n != temp_eq:
							nums[n_idx] = temp_eq
				if len(count_eq) == 0:
					flag = True
					for gn in generate_nums:
						if abs(float(gn) - float(temp_eq)) < 1e-4:
							generate_nums[gn] += 1
							if temp_eq != gn:
								temp_eq = gn
							flag = False
					if flag:
						generate_nums[temp_eq] = 0
					eq_segs.append(temp_eq)
				elif len(count_eq) == 1:
					eq_segs.append("N"+str(count_eq[0]))
				else:
					eq_segs.append(temp_eq)
				eq_segs.append(e)
				temp_eq = ""
			else:
				eq_segs.append(e)
		if temp_eq != "":
			count_eq = []
			for n_idx, n in enumerate(nums):
				if abs(float(n) - float(temp_eq)) < 1e-4:
					count_eq.append(n_idx)
					if n != temp_eq:
						nums[n_idx] = temp_eq
			if len(count_eq) == 0:
				flag = True
				for gn in generate_nums:
					if abs(float(gn) - float(temp_eq)) < 1e-4:
						generate_nums[gn] += 1
						if temp_eq != gn:
							temp_eq = gn
						flag = False
				if flag:
					generate_nums[temp_eq] = 0
				eq_segs.append(temp_eq)
			elif len(count_eq) == 1:
				eq_segs.append("N" + str(count_eq[0]))
			else:
				eq_segs.append(temp_eq)

		num_pos = []
		for i, j in enumerate(input_seq):
			if j == "NUM":
				num_pos.append(i)
		if len(nums) != 0:
			pairs.append((input_seq, eq_segs, nums, num_pos))

	temp_g = []
	for g in generate_nums:
		if generate_nums[g] >= 5:
			temp_g.append(g)

	return pairs, temp_g, copy_nums


def transfer_roth_num(data):  # transfer num into "NUM"
	print("Transfer numbers...")
	pattern = re.compile("\d+,\d+|\d+\.\d+|\d+")
	pairs = {}
	generate_nums = {}
	copy_nums = 0
	for key in data:
		d = data[key]
		nums = []
		input_seq = []
		seg = d["sQuestion"].strip().split(" ")
		equations = d["lEquations"]

		for s in seg:
			pos = re.search(pattern, s)
			if pos:
				if pos.start() > 0:
					input_seq.append(s[:pos.start()])
				num = s[pos.start(): pos.end()]
				nums.append(num.replace(",", ""))
				input_seq.append("NUM")
				if pos.end() < len(s):
					input_seq.append(s[pos.end():])
			else:
				input_seq.append(s)

		if copy_nums < len(nums):
			copy_nums = len(nums)
		eq_segs = []
		temp_eq = ""
		for e in equations:
			if e not in "()+-*/":
				temp_eq += e
			elif temp_eq != "":
				count_eq = []
				for n_idx, n in enumerate(nums):
					if abs(float(n) - float(temp_eq)) < 1e-4:
						count_eq.append(n_idx)
						if n != temp_eq:
							nums[n_idx] = temp_eq
				if len(count_eq) == 0:
					flag = True
					for gn in generate_nums:
						if abs(float(gn) - float(temp_eq)) < 1e-4:
							generate_nums[gn] += 1
							if temp_eq != gn:
								temp_eq = gn
							flag = False
					if flag:
						generate_nums[temp_eq] = 0
					eq_segs.append(temp_eq)
				elif len(count_eq) == 1:
					eq_segs.append("N"+str(count_eq[0]))
				else:
					eq_segs.append(temp_eq)
				eq_segs.append(e)
				temp_eq = ""
			else:
				eq_segs.append(e)
		if temp_eq != "":
			count_eq = []
			for n_idx, n in enumerate(nums):
				if abs(float(n) - float(temp_eq)) < 1e-4:
					count_eq.append(n_idx)
					if n != temp_eq:
						nums[n_idx] = temp_eq
			if len(count_eq) == 0:
				flag = True
				for gn in generate_nums:
					if abs(float(gn) - float(temp_eq)) < 1e-4:
						generate_nums[gn] += 1
						if temp_eq != gn:
							temp_eq = gn
						flag = False
				if flag:
					generate_nums[temp_eq] = 0
				eq_segs.append(temp_eq)
			elif len(count_eq) == 1:
				eq_segs.append("N" + str(count_eq[0]))
			else:
				eq_segs.append(temp_eq)

		num_pos = []
		for i, j in enumerate(input_seq):
			if j == "NUM":
				num_pos.append(i)
		if len(nums) != 0:
			pairs[key] = (input_seq, eq_segs, nums, num_pos)

	temp_g = []
	for g in generate_nums:
		if generate_nums[g] >= 5:
			temp_g.append(g)

	return pairs, temp_g, copy_nums


# Return a list of indexes, one for each word in the sentence, plus EOS
def indexes_from_sentence(lang, sentence, tree=False):
	res = []
	for word in sentence:
		if len(word) == 0:
			continue
		if word in lang.word2index:
			res.append(lang.word2index[word])
		else:
			res.append(lang.word2index["UNK"])
	if "EOS" in lang.index2word and not tree:
		res.append(lang.word2index["EOS"])
	return res

def sentence_from_indexes(lang, indexes):
	sent = []
	for ind in indexes:
		sent.append(lang.index2word[ind])
	return sent


def prepare_data(config, logger, pairs_trained, pairs_tested, trim_min_count, generate_nums, copy_nums, input_lang=None, output_lang=None, tree=False):
	if input_lang == None:
		input_lang = Lang()
	if output_lang == None:
		output_lang = Lang()

	test_pairs = []
	train_pairs = None

	if pairs_trained != None:
		train_pairs = []
		for pair in pairs_trained:
			if not tree:
				input_lang.add_sen_to_vocab(pair[0])
				output_lang.add_sen_to_vocab(pair[1])
			elif pair[-1]:
				input_lang.add_sen_to_vocab(pair[0])
				output_lang.add_sen_to_vocab(pair[1])

	if config.embedding == 'bert' or config.embedding == 'roberta':
		for pair in pairs_tested:
			if not tree:
				input_lang.add_sen_to_vocab(pair[0])
			elif pair[-1]:
				input_lang.add_sen_to_vocab(pair[0])

	if pairs_trained != None:

		input_lang.build_input_lang(logger, trim_min_count)
		if tree:
			output_lang.build_output_lang_for_tree(generate_nums, copy_nums)
		else:
			output_lang.build_output_lang(generate_nums, copy_nums)

		for pair in pairs_trained:
			num_stack = []
			for word in pair[1]: # For each token in equation
				temp_num = []
				flag_not = True
				if word not in output_lang.index2word: # If token is not in output vocab
					flag_not = False
					for i, j in enumerate(pair[2]):
						if j == word:
							temp_num.append(i) # Append number list index of token not in output vocab

				if not flag_not and len(temp_num) != 0: # Equation has an unknown token and it is a number present in number list (could be default number with freq < 5)
					num_stack.append(temp_num)
				if not flag_not and len(temp_num) == 0: # Equation has an unknown token but it is not a number from number list
					num_stack.append([_ for _ in range(len(pair[2]))])

			num_stack.reverse()
			input_cell = indexes_from_sentence(input_lang, pair[0])
			output_cell = indexes_from_sentence(output_lang, pair[1], tree)
			train_pairs.append((input_cell, len(input_cell), output_cell, len(output_cell),
								pair[2], pair[3], num_stack, pair[4]))

	logger.debug('Indexed {} words in input language, {} words in output'.format(input_lang.n_words, output_lang.n_words))

	for pair in pairs_tested:
		num_stack = []
		for word in pair[1]:
			temp_num = []
			flag_not = True
			if word not in output_lang.index2word:
				flag_not = False
				for i, j in enumerate(pair[2]):
					if j == word:
						temp_num.append(i)

			if not flag_not and len(temp_num) != 0:
				num_stack.append(temp_num)
			if not flag_not and len(temp_num) == 0:
				num_stack.append([_ for _ in range(len(pair[2]))])

		num_stack.reverse()
		input_cell = indexes_from_sentence(input_lang, pair[0])
		output_cell = indexes_from_sentence(output_lang, pair[1], tree)
		if config.challenge_disp:
			test_pairs.append((input_cell, len(input_cell), output_cell, len(output_cell),
						   pair[2], pair[3], num_stack, pair[4], pair[5], pair[6], pair[7], pair[8]))
		else:
			test_pairs.append((input_cell, len(input_cell), output_cell, len(output_cell),
						   pair[2], pair[3], num_stack, pair[4]))

	return input_lang, output_lang, train_pairs, test_pairs


def prepare_de_data(pairs_trained, pairs_tested, trim_min_count, generate_nums, copy_nums, tree=False):
	input_lang = Lang()
	output_lang = Lang()
	train_pairs = []
	test_pairs = []

	print("Indexing words...")
	for pair in pairs_trained:
		input_lang.add_sen_to_vocab(pair[0])
		output_lang.add_sen_to_vocab(pair[1])

	input_lang.build_input_lang(trim_min_count)

	if tree:
		output_lang.build_output_lang_for_tree(generate_nums, copy_nums)
	else:
		output_lang.build_output_lang(generate_nums, copy_nums)

	for pair in pairs_trained:
		num_stack = []
		for word in pair[1]:
			temp_num = []
			flag_not = True
			if word not in output_lang.index2word:
				flag_not = False
				for i, j in enumerate(pair[2]):
					if j == word:
						temp_num.append(i)

			if not flag_not and len(temp_num) != 0:
				num_stack.append(temp_num)
			if not flag_not and len(temp_num) == 0:
				num_stack.append([_ for _ in range(len(pair[2]))])

		num_stack.reverse()
		input_cell = indexes_from_sentence(input_lang, pair[0])
		# train_pairs.append([input_cell, len(input_cell), pair[1], 0, pair[2], pair[3], num_stack, pair[4]])
		train_pairs.append([input_cell, len(input_cell), pair[1], 0, pair[2], pair[3], num_stack])
	print('Indexed %d words in input language, %d words in output' % (input_lang.n_words, output_lang.n_words))
	print('Number of training data %d' % (len(train_pairs)))
	for pair in pairs_tested:
		num_stack = []
		for word in pair[1]:
			temp_num = []
			flag_not = True
			if word not in output_lang.index2word:
				flag_not = False
				for i, j in enumerate(pair[2]):
					if j == word:
						temp_num.append(i)

			if not flag_not and len(temp_num) != 0:
				num_stack.append(temp_num)
			if not flag_not and len(temp_num) == 0:
				num_stack.append([_ for _ in range(len(pair[2]))])

		num_stack.reverse()
		input_cell = indexes_from_sentence(input_lang, pair[0])
		output_cell = indexes_from_sentence(output_lang, pair[1], tree)
		# train_pairs.append((input_cell, len(input_cell), output_cell, len(output_cell),
		#                     pair[2], pair[3], num_stack, pair[4]))
		test_pairs.append((input_cell, len(input_cell), output_cell, len(output_cell),
						   pair[2], pair[3], num_stack))
	print('Number of testind data %d' % (len(test_pairs)))
	return input_lang, output_lang, train_pairs, test_pairs


# Pad a with the PAD symbol
def pad_seq(seq, seq_len, max_length):
	seq += [PAD_token for _ in range(max_length - seq_len)]
	return seq

def change_num(num):
	new_num = []
	for item in num:
		if '/' in item:
			new_str = item.split(')')[0]
			new_str = new_str.split('(')[1]
			a = float(new_str.split('/')[0])
			b = float(new_str.split('/')[1])
			value = a/b
			new_num.append(value)
		elif '%' in item:
			value = float(item[0:-1])/100
			new_num.append(value)
		else:
			new_num.append(float(item))
	return new_num

# num net graph
def get_lower_num_graph(max_len, sentence_length, num_list, id_num_list,contain_zh_flag=True):
	diag_ele = np.zeros(max_len)
	num_list = change_num(num_list)
	for i in range(sentence_length):
		diag_ele[i] = 1
	graph = np.diag(diag_ele)
	if not contain_zh_flag:
		return graph
	for i in range(len(id_num_list)):
		for j in range(len(id_num_list)):
			if float(num_list[i]) <= float(num_list[j]):
				graph[id_num_list[i]][id_num_list[j]] = 1
			else:
				graph[id_num_list[j]][id_num_list[i]] = 1
	return graph

def get_greater_num_graph(max_len, sentence_length, num_list, id_num_list,contain_zh_flag=True):
	diag_ele = np.zeros(max_len)
	num_list = change_num(num_list)
	for i in range(sentence_length):
		diag_ele[i] = 1
	graph = np.diag(diag_ele)
	if not contain_zh_flag:
		return graph
	for i in range(len(id_num_list)):
		for j in range(len(id_num_list)):
			if float(num_list[i]) > float(num_list[j]):
				graph[id_num_list[i]][id_num_list[j]] = 1
			else:
				graph[id_num_list[j]][id_num_list[i]] = 1
	return graph

# attribute between graph
def get_attribute_between_graph(input_batch, max_len, id_num_list, sentence_length, quantity_cell_list,contain_zh_flag=True):
	diag_ele = np.zeros(max_len)
	for i in range(sentence_length):
		diag_ele[i] = 1
	graph = np.diag(diag_ele)
	#quantity_cell_list = quantity_cell_list.extend(id_num_list)
	if not contain_zh_flag:
		return graph
	for i in id_num_list:
		for j in quantity_cell_list:
			if i < max_len and j < max_len and j not in id_num_list and abs(i-j) < 4:
				graph[i][j] = 1
				graph[j][i] = 1
	for i in quantity_cell_list:
		for j in quantity_cell_list:
			if i < max_len and j < max_len:
				if input_batch[i] == input_batch[j]:
					graph[i][j] = 1
					graph[j][i] = 1
	return graph

# quantity between graph
def get_quantity_between_graph(max_len, id_num_list, sentence_length, quantity_cell_list,contain_zh_flag=True):
	diag_ele = np.zeros(max_len)
	for i in range(sentence_length):
		diag_ele[i] = 1
	graph = np.diag(diag_ele)
	#quantity_cell_list = quantity_cell_list.extend(id_num_list)
	if not contain_zh_flag:
		return graph
	for i in id_num_list:
		for j in quantity_cell_list:
			if i < max_len and j < max_len and j not in id_num_list and abs(i-j) < 4:
				graph[i][j] = 1
				graph[j][i] = 1
	for i in id_num_list:
		for j in id_num_list:
			graph[i][j] = 1
			graph[j][i] = 1
	return graph

# quantity cell graph
def get_quantity_cell_graph(max_len, id_num_list, sentence_length, quantity_cell_list,contain_zh_flag=True):
	diag_ele = np.zeros(max_len)
	for i in range(sentence_length):
		diag_ele[i] = 1
	graph = np.diag(diag_ele)
	#quantity_cell_list = quantity_cell_list.extend(id_num_list)
	if not contain_zh_flag:
		return graph
	for i in id_num_list:
		for j in quantity_cell_list:
			if i < max_len and j < max_len and j not in id_num_list and abs(i-j) < 4:
				graph[i][j] = 1
				graph[j][i] = 1
	return graph

def get_single_batch_graph(input_batch, input_length,group,num_value,num_pos):
	# pdb.set_trace()
	batch_graph = []
	max_len = max(input_length)
	for i in range(len(input_length)):
		input_batch_t = input_batch[i]
		sentence_length = input_length[i]
		quantity_cell_list = group[i]
		num_list = num_value[i]
		id_num_list = num_pos[i]
		graph_newc = get_quantity_cell_graph(max_len, id_num_list, sentence_length, quantity_cell_list)
		graph_greater = get_greater_num_graph(max_len, sentence_length, num_list, id_num_list)
		graph_lower = get_lower_num_graph(max_len, sentence_length, num_list, id_num_list)
		graph_quanbet = get_quantity_between_graph(max_len, id_num_list, sentence_length, quantity_cell_list)
		graph_attbet = get_attribute_between_graph(input_batch_t, max_len, id_num_list, sentence_length, quantity_cell_list)
		#graph_newc1 = get_quantity_graph1(input_batch_t, max_len, id_num_list, sentence_length, quantity_cell_list)
		graph_total = [graph_newc.tolist(),graph_greater.tolist(),graph_lower.tolist(),graph_quanbet.tolist(),graph_attbet.tolist()]
		batch_graph.append(graph_total)
	batch_graph = np.array(batch_graph)
	return batch_graph

def get_single_example_graph(input_batch, input_length,group,num_value,num_pos):
	batch_graph = []
	max_len = input_length
	sentence_length = input_length
	quantity_cell_list = group
	num_list = num_value
	id_num_list = num_pos
	graph_newc = get_quantity_cell_graph(max_len, id_num_list, sentence_length, quantity_cell_list)
	graph_quanbet = get_quantity_between_graph(max_len, id_num_list, sentence_length, quantity_cell_list)
	graph_attbet = get_attribute_between_graph(input_batch, max_len, id_num_list, sentence_length, quantity_cell_list)
	graph_greater = get_greater_num_graph(max_len, sentence_length, num_list, id_num_list)
	graph_lower = get_greater_num_graph(max_len, sentence_length, num_list, id_num_list)
	#graph_newc1 = get_quantity_graph1(input_batch, max_len, id_num_list, sentence_length, quantity_cell_list)
	graph_total = [graph_newc.tolist(),graph_greater.tolist(),graph_lower.tolist(),graph_quanbet.tolist(),graph_attbet.tolist()]
	batch_graph.append(graph_total)
	batch_graph = np.array(batch_graph)
	return batch_graph

# prepare the batches
def prepare_train_batch(pairs_to_batch, batch_size):
	pairs = copy.deepcopy(pairs_to_batch)
	random.shuffle(pairs)  # shuffle the pairs
	pos = 0
	input_lengths = []
	output_lengths = []
	nums_batches = []
	batches = []
	input_batches = []
	output_batches = []
	num_stack_batches = []  # save the num stack which
	num_pos_batches = []
	num_size_batches = []
	group_batches = []
	graph_batches = []
	num_value_batches = []
	while pos + batch_size < len(pairs):
		batches.append(pairs[pos:pos+batch_size])
		pos += batch_size
	batches.append(pairs[pos:])

	for batch in batches:
		batch = sorted(batch, key=lambda tp: tp[1], reverse=True)
		input_length = []
		output_length = []
		for _, i, _, j, _, _, _,_ in batch:
			input_length.append(i)
			output_length.append(j)
		input_lengths.append(input_length)
		output_lengths.append(output_length)
		input_len_max = input_length[0]
		output_len_max = max(output_length)
		input_batch = []
		output_batch = []
		num_batch = []
		num_stack_batch = []
		num_pos_batch = []
		num_size_batch = []
		group_batch = []
		num_value_batch = []
		for i, li, j, lj, num, num_pos, num_stack, group in batch:
			num_batch.append(len(num))
			input_batch.append(pad_seq(i, li, input_len_max))
			output_batch.append(pad_seq(j, lj, output_len_max))
			num_stack_batch.append(num_stack)
			num_pos_batch.append(num_pos)
			num_size_batch.append(len(num_pos))
			num_value_batch.append(num)
			group_batch.append(group)
			
		input_batches.append(input_batch)
		nums_batches.append(num_batch)
		output_batches.append(output_batch)
		num_stack_batches.append(num_stack_batch)
		num_pos_batches.append(num_pos_batch)
		num_size_batches.append(num_size_batch)
		num_value_batches.append(num_value_batch)
		group_batches.append(group_batch)
		graph_batches.append(get_single_batch_graph(input_batch, input_length,group_batch,num_value_batch,num_pos_batch))
		
	return input_batches, input_lengths, output_batches, output_lengths, nums_batches, num_stack_batches, num_pos_batches, num_size_batches, num_value_batches, graph_batches, group_batches

def get_num_stack(eq, output_lang, num_pos):
	num_stack = []
	for word in eq:
		temp_num = []
		flag_not = True
		if word not in output_lang.index2word:
			flag_not = False
			for i, j in enumerate(num_pos):
				if j == word:
					temp_num.append(i)
		if not flag_not and len(temp_num) != 0:
			num_stack.append(temp_num)
		if not flag_not and len(temp_num) == 0:
			num_stack.append([_ for _ in range(len(num_pos))])
	num_stack.reverse()
	return num_stack


def prepare_de_train_batch(pairs_to_batch, batch_size, output_lang, rate, english=False):
	pairs = []
	b_pairs = copy.deepcopy(pairs_to_batch)
	for pair in b_pairs:
		p = copy.deepcopy(pair)
		pair[2] = check_bracket(pair[2], english)

		temp_out = exchange(pair[2], rate)
		temp_out = check_bracket(temp_out, english)

		p[2] = indexes_from_sentence(output_lang, pair[2])
		p[3] = len(p[2])
		pairs.append(p)

		temp_out_a = allocation(pair[2], rate)
		temp_out_a = check_bracket(temp_out_a, english)

		if temp_out_a != pair[2]:
			p = copy.deepcopy(pair)
			p[6] = get_num_stack(temp_out_a, output_lang, p[4])
			p[2] = indexes_from_sentence(output_lang, temp_out_a)
			p[3] = len(p[2])
			pairs.append(p)

		if temp_out != pair[2]:
			p = copy.deepcopy(pair)
			p[6] = get_num_stack(temp_out, output_lang, p[4])
			p[2] = indexes_from_sentence(output_lang, temp_out)
			p[3] = len(p[2])
			pairs.append(p)

			if temp_out_a != pair[2]:
				p = copy.deepcopy(pair)
				temp_out_a = allocation(temp_out, rate)
				temp_out_a = check_bracket(temp_out_a, english)
				if temp_out_a != temp_out:
					p[6] = get_num_stack(temp_out_a, output_lang, p[4])
					p[2] = indexes_from_sentence(output_lang, temp_out_a)
					p[3] = len(p[2])
					pairs.append(p)
	print("this epoch training data is", len(pairs))
	random.shuffle(pairs)  # shuffle the pairs
	pos = 0
	input_lengths = []
	output_lengths = []
	nums_batches = []
	batches = []
	input_batches = []
	output_batches = []
	num_stack_batches = []  # save the num stack which
	num_pos_batches = []
	while pos + batch_size < len(pairs):
		batches.append(pairs[pos:pos+batch_size])
		pos += batch_size
	batches.append(pairs[pos:])

	for batch in batches:
		batch = sorted(batch, key=lambda tp: tp[1], reverse=True)
		input_length = []
		output_length = []
		for _, i, _, j, _, _, _ in batch:
			input_length.append(i)
			output_length.append(j)
		input_lengths.append(input_length)
		output_lengths.append(output_length)
		input_len_max = input_length[0]
		output_len_max = max(output_length)
		input_batch = []
		output_batch = []
		num_batch = []
		num_stack_batch = []
		num_pos_batch = []
		for i, li, j, lj, num, num_pos, num_stack in batch:
			num_batch.append(len(num))
			input_batch.append(pad_seq(i, li, input_len_max))
			output_batch.append(pad_seq(j, lj, output_len_max))
			num_stack_batch.append(num_stack)
			num_pos_batch.append(num_pos)
		input_batches.append(input_batch)
		nums_batches.append(num_batch)
		output_batches.append(output_batch)
		num_stack_batches.append(num_stack_batch)
		num_pos_batches.append(num_pos_batch)
	return input_batches, input_lengths, output_batches, output_lengths, nums_batches, num_stack_batches, num_pos_batches


# Multiplication exchange rate
def exchange(ex_copy, rate):
	ex = copy.deepcopy(ex_copy)
	idx = 1
	while idx < len(ex):
		s = ex[idx]
		if (s == "*" or s == "+") and random.random() < rate:
			lidx = idx - 1
			ridx = idx + 1
			if s == "+":
				flag = 0
				while not (lidx == -1 or ((ex[lidx] == "+" or ex[lidx] == "-") and flag == 0) or flag == 1):
					if ex[lidx] == ")" or ex[lidx] == "]":
						flag -= 1
					elif ex[lidx] == "(" or ex[lidx] == "[":
						flag += 1
					lidx -= 1
				if flag == 1:
					lidx += 2
				else:
					lidx += 1

				flag = 0
				while not (ridx == len(ex) or ((ex[ridx] == "+" or ex[ridx] == "-") and flag == 0) or flag == -1):
					if ex[ridx] == ")" or ex[ridx] == "]":
						flag -= 1
					elif ex[ridx] == "(" or ex[ridx] == "[":
						flag += 1
					ridx += 1
				if flag == -1:
					ridx -= 2
				else:
					ridx -= 1
			else:
				flag = 0
				while not (lidx == -1
						   or ((ex[lidx] == "+" or ex[lidx] == "-" or ex[lidx] == "*" or ex[lidx] == "/") and flag == 0)
						   or flag == 1):
					if ex[lidx] == ")" or ex[lidx] == "]":
						flag -= 1
					elif ex[lidx] == "(" or ex[lidx] == "[":
						flag += 1
					lidx -= 1
				if flag == 1:
					lidx += 2
				else:
					lidx += 1

				flag = 0
				while not (ridx == len(ex)
						   or ((ex[ridx] == "+" or ex[ridx] == "-" or ex[ridx] == "*" or ex[ridx] == "/") and flag == 0)
						   or flag == -1):
					if ex[ridx] == ")" or ex[ridx] == "]":
						flag -= 1
					elif ex[ridx] == "(" or ex[ridx] == "[":
						flag += 1
					ridx += 1
				if flag == -1:
					ridx -= 2
				else:
					ridx -= 1
			if lidx > 0 and ((s == "+" and ex[lidx - 1] == "-") or (s == "*" and ex[lidx - 1] == "/")):
				lidx -= 1
				ex = ex[:lidx] + ex[idx:ridx + 1] + ex[lidx:idx] + ex[ridx + 1:]
			else:
				ex = ex[:lidx] + ex[idx + 1:ridx + 1] + [s] + ex[lidx:idx] + ex[ridx + 1:]
			idx = ridx
		idx += 1
	return ex


def check_bracket(x, english=False):
	if english:
		for idx, s in enumerate(x):
			if s == '[':
				x[idx] = '('
			elif s == '}':
				x[idx] = ')'
		s = x[0]
		idx = 0
		if s == "(":
			flag = 1
			temp_idx = idx + 1
			while flag > 0 and temp_idx < len(x):
				if x[temp_idx] == ")":
					flag -= 1
				elif x[temp_idx] == "(":
					flag += 1
				temp_idx += 1
			if temp_idx == len(x):
				x = x[idx + 1:temp_idx - 1]
			elif x[temp_idx] != "*" and x[temp_idx] != "/":
				x = x[idx + 1:temp_idx - 1] + x[temp_idx:]
		while True:
			y = len(x)
			for idx, s in enumerate(x):
				if s == "+" and idx + 1 < len(x) and x[idx + 1] == "(":
					flag = 1
					temp_idx = idx + 2
					while flag > 0 and temp_idx < len(x):
						if x[temp_idx] == ")":
							flag -= 1
						elif x[temp_idx] == "(":
							flag += 1
						temp_idx += 1
					if temp_idx == len(x):
						x = x[:idx + 1] + x[idx + 2:temp_idx - 1]
						break
					elif x[temp_idx] != "*" and x[temp_idx] != "/":
						x = x[:idx + 1] + x[idx + 2:temp_idx - 1] + x[temp_idx:]
						break
			if y == len(x):
				break
		return x

	lx = len(x)
	for idx, s in enumerate(x):
		if s == "[":
			flag_b = 0
			flag = False
			temp_idx = idx
			while temp_idx < lx:
				if x[temp_idx] == "]":
					flag_b += 1
				elif x[temp_idx] == "[":
					flag_b -= 1
				if x[temp_idx] == "(" or x[temp_idx] == "[":
					flag = True
				if x[temp_idx] == "]" and flag_b == 0:
					break
				temp_idx += 1
			if not flag:
				x[idx] = "("
				x[temp_idx] = ")"
				continue
		if s == "(":
			flag_b = 0
			flag = False
			temp_idx = idx
			while temp_idx < lx:
				if x[temp_idx] == ")":
					flag_b += 1
				elif x[temp_idx] == "(":
					flag_b -= 1
				if x[temp_idx] == "[":
					flag = True
				if x[temp_idx] == ")" and flag_b == 0:
					break
				temp_idx += 1
			if not flag:
				x[idx] = "["
				x[temp_idx] = "]"
	return x


# Multiplication allocation rate
def allocation(ex_copy, rate):
	ex = copy.deepcopy(ex_copy)
	idx = 1
	lex = len(ex)
	while idx < len(ex):
		if (ex[idx] == "/" or ex[idx] == "*") and (ex[idx - 1] == "]" or ex[idx - 1] == ")"):
			ridx = idx + 1
			r_allo = []
			r_last = []
			flag = 0
			flag_mmd = False
			while ridx < lex:
				if ex[ridx] == "(" or ex[ridx] == "[":
					flag += 1
				elif ex[ridx] == ")" or ex[ridx] == "]":
					flag -= 1
				if flag == 0:
					if ex[ridx] == "+" or ex[ridx] == "-":
						r_last = ex[ridx:]
						r_allo = ex[idx + 1: ridx]
						break
					elif ex[ridx] == "*" or ex[ridx] == "/":
						flag_mmd = True
						r_last = [")"] + ex[ridx:]
						r_allo = ex[idx + 1: ridx]
						break
				elif flag == -1:
					r_last = ex[ridx:]
					r_allo = ex[idx + 1: ridx]
					break
				ridx += 1
			if len(r_allo) == 0:
				r_allo = ex[idx + 1:]
			flag = 0
			lidx = idx - 1
			flag_al = False
			flag_md = False
			while lidx > 0:
				if ex[lidx] == "(" or ex[lidx] == "[":
					flag -= 1
				elif ex[lidx] == ")" or ex[lidx] == "]":
					flag += 1
				if flag == 1:
					if ex[lidx] == "+" or ex[lidx] == "-":
						flag_al = True
				if flag == 0:
					break
				lidx -= 1
			if lidx != 0 and ex[lidx - 1] == "/":
				flag_al = False
			if not flag_al:
				idx += 1
				continue
			elif random.random() < rate:
				temp_idx = lidx + 1
				temp_res = ex[:lidx]
				if flag_mmd:
					temp_res += ["("]
				if lidx - 1 > 0:
					if ex[lidx - 1] == "-" or ex[lidx - 1] == "*" or ex[lidx - 1] == "/":
						flag_md = True
						temp_res += ["("]
				flag = 0
				lidx += 1
				while temp_idx < idx - 1:
					if ex[temp_idx] == "(" or ex[temp_idx] == "[":
						flag -= 1
					elif ex[temp_idx] == ")" or ex[temp_idx] == "]":
						flag += 1
					if flag == 0:
						if ex[temp_idx] == "+" or ex[temp_idx] == "-":
							temp_res += ex[lidx: temp_idx] + [ex[idx]] + r_allo + [ex[temp_idx]]
							lidx = temp_idx + 1
					temp_idx += 1
				temp_res += ex[lidx: temp_idx] + [ex[idx]] + r_allo
				if flag_md:
					temp_res += [")"]
				temp_res += r_last
				return temp_res
		if ex[idx] == "*" and (ex[idx + 1] == "[" or ex[idx + 1] == "("):
			lidx = idx - 1
			l_allo = []
			temp_res = []
			flag = 0
			flag_md = False  # flag for x or /
			while lidx > 0:
				if ex[lidx] == "(" or ex[lidx] == "[":
					flag += 1
				elif ex[lidx] == ")" or ex[lidx] == "]":
					flag -= 1
				if flag == 0:
					if ex[lidx] == "+":
						temp_res = ex[:lidx + 1]
						l_allo = ex[lidx + 1: idx]
						break
					elif ex[lidx] == "-":
						flag_md = True  # flag for -
						temp_res = ex[:lidx] + ["("]
						l_allo = ex[lidx + 1: idx]
						break
				elif flag == 1:
					temp_res = ex[:lidx + 1]
					l_allo = ex[lidx + 1: idx]
					break
				lidx -= 1
			if len(l_allo) == 0:
				l_allo = ex[:idx]
			flag = 0
			ridx = idx + 1
			flag_al = False
			all_res = []
			while ridx < lex:
				if ex[ridx] == "(" or ex[ridx] == "[":
					flag -= 1
				elif ex[ridx] == ")" or ex[ridx] == "]":
					flag += 1
				if flag == 1:
					if ex[ridx] == "+" or ex[ridx] == "-":
						flag_al = True
				if flag == 0:
					break
				ridx += 1
			if not flag_al:
				idx += 1
				continue
			elif random.random() < rate:
				temp_idx = idx + 1
				flag = 0
				lidx = temp_idx + 1
				while temp_idx < idx - 1:
					if ex[temp_idx] == "(" or ex[temp_idx] == "[":
						flag -= 1
					elif ex[temp_idx] == ")" or ex[temp_idx] == "]":
						flag += 1
					if flag == 1:
						if ex[temp_idx] == "+" or ex[temp_idx] == "-":
							all_res += l_allo + [ex[idx]] + ex[lidx: temp_idx] + [ex[temp_idx]]
							lidx = temp_idx + 1
					if flag == 0:
						break
					temp_idx += 1
				if flag_md:
					temp_res += all_res + [")"]
				elif ex[temp_idx + 1] == "*" or ex[temp_idx + 1] == "/":
					temp_res += ["("] + all_res + [")"]
				temp_res += ex[temp_idx + 1:]
				return temp_res
		idx += 1
	return ex

## args.py

In [28]:
def build_parser():
	# Data loading parameters
	parser = argparse.ArgumentParser(description='Run Single sequence model')

	parser.add_argument('-mode', type=str, default='train', choices=['train', 'test'], help='Modes: train, test')

	# Run Config
	parser.add_argument('-run_name', type=str, default='debug', help='run name for logs')
	parser.add_argument('-dataset', type=str, default='asdiv-a_fold0_final', help='Dataset')
	parser.add_argument('-outputs', dest='outputs', action='store_true', help='Show full validation outputs')
	parser.add_argument('-no-outputs', dest='outputs', action='store_false', help='Do not show full validation outputs')
	parser.set_defaults(outputs=True)
	parser.add_argument('-results', dest='results', action='store_true', help='Store results')
	parser.add_argument('-no-results', dest='results', action='store_false', help='Do not store results')
	parser.set_defaults(results=True)

	# Meta Attributes
	# parser.add_argument('-vocab_size', type=int, default=30000, help='Vocabulary size to consider')
	parser.add_argument('-trim_threshold', type=int, default=1, help='Remove words with frequency less than this from vocab')

	# Device Configuration
	parser.add_argument('-gpu', type=int, default=2, help='Specify the gpu to use')
	parser.add_argument('-seed', type=int, default=6174, help='Default seed to set')
	parser.add_argument('-logging', type=int, default=1, help='Set to 0 if you do not require logging')
	parser.add_argument('-ckpt', type=str, default='model', help='Checkpoint file name')
	parser.add_argument('-save_model', dest='save_model',action='store_true', help='To save the model')
	parser.add_argument('-no-save_model', dest='save_model', action='store_false', help='Dont save the model')
	parser.set_defaults(save_model=True)
	# parser.add_argument('-log_fmt', type=str, default='%(asctime)s | %(levelname)s | %(name)s | %(message)s', help='Specify format of the logger')

	# Model parameters
	# parser.add_argument('-cell_type', type=str, default='gru', help='RNN cell for encoder, default: gru')
	parser.add_argument('-embedding', type=str, default='roberta', choices=['bert', 'roberta', 'word2vec', 'random'], help='Embeddings')
	parser.add_argument('-emb_name', type=str, default='roberta-base', choices=['bert-base-uncased', 'roberta-base'], help='Which pre-trained model')
	parser.add_argument('-embedding_size', type=int, default=768, help='Embedding dimensions of inputs')
	parser.add_argument('-emb_lr', type=float, default=1e-5, help='Larning rate to train embeddings')
	parser.add_argument('-freeze_emb', dest='freeze_emb', action='store_true', help='Freeze embedding weights')
	parser.add_argument('-no-freeze_emb', dest='freeze_emb', action='store_false', help='Train embedding weights')
	parser.set_defaults(freeze_emb=False)
	parser.add_argument('-word2vec_bin', type=str, default='/datadrive/global_files/GoogleNews-vectors-negative300.bin', help='Binary file of word2vec')

	parser.add_argument('-cell_type', type=str, default='lstm', help='RNN cell for encoder and decoder, default: lstm')
	parser.add_argument('-hidden_size', type=int, default=384, help='Number of hidden units in each layer')
	parser.add_argument('-depth', type=int, default=2, help='Number of layers in each encoder')
	parser.add_argument('-lr', type=float, default=1e-3, help='Learning rate')
	parser.add_argument('-batch_size', type=int, default=8, help='Batch size')
	parser.add_argument('-weight_decay', type=float, default=1e-5, help='Weight Decay')
	parser.add_argument('-beam_size', type=float, default=5, help='Beam Size')
	parser.add_argument('-epochs', type=int, default=70, help='Maximum # of training epochs')	
	parser.add_argument('-dropout', type=float, default=0.5, help= 'Dropout probability for input/output/state units (0.0: no dropout)')
	
	# parser.add_argument('-max_length', type=int, default=100, help='Specify max decode steps: Max length string to output')
	# parser.add_argument('-init_range', type=float, default=0.08, help='Initialization range for seq2seq model')
	# parser.add_argument('-bidirectional', dest='bidirectional', action='store_true', help='Bidirectionality in LSTMs')
	# parser.add_argument('-no-bidirectional', dest='bidirectional', action='store_false', help='Bidirectionality in LSTMs')
	# parser.set_defaults(bidirectional=False)
	
	# parser.add_argument('-max_grad_norm', type=float, default=0.25, help='Clip gradients to this norm')
	# parser.add_argument('-opt', type=str, default='adam', choices=['adam', 'adadelta', 'sgd', 'asgd'], help='Optimizer for training')

	# parser.add_argument('-grade_disp', dest='grade_disp', action='store_true', help='Display grade information in validation outputs')
	# parser.add_argument('-no-grade_disp', dest='grade_disp', action='store_false', help='Don\'t display grade information')
	# parser.set_defaults(grade_disp=True)
	# parser.add_argument('-type_disp', dest='type_disp', action='store_true', help='Display Type information in validation outputs')
	# parser.add_argument('-no-type_disp', dest='type_disp', action='store_false', help='Don\'t display Type information')
	# parser.set_defaults(type_disp=True)
	parser.add_argument('-nums_disp', dest='nums_disp', action='store_true', help='Display number of numbers information in validation outputs')
	parser.add_argument('-no-nums_disp', dest='nums_disp', action='store_false', help='Don\'t display number of numbers information')
	parser.set_defaults(nums_disp=True)
	parser.add_argument('-challenge_disp', dest='challenge_disp', action='store_true', help='Display information in validation outputs')
	parser.add_argument('-no-challenge_disp', dest='challenge_disp', action='store_false', help='Don\'t display information')
	parser.set_defaults(challenge_disp=False)

	parser.add_argument('-show_train_acc', dest='show_train_acc', action='store_true', help='Calculate the train accuracy')
	parser.add_argument('-no-show_train_acc', dest='show_train_acc', action='store_false', help='Don\'t calculate the train accuracy')
	parser.set_defaults(show_train_acc=True)

	parser.add_argument('-full_cv', dest='full_cv', action='store_true', help='5-fold CV')
	parser.add_argument('-no-full_cv', dest='full_cv', action='store_false', help='No 5-fold CV')
	parser.set_defaults(full_cv=False)

	parser.add_argument('-len_generate_nums', type=int, default=0, help='store length of generate_nums')
	parser.add_argument('-copy_nums', type=int, default=0, help='store copy_nums')
	
	return parser

def parse_arguments(arg_dict=None):
    parser = build_parser()
    if arg_dict:
        # Override default values with provided dictionary values
        args = parser.parse_args([])
        for key, value in arg_dict.items():
            setattr(args, key, value)
        return args
    else:
        return parser.parse_args()  # If no dictionary is provided, use default command line arguments

## train and evaluate

In [29]:
MAX_OUTPUT_LENGTH = 45
MAX_INPUT_LENGTH = 120
USE_CUDA = torch.cuda.is_available()


class Beam:  # the class save the beam node
	def __init__(self, score, input_var, hidden, all_output):
		self.score = score
		self.input_var = input_var
		self.hidden = hidden
		self.all_output = all_output


def time_since(s):  # compute time
	m = math.floor(s / 60)
	s -= m * 60
	h = math.floor(m / 60)
	m -= h * 60
	return '%dh %dm %ds' % (h, m, s)


def generate_rule_mask(decoder_input, nums_batch, word2index, batch_size, nums_start, copy_nums, generate_nums,
					   english):
	rule_mask = torch.FloatTensor(batch_size, nums_start + copy_nums).fill_(-float("1e12"))
	if english:
		if decoder_input[0] == word2index["SOS"]:
			for i in range(batch_size):
				res = [_ for _ in range(nums_start, nums_start + nums_batch[i])] + \
					  [word2index["("]] + generate_nums
				for j in res:
					rule_mask[i, j] = 0
			return rule_mask
		for i in range(batch_size):
			res = []
			if decoder_input[i] >= nums_start:
				res += [word2index[")"], word2index["+"], word2index["-"],
						word2index["/"], word2index["*"], word2index["EOS"]
						]
			elif decoder_input[i] in generate_nums:
				res += [word2index[")"], word2index["+"], word2index["-"],
						word2index["/"], word2index["*"], word2index["EOS"]
						]
			elif decoder_input[i] == word2index["EOS"] or decoder_input[i] == PAD_token:
				res += [PAD_token]
			elif decoder_input[i] == word2index["("]:
				res += [_ for _ in range(nums_start, nums_start + nums_batch[i])] +\
				  [word2index["("]] + generate_nums
			elif decoder_input[i] == word2index[")"]:
				res += [word2index[")"], word2index["+"], word2index["-"],
						word2index["/"], word2index["*"], word2index["EOS"]
						]
			elif decoder_input[i] in [word2index["+"], word2index["-"], word2index["/"], word2index["*"]]:
				res += [_ for _ in range(nums_start, nums_start + nums_batch[i])] + [word2index["("]] + generate_nums
			for j in res:
				rule_mask[i, j] = 0
	else:
		if decoder_input[0] == word2index["SOS"]:
			for i in range(batch_size):
				res = [_ for _ in range(nums_start, nums_start + nums_batch[i])] + \
					  [word2index["["], word2index["("]] + generate_nums
				for j in res:
					rule_mask[i, j] = 0
			return rule_mask
		for i in range(batch_size):
			res = []
			if decoder_input[i] >= nums_start or decoder_input[i] in generate_nums:
				res += [word2index["]"], word2index[")"], word2index["+"],
						word2index["-"], word2index["/"], word2index["^"],
						word2index["*"], word2index["EOS"]
						]
			elif decoder_input[i] == word2index["EOS"] or decoder_input[i] == PAD_token:
				res += [PAD_token]
			elif decoder_input[i] == word2index["["] or decoder_input[i] == word2index["("]:
				res += [_ for _ in range(nums_start, nums_start + nums_batch[i])] +\
				  [word2index["("]] + generate_nums
			elif decoder_input[i] == word2index[")"]:
				res += [word2index["]"], word2index[")"], word2index["+"],
						word2index["-"], word2index["/"], word2index["^"],
						word2index["*"], word2index["EOS"]
						]
			elif decoder_input[i] == word2index["]"]:
				res += [word2index["+"], word2index["*"], word2index["-"], word2index["/"], word2index["EOS"]]
			elif decoder_input[i] in [word2index["+"], word2index["-"], word2index["/"],
									  word2index["*"], word2index["^"]]:
				res += [_ for _ in range(nums_start, nums_start + nums_batch[i])] +\
				  [word2index["["], word2index["("]] + generate_nums
			for j in res:
				rule_mask[i, j] = 0
	return rule_mask


def generate_pre_tree_seq_rule_mask(decoder_input, nums_batch, word2index, batch_size, nums_start, copy_nums,
									generate_nums, english):
	rule_mask = torch.FloatTensor(batch_size, nums_start + copy_nums).fill_(-float("1e12"))
	if english:
		if decoder_input[0] == word2index["SOS"]:
			for i in range(batch_size):
				res = [_ for _ in range(nums_start, nums_start + nums_batch[i])] + generate_nums + \
					  [word2index["+"], word2index["-"], word2index["/"], word2index["*"]]
				for j in res:
					rule_mask[i, j] = 0
			return rule_mask
		for i in range(batch_size):
			res = []
			if decoder_input[i] >= nums_start or decoder_input[i] in generate_nums:
				res += [_ for _ in range(nums_start, nums_start + nums_batch[i])] + generate_nums + \
					   [word2index["+"], word2index["-"], word2index["/"], word2index["*"], word2index["EOS"]
						]
			elif decoder_input[i] == word2index["EOS"] or decoder_input[i] == PAD_token:
				res += [PAD_token]
			elif decoder_input[i] in [word2index["+"], word2index["-"], word2index["/"], word2index["*"]]:
				res += [_ for _ in range(nums_start, nums_start + nums_batch[i])] + generate_nums + \
					   [word2index["+"], word2index["-"], word2index["/"], word2index["*"]]
			for j in res:
				rule_mask[i, j] = 0
	else:
		if decoder_input[0] == word2index["SOS"]:
			for i in range(batch_size):
				res = [_ for _ in range(nums_start, nums_start + nums_batch[i])] + generate_nums + \
					  [word2index["+"], word2index["-"], word2index["/"], word2index["*"], word2index["^"]]
				for j in res:
					rule_mask[i, j] = 0
			return rule_mask
		for i in range(batch_size):
			res = []
			if decoder_input[i] >= nums_start or decoder_input[i] in generate_nums:
				res += [_ for _ in range(nums_start, nums_start + nums_batch[i])] + generate_nums + \
					   [word2index["+"], word2index["-"], word2index["/"], word2index["*"], word2index["EOS"],
						word2index["^"]
						]
			elif decoder_input[i] == word2index["EOS"] or decoder_input[i] == PAD_token:
				res += [PAD_token]
			elif decoder_input[i] in [word2index["+"], word2index["-"], word2index["/"], word2index["*"],
									  word2index["^"]]:
				res += [_ for _ in range(nums_start, nums_start + nums_batch[i])] + generate_nums + \
					   [word2index["+"], word2index["-"], word2index["/"], word2index["*"], word2index["^"]]
			for j in res:
				rule_mask[i, j] = 0
	return rule_mask


def generate_post_tree_seq_rule_mask(decoder_input, nums_batch, word2index, batch_size, nums_start, copy_nums,
									 generate_nums, english):
	rule_mask = torch.FloatTensor(batch_size, nums_start + copy_nums).fill_(-float("1e12"))
	if english:
		if decoder_input[0] == word2index["SOS"]:
			for i in range(batch_size):
				res = [_ for _ in range(nums_start, nums_start + nums_batch[i])] + generate_nums
				for j in res:
					rule_mask[i, j] = 0
			return rule_mask
		for i in range(batch_size):
			res = []
			if decoder_input[i] >= nums_start or decoder_input[i] in generate_nums:
				res += [_ for _ in range(nums_start, nums_start + nums_batch[i])] + generate_nums + \
					   [word2index["+"], word2index["-"], word2index["/"], word2index["*"]]
			elif decoder_input[i] == word2index["EOS"] or decoder_input[i] == PAD_token:
				res += [PAD_token]
			elif decoder_input[i] in [word2index["+"], word2index["-"], word2index["/"], word2index["*"]]:
				res += [_ for _ in range(nums_start, nums_start + nums_batch[i])] + generate_nums +\
					   [word2index["+"], word2index["-"], word2index["/"], word2index["*"], word2index["EOS"]
						]
			for j in res:
				rule_mask[i, j] = 0
	else:
		if decoder_input[0] == word2index["SOS"]:
			for i in range(batch_size):
				res = [_ for _ in range(nums_start, nums_start + nums_batch[i])] + generate_nums
				for j in res:
					rule_mask[i, j] = 0
			return rule_mask
		for i in range(batch_size):
			res = []
			if decoder_input[i] >= nums_start or decoder_input[i] in generate_nums:
				res += [_ for _ in range(nums_start, nums_start + nums_batch[i])] + generate_nums + \
					   [word2index["+"], word2index["-"], word2index["/"], word2index["*"], word2index["^"]
						]
			elif decoder_input[i] == word2index["EOS"] or decoder_input[i] == PAD_token:
				res += [PAD_token]
			elif decoder_input[i] in [word2index["+"], word2index["-"], word2index["/"], word2index["*"],
									  word2index["^"]]:
				res += [_ for _ in range(nums_start, nums_start + nums_batch[i])] + generate_nums + \
					   [word2index["+"], word2index["-"], word2index["/"], word2index["*"], word2index["^"],
						word2index["EOS"]
						]
			for j in res:
				rule_mask[i, j] = 0
	return rule_mask


def generate_tree_input(target, decoder_output, nums_stack_batch, num_start, unk):
	# when the decoder input is copied num but the num has two pos, chose the max
	target_input = copy.deepcopy(target)
	for i in range(len(target)):
		if target[i] == unk:
			num_stack = nums_stack_batch[i].pop()
			max_score = -float("1e12")
			for num in num_stack:
				if decoder_output[i, num_start + num] > max_score:
					target[i] = num + num_start
					max_score = decoder_output[i, num_start + num]
		if target_input[i] >= num_start:
			target_input[i] = 0
	return torch.LongTensor(target), torch.LongTensor(target_input)


def generate_decoder_input(target, decoder_output, nums_stack_batch, num_start, unk):
	# when the decoder input is copied num but the num has two pos, chose the max
	if USE_CUDA:
		decoder_output = decoder_output.cpu()
	for i in range(target.size(0)):
		if target[i] == unk:
			num_stack = nums_stack_batch[i].pop()
			max_score = -float("1e12")
			for num in num_stack:
				if decoder_output[i, num_start + num] > max_score:
					target[i] = num + num_start
					max_score = decoder_output[i, num_start + num]
	return target


def mask_num(encoder_outputs, decoder_input, embedding_size, nums_start, copy_nums, num_pos):
	# mask the decoder input number and return the mask tensor and the encoder position Hidden vector
	up_num_start = decoder_input >= nums_start
	down_num_end = decoder_input < (nums_start + copy_nums)
	num_mask = up_num_start == down_num_end
	num_mask_encoder = num_mask < 1
	num_mask_encoder = num_mask_encoder.unsqueeze(1)  # ByteTensor size: B x 1
	repeat_dims = [1] * num_mask_encoder.dim()
	repeat_dims[1] = embedding_size
	num_mask_encoder = num_mask_encoder.repeat(*repeat_dims)  # B x 1 -> B x Decoder_embedding_size

	all_embedding = encoder_outputs.transpose(0, 1).contiguous()
	all_embedding = all_embedding.view(-1, encoder_outputs.size(2))  # S x B x H -> (B x S) x H
	indices = decoder_input - nums_start
	indices = indices * num_mask.long()  # 0 or the num pos in sentence
	indices = indices.tolist()
	for k in range(len(indices)):
		indices[k] = num_pos[k][indices[k]]
	indices = torch.LongTensor(indices)
	if USE_CUDA:
		indices = indices.cuda()
	batch_size = decoder_input.size(0)
	sen_len = encoder_outputs.size(0)
	batch_num = torch.LongTensor(range(batch_size))
	batch_num = batch_num * sen_len
	if USE_CUDA:
		batch_num = batch_num.cuda()
	indices = batch_num + indices
	num_encoder = all_embedding.index_select(0, indices)
	return num_mask, num_encoder, num_mask_encoder


def out_equation(test, output_lang, num_list, num_stack=None):
	test = test[:-1]
	max_index = len(output_lang.index2word) - 1
	test_str = ""
	for i in test:
		if i < max_index:
			c = output_lang.index2word[i]
			if c == "^":
				test_str += "**"
			elif c == "[":
				test_str += "("
			elif c == "]":
				test_str += ")"
			elif c[0] == "N":
				if int(c[1:]) >= len(num_list):
					return None
				x = num_list[int(c[1:])]
				if x[-1] == "%":
					test_str += "(" + x[:-1] + "/100" + ")"
				else:
					test_str += x
			else:
				test_str += c
		else:
			if len(num_stack) == 0:
				print(test_str, num_list)
				return ""
			n_pos = num_stack.pop()
			test_str += num_list[n_pos[0]]
	return test_str


def compute_prefix_tree_result(test_res, test_tar, output_lang, num_list, num_stack):
	# print(test_res, test_tar)

	if len(num_stack) == 0 and test_res == test_tar:
		return True, True, test_res, test_tar
	test = out_expression_list(test_res, output_lang, num_list)
	tar = out_expression_list(test_tar, output_lang, num_list, copy.deepcopy(num_stack))
	# print(test, tar)
	if test is None:
		return False, False, test, tar
	if test == tar:
		return True, True, test, tar
	try:
		if abs(compute_prefix_expression(test) - compute_prefix_expression(tar)) < 1e-4:
			return True, False, test, tar
		else:
			return False, False, test, tar
	except:
		return False, False, test, tar


def compute_postfix_tree_result(test_res, test_tar, output_lang, num_list, num_stack):
	# print(test_res, test_tar)

	if len(num_stack) == 0 and test_res == test_tar:
		return True, True, test_res, test_tar
	test = out_expression_list(test_res, output_lang, num_list)
	tar = out_expression_list(test_tar, output_lang, num_list, copy.deepcopy(num_stack))
	# print(test, tar)
	if test is None:
		return False, False, test, tar
	if test == tar:
		return True, True, test, tar
	try:
		if abs(compute_postfix_expression(test) - compute_postfix_expression(tar)) < 1e-4:
			return True, False, test, tar
		else:
			return False, False, test, tar
	except:
		return False, False, test, tar


def compute_result(test_res, test_tar, output_lang, num_list, num_stack):
	if len(num_stack) == 0 and test_res == test_tar:
		return True, True
	test = out_equation(test_res, output_lang, num_list)
	tar = out_equation(test_tar, output_lang, num_list, copy.deepcopy(num_stack))
	if test is None:
		return False, False
	if test == tar:
		return True, True
	try:
		if abs(eval(test) - eval(tar)) < 1e-4:
			return True, False
		else:
			return False, False
	except:
		return False, False


def get_all_number_encoder_outputs(encoder_outputs, num_pos, batch_size, num_size, hidden_size):
	indices = list()
	sen_len = encoder_outputs.size(0)
	masked_index = []
	temp_1 = [1 for _ in range(hidden_size)]
	temp_0 = [0 for _ in range(hidden_size)]
	for b in range(batch_size):
		for i in num_pos[b]:
			indices.append(i + b * sen_len)
			masked_index.append(temp_0)
		indices += [0 for _ in range(len(num_pos[b]), num_size)]
		masked_index += [temp_1 for _ in range(len(num_pos[b]), num_size)]
	indices = torch.LongTensor(indices)
	masked_index = torch.BoolTensor(masked_index)
	masked_index = masked_index.view(batch_size, num_size, hidden_size)
	if USE_CUDA:
		indices = indices.cuda()
		masked_index = masked_index.cuda()
	all_outputs = encoder_outputs.transpose(0, 1).contiguous()
	all_embedding = all_outputs.view(-1, encoder_outputs.size(2))  # S x B x H -> (B x S) x H
	all_num = all_embedding.index_select(0, indices)
	all_num = all_num.view(batch_size, num_size, hidden_size)
	return all_num.masked_fill_(masked_index, 0.0)


def train_attn(input_batch, input_length, target_batch, target_length, num_batch, nums_stack_batch, copy_nums,
			   generate_nums, encoder, decoder, encoder_optimizer, decoder_optimizer, output_lang, clip=0,
			   use_teacher_forcing=1, beam_size=1, english=False):
	seq_mask = []
	max_len = max(input_length)
	for i in input_length:
		seq_mask.append([0 for _ in range(i)] + [1 for _ in range(i, max_len)])
	seq_mask = torch.BoolTensor(seq_mask)

	num_start = output_lang.n_words - copy_nums - 2
	unk = output_lang.word2index["UNK"]
	# Turn padded arrays into (batch_size x max_len) tensors, transpose into (max_len x batch_size)
	input_var = torch.LongTensor(input_batch).transpose(0, 1)
	target = torch.LongTensor(target_batch).transpose(0, 1)

	batch_size = len(input_length)

	encoder.train()
	decoder.train()

	if USE_CUDA:
		input_var = input_var.cuda()
		seq_mask = seq_mask.cuda()

	# Zero gradients of both optimizers
	encoder_optimizer.zero_grad()
	decoder_optimizer.zero_grad()
	# Run words through encoder
	encoder_outputs, encoder_hidden = encoder(input_var, input_length, None)

	# Prepare input and output variables
	decoder_input = torch.LongTensor([output_lang.word2index["SOS"]] * batch_size)

	decoder_hidden = encoder_hidden[:decoder.n_layers]  # Use last (forward) hidden state from encoder

	max_target_length = max(target_length)
	all_decoder_outputs = torch.zeros(max_target_length, batch_size, decoder.output_size)

	# Move new Variables to CUDA
	if USE_CUDA:
		all_decoder_outputs = all_decoder_outputs.cuda()

	if random.random() < use_teacher_forcing:
		# Run through decoder one time step at a time
		for t in range(max_target_length):
			if USE_CUDA:
				decoder_input = decoder_input.cuda()

			decoder_output, decoder_hidden = decoder(
				decoder_input, decoder_hidden, encoder_outputs, seq_mask)
			all_decoder_outputs[t] = decoder_output
			decoder_input = generate_decoder_input(
				target[t], decoder_output, nums_stack_batch, num_start, unk)
			target[t] = decoder_input
	else:
		beam_list = list()
		score = torch.zeros(batch_size)
		if USE_CUDA:
			score = score.cuda()
		beam_list.append(Beam(score, decoder_input, decoder_hidden, all_decoder_outputs))
		# Run through decoder one time step at a time
		for t in range(max_target_length):
			beam_len = len(beam_list)
			beam_scores = torch.zeros(batch_size, decoder.output_size * beam_len)
			all_hidden = torch.zeros(decoder_hidden.size(0), batch_size * beam_len, decoder_hidden.size(2))
			all_outputs = torch.zeros(max_target_length, batch_size * beam_len, decoder.output_size)
			if USE_CUDA:
				beam_scores = beam_scores.cuda()
				all_hidden = all_hidden.cuda()
				all_outputs = all_outputs.cuda()

			for b_idx in range(len(beam_list)):
				decoder_input = beam_list[b_idx].input_var
				decoder_hidden = beam_list[b_idx].hidden

				rule_mask = generate_rule_mask(decoder_input, num_batch, output_lang.word2index, batch_size,
											   num_start, copy_nums, generate_nums, english)
				if USE_CUDA:
					rule_mask = rule_mask.cuda()
					decoder_input = decoder_input.cuda()

				decoder_output, decoder_hidden = decoder(
					decoder_input, decoder_hidden, encoder_outputs, seq_mask)

				score = f.log_softmax(decoder_output, dim=1) + rule_mask
				beam_score = beam_list[b_idx].score
				beam_score = beam_score.unsqueeze(1)
				repeat_dims = [1] * beam_score.dim()
				repeat_dims[1] = score.size(1)
				beam_score = beam_score.repeat(*repeat_dims)
				score += beam_score
				beam_scores[:, b_idx * decoder.output_size: (b_idx + 1) * decoder.output_size] = score
				all_hidden[:, b_idx * batch_size:(b_idx + 1) * batch_size, :] = decoder_hidden

				beam_list[b_idx].all_output[t] = decoder_output
				all_outputs[:, batch_size * b_idx: batch_size * (b_idx + 1), :] = \
					beam_list[b_idx].all_output
			topv, topi = beam_scores.topk(beam_size, dim=1)
			beam_list = list()

			for k in range(beam_size):
				temp_topk = topi[:, k]
				temp_input = temp_topk % decoder.output_size
				temp_input = temp_input.data
				if USE_CUDA:
					temp_input = temp_input.cpu()
				temp_beam_pos = temp_topk / decoder.output_size

				indices = torch.LongTensor(range(batch_size))
				if USE_CUDA:
					indices = indices.cuda()
				indices += temp_beam_pos * batch_size

				temp_hidden = all_hidden.index_select(1, indices)
				temp_output = all_outputs.index_select(1, indices)

				beam_list.append(Beam(topv[:, k], temp_input, temp_hidden, temp_output))
		all_decoder_outputs = beam_list[0].all_output

		for t in range(max_target_length):
			target[t] = generate_decoder_input(
				target[t], all_decoder_outputs[t], nums_stack_batch, num_start, unk)
	# Loss calculation and backpropagation

	if USE_CUDA:
		target = target.cuda()

	loss = masked_cross_entropy(
		all_decoder_outputs.transpose(0, 1).contiguous(),  # -> batch x seq
		target.transpose(0, 1).contiguous(),  # -> batch x seq
		target_length
	)

	loss.backward()
	return_loss = loss.item()

	# Clip gradient norms
	if clip:
		torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
		torch.nn.utils.clip_grad_norm_(decoder.parameters(), clip)

	# Update parameters with optimizers
	encoder_optimizer.step()
	decoder_optimizer.step()

	return return_loss


def evaluate_attn(input_seq, input_length, num_list, copy_nums, generate_nums, encoder, decoder, output_lang,
				  beam_size=1, english=False, max_length=MAX_OUTPUT_LENGTH):
	seq_mask = torch.BoolTensor(1, input_length).fill_(0)
	num_start = output_lang.n_words - copy_nums - 2

	# Turn padded arrays into (batch_size x max_len) tensors, transpose into (max_len x batch_size)
	input_var = torch.LongTensor(input_seq).unsqueeze(1)
	if USE_CUDA:
		input_var = input_var.cuda()
		seq_mask = seq_mask.cuda()

	# Set to not-training mode to disable dropout
	encoder.eval()
	decoder.eval()

	# Run through encoder
	encoder_outputs, encoder_hidden = encoder(input_var, [input_length], None)

	# Create starting vectors for decoder
	decoder_input = torch.LongTensor([output_lang.word2index["SOS"]])  # SOS
	decoder_hidden = encoder_hidden[:decoder.n_layers]  # Use last (forward) hidden state from encoder
	beam_list = list()
	score = 0
	beam_list.append(Beam(score, decoder_input, decoder_hidden, []))

	# Run through decoder
	for di in range(max_length):
		temp_list = list()
		beam_len = len(beam_list)
		for xb in beam_list:
			if int(xb.input_var[0]) == output_lang.word2index["EOS"]:
				temp_list.append(xb)
				beam_len -= 1
		if beam_len == 0:
			return beam_list[0].all_output
		beam_scores = torch.zeros(decoder.output_size * beam_len)
		hidden_size_0 = decoder_hidden.size(0)
		hidden_size_2 = decoder_hidden.size(2)
		all_hidden = torch.zeros(beam_len, hidden_size_0, 1, hidden_size_2)
		if USE_CUDA:
			beam_scores = beam_scores.cuda()
			all_hidden = all_hidden.cuda()
		all_outputs = []
		current_idx = -1

		for b_idx in range(len(beam_list)):
			decoder_input = beam_list[b_idx].input_var
			if int(decoder_input[0]) == output_lang.word2index["EOS"]:
				continue
			current_idx += 1
			decoder_hidden = beam_list[b_idx].hidden

			# rule_mask = generate_rule_mask(decoder_input, [num_list], output_lang.word2index,
			#                                1, num_start, copy_nums, generate_nums, english)
			if USE_CUDA:
				# rule_mask = rule_mask.cuda()
				decoder_input = decoder_input.cuda()

			decoder_output, decoder_hidden = decoder(
				decoder_input, decoder_hidden, encoder_outputs, seq_mask)
			# score = f.log_softmax(decoder_output, dim=1) + rule_mask.squeeze()
			score = f.log_softmax(decoder_output, dim=1)
			score += beam_list[b_idx].score
			beam_scores[current_idx * decoder.output_size: (current_idx + 1) * decoder.output_size] = score
			all_hidden[current_idx] = decoder_hidden
			all_outputs.append(beam_list[b_idx].all_output)
		topv, topi = beam_scores.topk(beam_size)

		for k in range(beam_size):
			word_n = int(topi[k])
			word_input = word_n % decoder.output_size
			temp_input = torch.LongTensor([word_input])
			indices = int(word_n / decoder.output_size)

			temp_hidden = all_hidden[indices]
			temp_output = all_outputs[indices]+[word_input]
			temp_list.append(Beam(float(topv[k]), temp_input, temp_hidden, temp_output))

		temp_list = sorted(temp_list, key=lambda x: x.score, reverse=True)

		if len(temp_list) < beam_size:
			beam_list = temp_list
		else:
			beam_list = temp_list[:beam_size]
	return beam_list[0].all_output


def copy_list(l):
	r = []
	if len(l) == 0:
		return r
	for i in l:
		if type(i) is list:
			r.append(copy_list(i))
		else:
			r.append(i)
	return r


class TreeBeam:  # the class save the beam node
	def __init__(self, score, node_stack, embedding_stack, left_childs, out):
		self.score = score
		self.embedding_stack = copy_list(embedding_stack)
		self.node_stack = copy_list(node_stack)
		self.left_childs = copy_list(left_childs)
		self.out = copy.deepcopy(out)


class TreeEmbedding:  # the class save the tree
	def __init__(self, embedding, terminal=False):
		self.embedding = embedding
		self.terminal = terminal


def train_tree(config, input_batch, input_length, target_batch, target_length, nums_stack_batch, num_size_batch, num_value_batch, group_batch, generate_nums,
			   embedding, encoder, predict, generate, merge, embedding_optimizer, encoder_optimizer, predict_optimizer, generate_optimizer,
			   merge_optimizer, input_lang, output_lang, num_pos, batch_graph, english=False):
	num_mask = []
	max_num_size = max(num_size_batch) + len(generate_nums)
	for i in num_size_batch:
		d = i + len(generate_nums)
		num_mask.append([0] * d + [1] * (max_num_size - d))
	num_mask = torch.BoolTensor(num_mask)

	unk = output_lang.word2index["UNK"]

	# Turn padded arrays into (batch_size x max_len) tensors, transpose into (max_len x batch_size)
	input_var = torch.LongTensor(input_batch).transpose(0, 1)

	target = torch.LongTensor(target_batch).transpose(0, 1)
	batch_graph = torch.LongTensor(batch_graph)

	padding_hidden = torch.FloatTensor([0.0 for _ in range(predict.hidden_size)]).unsqueeze(0)
	batch_size = len(input_length)

	embedding.train()
	encoder.train()
	predict.train()
	generate.train()
	merge.train()

	if USE_CUDA:
		input_var = input_var.cuda()
		# seq_mask = seq_mask.cuda()
		padding_hidden = padding_hidden.cuda()
		num_mask = num_mask.cuda()
		# batch_graph = batch_graph.cuda()

	# Zero gradients of both optimizers
	embedding_optimizer.zero_grad()
	encoder_optimizer.zero_grad()
	predict_optimizer.zero_grad()
	generate_optimizer.zero_grad()
	merge_optimizer.zero_grad()
	# Run words through encoder

	orig_idx = None
	embedded = None
	if config.embedding == 'bert' or config.embedding == 'roberta':
		contextual_input = index_batch_to_words(input_batch, input_length, input_lang)
		input_seq1, input_len1, token_ids, index_retrieve = embedding(contextual_input)
		
		new_group_batch = []
		for bat in range(len(group_batch)):
			try:
				new_group_batch.append([index_retrieve[bat][index1] for index1 in group_batch[bat] if index1 < len(index_retrieve[bat])])
			except:
				pdb.set_trace()

		batch_graph = get_single_batch_graph(token_ids.cpu().tolist(), input_len1, new_group_batch, num_value_batch, num_pos)
		batch_graph = torch.LongTensor(batch_graph)

		input_seq1 = input_seq1.transpose(0,1)
		embedded, input_length, orig_idx = sort_by_len(input_seq1, input_len1, gpu_init_pytorch(config.gpu))
	else:
		embedded = embedding(input_var)

	if USE_CUDA:
		batch_graph = batch_graph.cuda()

	encoder_outputs, problem_output = encoder(embedded, input_length, orig_idx, batch_graph)

	# sequence mask for attention
	seq_mask = []
	max_len = max(input_length)
	for i in input_length:
		seq_mask.append([0 for _ in range(i)] + [1 for _ in range(i, max_len)])
	seq_mask = torch.BoolTensor(seq_mask)

	if USE_CUDA:
		seq_mask = seq_mask.cuda()

	# Prepare input and output variables
	node_stacks = [[TreeNode(_)] for _ in problem_output.split(1, dim=0)]

	max_target_length = max(target_length)

	all_node_outputs = []
	# all_leafs = []

	copy_num_len = [len(_) for _ in num_pos]
	num_size = max(copy_num_len)
	all_nums_encoder_outputs = get_all_number_encoder_outputs(encoder_outputs, num_pos, batch_size, num_size,
															  encoder.hidden_size)

	num_start = output_lang.num_start
	embeddings_stacks = [[] for _ in range(batch_size)]
	left_childs = [None for _ in range(batch_size)]
	for t in range(max_target_length):
		num_score, op, current_embeddings, current_context, current_nums_embeddings = predict(
			node_stacks, left_childs, encoder_outputs, all_nums_encoder_outputs, padding_hidden, seq_mask, num_mask)

		# all_leafs.append(p_leaf)
		outputs = torch.cat((op, num_score), 1)
		all_node_outputs.append(outputs)

		target_t, generate_input = generate_tree_input(target[t].tolist(), outputs, nums_stack_batch, num_start, unk)
		target[t] = target_t
		if USE_CUDA:
			generate_input = generate_input.cuda()
		left_child, right_child, node_label = generate(current_embeddings, generate_input, current_context)
		left_childs = []
		for idx, l, r, node_stack, i, o in zip(range(batch_size), left_child.split(1), right_child.split(1),
											   node_stacks, target[t].tolist(), embeddings_stacks):
			if len(node_stack) != 0:
				node = node_stack.pop()
			else:
				left_childs.append(None)
				continue

			if i < num_start:
				node_stack.append(TreeNode(r))
				node_stack.append(TreeNode(l, left_flag=True))
				o.append(TreeEmbedding(node_label[idx].unsqueeze(0), False))
			else:
				current_num = current_nums_embeddings[idx, i - num_start].unsqueeze(0)
				while len(o) > 0 and o[-1].terminal:
					sub_stree = o.pop()
					op = o.pop()
					current_num = merge(op.embedding, sub_stree.embedding, current_num)
				o.append(TreeEmbedding(current_num, True))
			if len(o) > 0 and o[-1].terminal:
				left_childs.append(o[-1].embedding)
			else:
				left_childs.append(None)

	# all_leafs = torch.stack(all_leafs, dim=1)  # B x S x 2
	all_node_outputs = torch.stack(all_node_outputs, dim=1)  # B x S x N

	target = target.transpose(0, 1).contiguous()
	if USE_CUDA:
		# all_leafs = all_leafs.cuda()
		all_node_outputs = all_node_outputs.cuda()
		target = target.cuda()

	# op_target = target < num_start
	# loss_0 = masked_cross_entropy_without_logit(all_leafs, op_target.long(), target_length)
	loss = masked_cross_entropy(all_node_outputs, target, target_length)
	# loss = loss_0 + loss_1
	loss.backward()
	# clip the grad
	# torch.nn.utils.clip_grad_norm_(encoder.parameters(), 5)
	# torch.nn.utils.clip_grad_norm_(predict.parameters(), 5)
	# torch.nn.utils.clip_grad_norm_(generate.parameters(), 5)

	# Update parameters with optimizers
	embedding_optimizer.step()
	encoder_optimizer.step()
	predict_optimizer.step()
	generate_optimizer.step()
	merge_optimizer.step()
	return loss.item()  # , loss_0.item(), loss_1.item()


# def evaluate_tree(input_batch, input_length, generate_nums, encoder, predict, generate, merge, output_lang, num_pos, batch_graph, beam_size=5, english=False, max_length=MAX_OUTPUT_LENGTH):
def evaluate_tree(config, input_batch, input_length, generate_nums, embedding, encoder, predict, generate, merge, input_lang, output_lang, num_value, num_pos, batch_graph, group_example, beam_size=5, english=False, max_length=MAX_OUTPUT_LENGTH):

	# Turn padded arrays into (batch_size x max_len) tensors, transpose into (max_len x batch_size)
	input_var = torch.LongTensor(input_batch).unsqueeze(1)
	batch_graph = torch.LongTensor(batch_graph)

	num_mask = torch.BoolTensor(1, len(num_pos) + len(generate_nums)).fill_(0)

	# Set to not-training mode to disable dropout
	embedding.eval()
	encoder.eval()
	predict.eval()
	generate.eval()
	merge.eval()

	padding_hidden = torch.FloatTensor([0.0 for _ in range(predict.hidden_size)]).unsqueeze(0)

	batch_size = 1

	if USE_CUDA:
		input_var = input_var.cuda()
		# seq_mask = seq_mask.cuda()
		padding_hidden = padding_hidden.cuda()
		num_mask = num_mask.cuda()
		# batch_graph = batch_graph.cuda()
	# Run words through encoder

	embedded = None
	orig_idx = None
	if config.embedding == 'bert' or config.embedding == 'roberta':
		contextual_input = index_batch_to_words([input_batch], [input_length], input_lang)
		input_seq1, input_len1, token_ids, index_retrieve = embedding(contextual_input)

		try:
			new_group_example = [index_retrieve[0][index1] for index1 in group_example if index1 < len(index_retrieve[0])]
		except:
			pdb.set_trace()

		batch_graph = get_single_example_graph(token_ids.cpu().tolist()[0], input_len1[0], new_group_example, num_value, num_pos)
		batch_graph = torch.LongTensor(batch_graph)

		input_seq1 = input_seq1.transpose(0,1)
		embedded, input_length, orig_idx = sort_by_len(input_seq1, input_len1, gpu_init_pytorch(config.gpu))
		input_length = input_length[0]
	else:
		embedded = embedding(input_var)

	if USE_CUDA:
		batch_graph = batch_graph.cuda()

	encoder_outputs, problem_output = encoder(embedded, [input_length], orig_idx, batch_graph)
	# encoder_outputs, problem_output = encoder(input_var, [input_length], batch_graph)

	seq_mask = torch.BoolTensor(1, input_length).fill_(0)

	if USE_CUDA:
		seq_mask = seq_mask.cuda()

	# Prepare input and output variables
	node_stacks = [[TreeNode(_)] for _ in problem_output.split(1, dim=0)]

	num_size = len(num_pos)
	all_nums_encoder_outputs = get_all_number_encoder_outputs(encoder_outputs, [num_pos], batch_size, num_size,
															  encoder.hidden_size)
	num_start = output_lang.num_start
	# B x P x N
	embeddings_stacks = [[] for _ in range(batch_size)]
	left_childs = [None for _ in range(batch_size)]

	beams = [TreeBeam(0.0, node_stacks, embeddings_stacks, left_childs, [])]

	for t in range(max_length):
		current_beams = []
		while len(beams) > 0:
			b = beams.pop()
			if len(b.node_stack[0]) == 0:
				current_beams.append(b)
				continue
			# left_childs = torch.stack(b.left_childs)
			left_childs = b.left_childs

			num_score, op, current_embeddings, current_context, current_nums_embeddings = predict(
				b.node_stack, left_childs, encoder_outputs, all_nums_encoder_outputs, padding_hidden,
				seq_mask, num_mask)

			# leaf = p_leaf[:, 0].unsqueeze(1)
			# repeat_dims = [1] * leaf.dim()
			# repeat_dims[1] = op.size(1)
			# leaf = leaf.repeat(*repeat_dims)
			#
			# non_leaf = p_leaf[:, 1].unsqueeze(1)
			# repeat_dims = [1] * non_leaf.dim()
			# repeat_dims[1] = num_score.size(1)
			# non_leaf = non_leaf.repeat(*repeat_dims)
			#
			# p_leaf = torch.cat((leaf, non_leaf), dim=1)
			out_score = nn.functional.log_softmax(torch.cat((op, num_score), dim=1), dim=1)

			# out_score = p_leaf * out_score

			topv, topi = out_score.topk(min(beam_size, out_score.size()[1]))

			# is_leaf = int(topi[0])
			# if is_leaf:
			#     topv, topi = op.topk(1)
			#     out_token = int(topi[0])
			# else:
			#     topv, topi = num_score.topk(1)
			#     out_token = int(topi[0]) + num_start

			for tv, ti in zip(topv.split(1, dim=1), topi.split(1, dim=1)):
				current_node_stack = copy_list(b.node_stack)
				current_left_childs = []
				current_embeddings_stacks = copy_list(b.embedding_stack)
				current_out = copy.deepcopy(b.out)

				out_token = int(ti)
				current_out.append(out_token)

				node = current_node_stack[0].pop()

				if out_token < num_start:
					generate_input = torch.LongTensor([out_token])
					if USE_CUDA:
						generate_input = generate_input.cuda()
					left_child, right_child, node_label = generate(current_embeddings, generate_input, current_context)

					current_node_stack[0].append(TreeNode(right_child))
					current_node_stack[0].append(TreeNode(left_child, left_flag=True))

					current_embeddings_stacks[0].append(TreeEmbedding(node_label[0].unsqueeze(0), False))
				else:
					current_num = current_nums_embeddings[0, out_token - num_start].unsqueeze(0)

					while len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal:
						sub_stree = current_embeddings_stacks[0].pop()
						op = current_embeddings_stacks[0].pop()
						current_num = merge(op.embedding, sub_stree.embedding, current_num)
					current_embeddings_stacks[0].append(TreeEmbedding(current_num, True))
				if len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal:
					current_left_childs.append(current_embeddings_stacks[0][-1].embedding)
				else:
					current_left_childs.append(None)
				current_beams.append(TreeBeam(b.score+float(tv), current_node_stack, current_embeddings_stacks,
											  current_left_childs, current_out))
		beams = sorted(current_beams, key=lambda x: x.score, reverse=True)
		beams = beams[:beam_size]
		flag = True
		for b in beams:
			if len(b.node_stack[0]) != 0:
				flag = False
		if flag:
			break

	return beams[0].out


def topdown_train_tree(input_batch, input_length, target_batch, target_length, nums_stack_batch, num_size_batch,
					   generate_nums, encoder, predict, generate, encoder_optimizer, predict_optimizer,
					   generate_optimizer, output_lang, num_pos, english=False):
	# sequence mask for attention
	seq_mask = []
	max_len = max(input_length)
	for i in input_length:
		seq_mask.append([0 for _ in range(i)] + [1 for _ in range(i, max_len)])
	seq_mask = torch.BoolTensor(seq_mask)

	num_mask = []
	max_num_size = max(num_size_batch) + len(generate_nums)
	for i in num_size_batch:
		d = i + len(generate_nums)
		num_mask.append([0] * d + [1] * (max_num_size - d))
	num_mask = torch.BoolTensor(num_mask)

	unk = output_lang.word2index["UNK"]

	# Turn padded arrays into (batch_size x max_len) tensors, transpose into (max_len x batch_size)
	input_var = torch.LongTensor(input_batch).transpose(0, 1)

	target = torch.LongTensor(target_batch).transpose(0, 1)

	padding_hidden = torch.FloatTensor([0.0 for _ in range(predict.hidden_size)]).unsqueeze(0)
	batch_size = len(input_length)

	encoder.train()
	predict.train()
	generate.train()

	if USE_CUDA:
		input_var = input_var.cuda()
		seq_mask = seq_mask.cuda()
		padding_hidden = padding_hidden.cuda()
		num_mask = num_mask.cuda()

	# Zero gradients of both optimizers
	encoder_optimizer.zero_grad()
	predict_optimizer.zero_grad()
	generate_optimizer.zero_grad()
	# Run words through encoder

	encoder_outputs, problem_output = encoder(input_var, input_length)
	# Prepare input and output variables
	node_stacks = [[TreeNode(_)] for _ in problem_output.split(1, dim=0)]

	max_target_length = max(target_length)

	all_node_outputs = []
	# all_leafs = []

	copy_num_len = [len(_) for _ in num_pos]
	num_size = max(copy_num_len)
	all_nums_encoder_outputs = get_all_number_encoder_outputs(encoder_outputs, num_pos, batch_size, num_size,
															  encoder.hidden_size)

	num_start = output_lang.num_start
	left_childs = [None for _ in range(batch_size)]
	for t in range(max_target_length):
		num_score, op, current_embeddings, current_context, current_nums_embeddings = predict(
			node_stacks, left_childs, encoder_outputs, all_nums_encoder_outputs, padding_hidden, seq_mask, num_mask)

		# all_leafs.append(p_leaf)
		outputs = torch.cat((op, num_score), 1)
		all_node_outputs.append(outputs)

		target_t, generate_input = generate_tree_input(target[t].tolist(), outputs, nums_stack_batch, num_start, unk)
		target[t] = target_t
		if USE_CUDA:
			generate_input = generate_input.cuda()
		left_child, right_child, node_label = generate(current_embeddings, generate_input, current_context)
		for idx, l, r, node_stack, i in zip(range(batch_size), left_child.split(1), right_child.split(1),
											node_stacks, target[t].tolist()):
			if len(node_stack) != 0:
				node = node_stack.pop()
			else:
				continue

			if i < num_start:
				node_stack.append(TreeNode(r))
				node_stack.append(TreeNode(l, left_flag=True))

	# all_leafs = torch.stack(all_leafs, dim=1)  # B x S x 2
	all_node_outputs = torch.stack(all_node_outputs, dim=1)  # B x S x N

	target = target.transpose(0, 1).contiguous()
	if USE_CUDA:
		# all_leafs = all_leafs.cuda()
		all_node_outputs = all_node_outputs.cuda()
		target = target.cuda()

	# op_target = target < num_start
	# loss_0 = masked_cross_entropy_without_logit(all_leafs, op_target.long(), target_length)
	loss = masked_cross_entropy(all_node_outputs, target, target_length)
	# loss = loss_0 + loss_1
	loss.backward()
	# clip the grad
	# torch.nn.utils.clip_grad_norm_(encoder.parameters(), 5)
	# torch.nn.utils.clip_grad_norm_(predict.parameters(), 5)
	# torch.nn.utils.clip_grad_norm_(generate.parameters(), 5)

	# Update parameters with optimizers
	encoder_optimizer.step()
	predict_optimizer.step()
	generate_optimizer.step()
	return loss.item()  # , loss_0.item(), loss_1.item()


def topdown_evaluate_tree(input_batch, input_length, generate_nums, encoder, predict, generate, output_lang, num_pos,
						  beam_size=5, english=False, max_length=MAX_OUTPUT_LENGTH):

	seq_mask = torch.BoolTensor(1, input_length).fill_(0)
	# Turn padded arrays into (batch_size x max_len) tensors, transpose into (max_len x batch_size)
	input_var = torch.LongTensor(input_batch).unsqueeze(1)

	num_mask = torch.BoolTensor(1, len(num_pos) + len(generate_nums)).fill_(0)

	# Set to not-training mode to disable dropout
	encoder.eval()
	predict.eval()
	generate.eval()

	padding_hidden = torch.FloatTensor([0.0 for _ in range(predict.hidden_size)]).unsqueeze(0)

	batch_size = 1

	if USE_CUDA:
		input_var = input_var.cuda()
		seq_mask = seq_mask.cuda()
		padding_hidden = padding_hidden.cuda()
		num_mask = num_mask.cuda()
	# Run words through encoder

	encoder_outputs, problem_output = encoder(input_var, [input_length])

	# Prepare input and output variables
	node_stacks = [[TreeNode(_)] for _ in problem_output.split(1, dim=0)]

	num_size = len(num_pos)
	all_nums_encoder_outputs = get_all_number_encoder_outputs(encoder_outputs, [num_pos], batch_size, num_size,
															  encoder.hidden_size)
	num_start = output_lang.num_start
	# B x P x N
	embeddings_stacks = [[] for _ in range(batch_size)]
	left_childs = [None for _ in range(batch_size)]

	beams = [TreeBeam(0.0, node_stacks, embeddings_stacks, left_childs, [])]

	for t in range(max_length):
		current_beams = []
		while len(beams) > 0:
			b = beams.pop()
			if len(b.node_stack[0]) == 0:
				current_beams.append(b)
				continue
			# left_childs = torch.stack(b.left_childs)

			num_score, op, current_embeddings, current_context, current_nums_embeddings = predict(
				b.node_stack, left_childs, encoder_outputs, all_nums_encoder_outputs, padding_hidden,
				seq_mask, num_mask)

			# leaf = p_leaf[:, 0].unsqueeze(1)
			# repeat_dims = [1] * leaf.dim()
			# repeat_dims[1] = op.size(1)
			# leaf = leaf.repeat(*repeat_dims)
			#
			# non_leaf = p_leaf[:, 1].unsqueeze(1)
			# repeat_dims = [1] * non_leaf.dim()
			# repeat_dims[1] = num_score.size(1)
			# non_leaf = non_leaf.repeat(*repeat_dims)
			#
			# p_leaf = torch.cat((leaf, non_leaf), dim=1)
			out_score = nn.functional.log_softmax(torch.cat((op, num_score), dim=1), dim=1)

			# out_score = p_leaf * out_score

			topv, topi = out_score.topk(beam_size)

			# is_leaf = int(topi[0])
			# if is_leaf:
			#     topv, topi = op.topk(1)
			#     out_token = int(topi[0])
			# else:
			#     topv, topi = num_score.topk(1)
			#     out_token = int(topi[0]) + num_start

			for tv, ti in zip(topv.split(1, dim=1), topi.split(1, dim=1)):
				current_node_stack = copy_list(b.node_stack)
				current_out = copy.deepcopy(b.out)

				out_token = int(ti)
				current_out.append(out_token)

				node = current_node_stack[0].pop()

				if out_token < num_start:
					generate_input = torch.LongTensor([out_token])
					if USE_CUDA:
						generate_input = generate_input.cuda()
					left_child, right_child, node_label = generate(current_embeddings, generate_input, current_context)

					current_node_stack[0].append(TreeNode(right_child))
					current_node_stack[0].append(TreeNode(left_child, left_flag=True))

				current_beams.append(TreeBeam(b.score+float(tv), current_node_stack, embeddings_stacks, left_childs,
											  current_out))
		beams = sorted(current_beams, key=lambda x: x.score, reverse=True)
		beams = beams[:beam_size]
		flag = True
		for b in beams:
			if len(b.node_stack[0]) != 0:
				flag = False
		if flag:
			break

	return beams[0].out


## main.py

In [30]:
log_folder = 'logs'
model_folder = 'models'
outputs_folder = 'outputs'
result_folder = './out/'
data_path = '/kaggle/input/svampdata/data/'
board_path = './runs/'

def read_json(path):
	with open(path,'r') as f:
		file = json.load(f)
	return file

USE_CUDA = True

def get_new_fold(data,pairs,group):
	new_fold = []
	for item,pair,g in zip(data, pairs, group):
		pair = list(pair)
		pair.append(g['group_num'])
		pair = tuple(pair)
		new_fold.append(pair)
	return new_fold

def change_num(num):
	new_num = []
	for item in num:
		if '/' in item:
			new_str = item.split(')')[0]
			new_str = new_str.split('(')[1]
			a = float(new_str.split('/')[0])
			b = float(new_str.split('/')[1])
			value = a/b
			new_num.append(value)
		elif '%' in item:
			value = float(item[0:-1])/100
			new_num.append(value)
		else:
			new_num.append(float(item))
	return new_num

kaggle_args = {
    'debug': False,
    'mode': 'train',
    'gpu': 0,
    'dropout': 0.1,
    'heads': 4,
    'encoder_layers': 1,
    'decoder_layers': 1,
    'd_model': 768,
    'd_ff': 256,
    'lr': 0.0001,
    'emb_lr': 1e-5,
    'batch_size': 32,
    'epochs': 10,
    'embedding': 'roberta',
    'emb_name': 'roberta-base',
    'mawps_vocab': True,
    'dataset': 'mawps-asdiv-a_svamp',
    'run_name': 'mawps_try1',
    'logging': 0
}




config =  parse_arguments(kaggle_args)

mode = config.mode

if mode == 'train':
    is_train = True
else:
    is_train = False

''' Set seed for reproducibility'''
np.random.seed(config.seed)
torch.manual_seed(config.seed)
random.seed(config.seed)

'''GPU initialization'''
device = gpu_init_pytorch(config.gpu)

if config.full_cv:
    global data_path
    data_name = config.dataset
    data_path = data_path + data_name + '/'
    config.val_result_path = os.path.join(result_folder, 'CV_results_{}.json'.format(data_name))
    fold_acc_score = 0.0
    folds_scores = []
    best_acc = []
    for z in range(5):
        run_name = config.run_name + '_fold' + str(z)
        config.dataset = 'fold' + str(z)
        config.log_path = os.path.join(log_folder, run_name)
        config.model_path = os.path.join(model_folder, run_name)
        config.board_path = os.path.join(board_path, run_name)
        config.outputs_path = os.path.join(outputs_folder, run_name)

        vocab1_path = os.path.join(config.model_path, 'vocab1.p')
        vocab2_path = os.path.join(config.model_path, 'vocab2.p')
        config_file = os.path.join(config.model_path, 'config.p')
        log_file = os.path.join(config.log_path, 'log.txt')

        if config.results:
            config.result_path = os.path.join(result_folder, 'val_results_{}.json'.format(config.dataset))

        create_save_directories(config.log_path)
        create_save_directories(config.model_path)
        create_save_directories(config.outputs_path)

        logger = get_logger(run_name, log_file, logging.DEBUG)

        logger.info('Experiment Name: {}'.format(config.run_name))
        logger.debug('Created Relevant Directories')

        logger.info('Loading Data...')

        train_ls, dev_ls = load_raw_data(data_path, config.dataset, is_train)

        pairs_trained, pairs_tested, generate_nums, copy_nums = transfer_num(train_ls, dev_ls, config.challenge_disp)

        logger.debug('Data Loaded...')
        logger.debug('Number of Training Examples: {}'.format(len(pairs_trained)))
        logger.debug('Number of Testing Examples: {}'.format(len(pairs_tested)))
        logger.debug('Extra Numbers: {}'.format(generate_nums))
        logger.debug('Maximum Number of Numbers: {}'.format(copy_nums))

        logger.info('Creating Vocab...')
        input_lang = None
        output_lang = None

        input_lang, output_lang, train_pairs, test_pairs = prepare_data(config, logger, pairs_trained, pairs_tested, config.trim_threshold, generate_nums, copy_nums, input_lang, output_lang, tree=True)

        checkpoint = get_latest_checkpoint(config.model_path, logger)

        with open(vocab1_path, 'wb') as f:
            pickle.dump(input_lang, f, protocol=pickle.HIGHEST_PROTOCOL)
        with open(vocab2_path, 'wb') as f:
            pickle.dump(output_lang, f, protocol=pickle.HIGHEST_PROTOCOL)

        logger.debug('Vocab saved at {}'.format(vocab1_path))

        generate_num_ids = []
        for num in generate_nums:
            generate_num_ids.append(output_lang.word2index[num])

        config.len_generate_nums = len(generate_nums)
        config.copy_nums = copy_nums

        with open(config_file, 'wb') as f:
            pickle.dump(vars(config), f, protocol=pickle.HIGHEST_PROTOCOL)

        logger.debug('Config File Saved')

        logger.info('Initializing Models...')

        # Initialize models
        embedding = None
        if config.embedding == 'bert':
            embedding = BertEncoder(config.emb_name, device, config.freeze_emb)
        elif config.embedding == 'roberta':
            embedding = RobertaEncoder(config.emb_name, device, config.freeze_emb)
        else:
            embedding = Embedding(config, input_lang, input_size=input_lang.n_words, embedding_size=config.embedding_size, dropout=config.dropout)

        encoder = EncoderSeq(cell_type=config.cell_type, embedding_size=config.embedding_size, hidden_size=config.hidden_size, n_layers=config.depth, dropout=config.dropout)
        predict = Prediction(hidden_size=config.hidden_size, op_nums=output_lang.n_words - copy_nums - 1 - len(generate_nums), input_size=len(generate_nums), dropout=config.dropout)
        generate = GenerateNode(hidden_size=config.hidden_size, op_nums=output_lang.n_words - copy_nums - 1 - len(generate_nums), embedding_size=config.embedding_size, dropout=config.dropout)
        merge = Merge(hidden_size=config.hidden_size, embedding_size=config.embedding_size, dropout=config.dropout)
        # the embedding layer is  only for generated number embeddings, operators, and paddings

        logger.debug('Models Initialized')
        logger.info('Initializing Optimizers...')

        embedding_optimizer = torch.optim.Adam(embedding.parameters(), lr=config.emb_lr, weight_decay=config.weight_decay)
        encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        predict_optimizer = torch.optim.Adam(predict.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        generate_optimizer = torch.optim.Adam(generate.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        merge_optimizer = torch.optim.Adam(merge.parameters(), lr=config.lr, weight_decay=config.weight_decay)

        logger.debug('Optimizers Initialized')
        logger.info('Initializing Schedulers...')

        embedding_scheduler = torch.optim.lr_scheduler.StepLR(embedding_optimizer, step_size=20, gamma=0.5)
        encoder_scheduler = torch.optim.lr_scheduler.StepLR(encoder_optimizer, step_size=20, gamma=0.5)
        predict_scheduler = torch.optim.lr_scheduler.StepLR(predict_optimizer, step_size=20, gamma=0.5)
        generate_scheduler = torch.optim.lr_scheduler.StepLR(generate_optimizer, step_size=20, gamma=0.5)
        merge_scheduler = torch.optim.lr_scheduler.StepLR(merge_optimizer, step_size=20, gamma=0.5)

        logger.debug('Schedulers Initialized')

        logger.info('Loading Models on GPU {}...'.format(config.gpu))

        # Move models to GPU
        if USE_CUDA:
            embedding.to(device)
            encoder.to(device)
            predict.to(device)
            generate.to(device)
            merge.to(device)

        logger.debug('Models loaded on GPU {}'.format(config.gpu))

        max_value_corr = 0
        len_total_eval = 0
        max_val_acc = 0.0
        max_train_acc = 0.0
        eq_acc = 0.0
        best_epoch = -1
        min_train_loss = float('inf')

        logger.info('Starting Training Procedure')

        for epoch in range(config.epochs):
            loss_total = 0
            input_batches, input_lengths, output_batches, output_lengths, nums_batches, num_stack_batches, num_pos_batches, num_size_batches, num_value_batches, graph_batches, group_batches = prepare_train_batch(train_pairs, config.batch_size)

            od = OrderedDict()
            od['Epoch'] = epoch + 1
            print_log(logger, od)

            start = time.time()
            for idx in range(len(input_lengths)):
                loss = train_tree(
                    config, input_batches[idx], input_lengths[idx], output_batches[idx], output_lengths[idx],
                    num_stack_batches[idx], num_size_batches[idx], num_value_batches[idx], group_batches[idx], generate_num_ids, embedding, encoder, predict, generate, merge,
                    embedding_optimizer, encoder_optimizer, predict_optimizer, generate_optimizer, merge_optimizer, input_lang, output_lang, 
                    num_pos_batches[idx], graph_batches[idx])
                loss_total += loss
                print("Completed {} / {}...".format(idx, len(input_lengths)), end = '\r', flush = True)

            embedding_scheduler.step()
            encoder_scheduler.step()
            predict_scheduler.step()
            generate_scheduler.step()
            merge_scheduler.step()

            logger.debug('Training for epoch {} completed...\nTime Taken: {}'.format(epoch, time_since(time.time() - start)))

            if loss_total / len(input_lengths) < min_train_loss:
                min_train_loss = loss_total / len(input_lengths)

            train_value_ac = 0
            train_equation_ac = 0
            train_eval_total = 1
            if config.show_train_acc:
                train_eval_total = 0
                logger.info('Computing Train Accuracy')
                start = time.time()
                with torch.no_grad():
                    for train_batch in train_pairs:
                        batch_graph = get_single_example_graph(train_batch[0], train_batch[1], train_batch[7], train_batch[4], train_batch[5])
                        train_res = evaluate_tree(config, train_batch[0], train_batch[1], generate_num_ids, embedding, encoder, predict, generate,
                                                 merge, input_lang, output_lang, train_batch[4], train_batch[5], batch_graph, test_batch[7], beam_size=config.beam_size)
                        train_val_ac, train_equ_ac, _, _ = compute_prefix_tree_result(train_res, train_batch[2], output_lang, train_batch[4], train_batch[6])

                        if train_val_ac:
                            train_value_ac += 1
                        if train_equ_ac:
                            train_equation_ac += 1
                        train_eval_total += 1

                logger.debug('Train Accuracy Computed...\nTime Taken: {}'.format(time_since(time.time() - start)))

            logger.info('Starting Validation')

            value_ac = 0
            equation_ac = 0
            eval_total = 0
            start = time.time()

            with open(config.outputs_path + '/outputs.txt', 'a') as f_out:
                f_out.write('---------------------------------------\n')
                f_out.write('Epoch: ' + str(epoch) + '\n')
                f_out.write('---------------------------------------\n')
                f_out.close()

            ex_num = 0
            for test_batch in test_pairs:
                batch_graph = get_single_example_graph(test_batch[0], test_batch[1], test_batch[7], test_batch[4], test_batch[5])
                test_res = evaluate_tree(config, test_batch[0], test_batch[1], generate_num_ids, embedding, encoder, predict, generate,
                                         merge, input_lang, output_lang, test_batch[4], test_batch[5], batch_graph, test_batch[7], beam_size=config.beam_size)
                val_ac, equ_ac, _, _ = compute_prefix_tree_result(test_res, test_batch[2], output_lang, test_batch[4], test_batch[6])

                cur_result = 0
                if val_ac:
                    value_ac += 1
                    cur_result = 1
                if equ_ac:
                    equation_ac += 1
                eval_total += 1

                with open(config.outputs_path + '/outputs.txt', 'a') as f_out:
                    f_out.write('Example: ' + str(ex_num) + '\n')
                    f_out.write('Source: ' + stack_to_string(sentence_from_indexes(input_lang, test_batch[0])) + '\n')
                    f_out.write('Target: ' + stack_to_string(sentence_from_indexes(output_lang, test_batch[2])) + '\n')
                    f_out.write('Generated: ' + stack_to_string(sentence_from_indexes(output_lang, test_res)) + '\n')
                    if config.nums_disp:
                        src_nums = len(test_batch[4])
                        tgt_nums = 0
                        pred_nums = 0
                        for k_tgt in sentence_from_indexes(output_lang, test_batch[2]):
                            if k_tgt not in ['+', '-', '*', '/']:
                                tgt_nums += 1
                        for k_pred in sentence_from_indexes(output_lang, test_res):
                            if k_pred not in ['+', '-', '*', '/']:
                                pred_nums += 1
                        f_out.write('Numbers in question: ' + str(src_nums) + '\n')
                        f_out.write('Numbers in Target Equation: ' + str(tgt_nums) + '\n')
                        f_out.write('Numbers in Predicted Equation: ' + str(pred_nums) + '\n')
                    f_out.write('Result: ' + str(cur_result) + '\n' + '\n')
                    f_out.close()

                ex_num+=1

            if float(train_value_ac) / train_eval_total > max_train_acc:
                max_train_acc = float(train_value_ac) / train_eval_total

            if float(value_ac) / eval_total > max_val_acc:
                max_value_corr = value_ac
                len_total_eval = eval_total
                max_val_acc = float(value_ac) / eval_total
                eq_acc = float(equation_ac) / eval_total
                best_epoch = epoch+1

                state = {
                        'epoch' : epoch,
                        'best_epoch': best_epoch-1,
                        'embedding_state_dict': embedding.state_dict(),
                        'encoder_state_dict': encoder.state_dict(),
                        'predict_state_dict': predict.state_dict(),
                        'generate_state_dict': generate.state_dict(),
                        'merge_state_dict': merge.state_dict(),
                        'embedding_optimizer_state_dict': embedding_optimizer.state_dict(),
                        'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
                        'predict_optimizer_state_dict': predict_optimizer.state_dict(),
                        'generate_optimizer_state_dict': generate_optimizer.state_dict(),
                        'merge_optimizer_state_dict': merge_optimizer.state_dict(),
                        'embedding_scheduler_state_dict': embedding_scheduler.state_dict(),
                        'encoder_scheduler_state_dict': encoder_scheduler.state_dict(),
                        'predict_scheduler_state_dict': predict_scheduler.state_dict(),
                        'generate_scheduler_state_dict': generate_scheduler.state_dict(),
                        'merge_scheduler_state_dict': merge_scheduler.state_dict(),
                        'voc1': input_lang,
                        'voc2': output_lang,
                        'train_loss_epoch' : loss_total / len(input_lengths),
                        'min_train_loss' : min_train_loss,
                        'val_acc_epoch' : float(value_ac) / eval_total,
                        'max_val_acc' : max_val_acc,
                        'equation_acc' : eq_acc,
                        'max_train_acc' : max_train_acc,
                        'generate_nums' : generate_nums
                    }

                if config.save_model:
                    save_checkpoint(state, epoch, logger, config.model_path, config.ckpt)

            od = OrderedDict()
            od['Epoch'] = epoch + 1
            od['best_epoch'] = best_epoch
            od['train_loss_epoch'] = loss_total / len(input_lengths)
            od['min_train_loss'] = min_train_loss
            od['train_acc_epoch'] = float(train_value_ac) / train_eval_total
            od['max_train_acc'] = max_train_acc
            od['val_acc_epoch'] = float(value_ac) / eval_total
            od['equation_acc_epoch'] = float(equation_ac) / eval_total
            od['max_val_acc'] = max_val_acc
            od['equation_acc'] = eq_acc
            print_log(logger, od)

            logger.debug('Validation Completed...\nTime Taken: {}'.format(time_since(time.time() - start)))

        if config.results:
            store_results(config, max_train_acc, max_val_acc, eq_acc, min_train_loss, best_epoch)
            logger.info('Scores saved at {}'.format(config.result_path))

        best_acc.append((max_value_corr, len_total_eval))

    total_value_corr = 0
    total_len = 0
    for w in range(len(best_acc)):
        folds_scores.append(float(best_acc[w][0])/best_acc[w][1])
        total_value_corr += best_acc[w][0]
        total_len += best_acc[w][1]
    fold_acc_score = float(total_value_corr)/total_len

    store_val_results(config, fold_acc_score, folds_scores)
    logger.info('Final Val score: {}'.format(fold_acc_score))


else:
    run_name = config.run_name
    config.log_path = os.path.join(log_folder, run_name)
    config.model_path = os.path.join(model_folder, run_name)
    config.board_path = os.path.join(board_path, run_name)
    config.outputs_path = os.path.join(outputs_folder, run_name)

    vocab1_path = os.path.join(config.model_path, 'vocab1.p')
    vocab2_path = os.path.join(config.model_path, 'vocab2.p')
    config_file = os.path.join(config.model_path, 'config.p')
    log_file = os.path.join(config.log_path, 'log.txt')

    if config.results:
        config.result_path = os.path.join(result_folder, 'val_results_{}.json'.format(config.dataset))

    if is_train:
        create_save_directories(config.log_path)
        create_save_directories(config.model_path)
        create_save_directories(config.outputs_path)
    else:
        create_save_directories(config.log_path)
        create_save_directories(config.result_path)

    logger = get_logger(run_name, log_file, logging.DEBUG)

    logger.info('Experiment Name: {}'.format(config.run_name))
    logger.debug('Created Relevant Directories')

    logger.info('Loading Data...')

    train_ls, dev_ls = load_raw_data(data_path, config.dataset, is_train)

    pairs_trained, pairs_tested, generate_nums, copy_nums = transfer_num(train_ls, dev_ls, config.challenge_disp)

    logger.debug('Data Loaded...')
    if is_train:
        logger.debug('Number of Training Examples: {}'.format(len(pairs_trained)))
    logger.debug('Number of Testing Examples: {}'.format(len(pairs_tested)))
    logger.debug('Extra Numbers: {}'.format(generate_nums))
    logger.debug('Maximum Number of Numbers: {}'.format(copy_nums))

    if is_train:
        logger.info('Creating Vocab...')
        input_lang = None
        output_lang = None
    else:
        logger.info('Loading Vocab File...')

        with open(vocab1_path, 'rb') as f:
            input_lang = pickle.load(f)
        with open(vocab2_path, 'rb') as f:
            output_lang = pickle.load(f)

        logger.info('Vocab Files loaded from {}\nNumber of Words: {}'.format(vocab1_path, input_lang.n_words))

    input_lang, output_lang, train_pairs, test_pairs = prepare_data(config, logger, pairs_trained, pairs_tested, config.trim_threshold, generate_nums, copy_nums, input_lang, output_lang, tree=True)

    checkpoint = get_latest_checkpoint(config.model_path, logger)

    if is_train:
        with open(vocab1_path, 'wb') as f:
            pickle.dump(input_lang, f, protocol=pickle.HIGHEST_PROTOCOL)
        with open(vocab2_path, 'wb') as f:
            pickle.dump(output_lang, f, protocol=pickle.HIGHEST_PROTOCOL)

        logger.debug('Vocab saved at {}'.format(vocab1_path))

        generate_num_ids = []
        for num in generate_nums:
            generate_num_ids.append(output_lang.word2index[num])

        config.len_generate_nums = len(generate_nums)
        config.copy_nums = copy_nums

        with open(config_file, 'wb') as f:
            pickle.dump(vars(config), f, protocol=pickle.HIGHEST_PROTOCOL)

        logger.debug('Config File Saved')

        logger.info('Initializing Models...')

        # Initialize models
        embedding = None
        if config.embedding == 'bert':
            embedding = BertEncoder(config.emb_name, device, config.freeze_emb)
        elif config.embedding == 'roberta':
            embedding = RobertaEncoder(config.emb_name, device, config.freeze_emb)
        else:
            embedding = Embedding(config, input_lang, input_size=input_lang.n_words, embedding_size=config.embedding_size, dropout=config.dropout)

        encoder = EncoderSeq(cell_type=config.cell_type, embedding_size=config.embedding_size, hidden_size=config.hidden_size, n_layers=config.depth, dropout=config.dropout)
        predict = Prediction(hidden_size=config.hidden_size, op_nums=output_lang.n_words - copy_nums - 1 - len(generate_nums), input_size=len(generate_nums), dropout=config.dropout)
        generate = GenerateNode(hidden_size=config.hidden_size, op_nums=output_lang.n_words - copy_nums - 1 - len(generate_nums), embedding_size=config.embedding_size, dropout=config.dropout)
        merge = Merge(hidden_size=config.hidden_size, embedding_size=config.embedding_size, dropout=config.dropout)
        # the embedding layer is  only for generated number embeddings, operators, and paddings

        logger.debug('Models Initialized')
        logger.info('Initializing Optimizers...')

        embedding_optimizer = torch.optim.Adam(embedding.parameters(), lr=config.emb_lr, weight_decay=config.weight_decay)
        encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        predict_optimizer = torch.optim.Adam(predict.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        generate_optimizer = torch.optim.Adam(generate.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        merge_optimizer = torch.optim.Adam(merge.parameters(), lr=config.lr, weight_decay=config.weight_decay)

        logger.debug('Optimizers Initialized')
        logger.info('Initializing Schedulers...')

        embedding_scheduler = torch.optim.lr_scheduler.StepLR(embedding_optimizer, step_size=20, gamma=0.5)
        encoder_scheduler = torch.optim.lr_scheduler.StepLR(encoder_optimizer, step_size=20, gamma=0.5)
        predict_scheduler = torch.optim.lr_scheduler.StepLR(predict_optimizer, step_size=20, gamma=0.5)
        generate_scheduler = torch.optim.lr_scheduler.StepLR(generate_optimizer, step_size=20, gamma=0.5)
        merge_scheduler = torch.optim.lr_scheduler.StepLR(merge_optimizer, step_size=20, gamma=0.5)

        logger.debug('Schedulers Initialized')

        logger.info('Loading Models on GPU {}...'.format(config.gpu))

        # Move models to GPU
        if USE_CUDA:
            embedding.to(device)
            encoder.to(device)
            predict.to(device)
            generate.to(device)
            merge.to(device)

        logger.debug('Models loaded on GPU {}'.format(config.gpu))

        max_val_acc = 0.0
        max_train_acc = 0.0
        eq_acc = 0.0
        best_epoch = -1
        min_train_loss = float('inf')

        logger.info('Starting Training Procedure')

        for epoch in range(config.epochs):
            loss_total = 0
            input_batches, input_lengths, output_batches, output_lengths, nums_batches, num_stack_batches, num_pos_batches, num_size_batches, num_value_batches, graph_batches, group_batches = prepare_train_batch(train_pairs, config.batch_size)

            od = OrderedDict()
            od['Epoch'] = epoch + 1
            print_log(logger, od)

            start = time.time()
            for idx in range(len(input_lengths)):
                loss = train_tree(
                    config, input_batches[idx], input_lengths[idx], output_batches[idx], output_lengths[idx],
                    num_stack_batches[idx], num_size_batches[idx], num_value_batches[idx], group_batches[idx], generate_num_ids, embedding, encoder, predict, generate, merge,
                    embedding_optimizer, encoder_optimizer, predict_optimizer, generate_optimizer, merge_optimizer, input_lang, output_lang, 
                    num_pos_batches[idx], graph_batches[idx])
                loss_total += loss
                print("Completed {} / {}...".format(idx, len(input_lengths)), end = '\r', flush = True)

            embedding_scheduler.step()
            encoder_scheduler.step()
            predict_scheduler.step()
            generate_scheduler.step()
            merge_scheduler.step()

            logger.debug('Training for epoch {} completed...\nTime Taken: {}'.format(epoch, time_since(time.time() - start)))

            if loss_total / len(input_lengths) < min_train_loss:
                min_train_loss = loss_total / len(input_lengths)

            train_value_ac = 0
            train_equation_ac = 0
            train_eval_total = 1
            if config.show_train_acc:
                train_eval_total = 0
                logger.info('Computing Train Accuracy')
                start = time.time()
                with torch.no_grad():
                    for train_batch in train_pairs:
                        batch_graph = get_single_example_graph(train_batch[0], train_batch[1], train_batch[7], train_batch[4], train_batch[5])
                        train_res = evaluate_tree(config, train_batch[0], train_batch[1], generate_num_ids, embedding, encoder, predict, generate,
                                                 merge, input_lang, output_lang, train_batch[4], train_batch[5], batch_graph, train_batch[7], beam_size=config.beam_size)
                        train_val_ac, train_equ_ac, _, _ = compute_prefix_tree_result(train_res, train_batch[2], output_lang, train_batch[4], train_batch[6])

                        if train_val_ac:
                            train_value_ac += 1
                        if train_equ_ac:
                            train_equation_ac += 1
                        train_eval_total += 1

                logger.debug('Train Accuracy Computed...\nTime Taken: {}'.format(time_since(time.time() - start)))

            logger.info('Starting Validation')

            value_ac = 0
            equation_ac = 0
            eval_total = 0
            start = time.time()

            with open(config.outputs_path + '/outputs.txt', 'a') as f_out:
                f_out.write('---------------------------------------\n')
                f_out.write('Epoch: ' + str(epoch) + '\n')
                f_out.write('---------------------------------------\n')
                f_out.close()

            ex_num = 0
            for test_batch in test_pairs:
                batch_graph = get_single_example_graph(test_batch[0], test_batch[1], test_batch[7], test_batch[4], test_batch[5])
                test_res = evaluate_tree(config, test_batch[0], test_batch[1], generate_num_ids, embedding, encoder, predict, generate,
                                         merge, input_lang, output_lang, test_batch[4], test_batch[5], batch_graph, test_batch[7], beam_size=config.beam_size)
                val_ac, equ_ac, _, _ = compute_prefix_tree_result(test_res, test_batch[2], output_lang, test_batch[4], test_batch[6])

                cur_result = 0
                if val_ac:
                    value_ac += 1
                    cur_result = 1
                if equ_ac:
                    equation_ac += 1
                eval_total += 1

                with open(config.outputs_path + '/outputs.txt', 'a') as f_out:
                    f_out.write('Example: ' + str(ex_num) + '\n')
                    f_out.write('Source: ' + stack_to_string(sentence_from_indexes(input_lang, test_batch[0])) + '\n')
                    f_out.write('Target: ' + stack_to_string(sentence_from_indexes(output_lang, test_batch[2])) + '\n')
                    f_out.write('Generated: ' + stack_to_string(sentence_from_indexes(output_lang, test_res)) + '\n')
                    if config.challenge_disp:
                        f_out.write('Type: ' + test_batch[8] + '\n')
                        f_out.write('Variation Type: ' + test_batch[9] + '\n')
                        f_out.write('Annotator: ' + test_batch[10] + '\n')
                        f_out.write('Alternate: ' + str(test_batch[11]) + '\n')
                    if config.nums_disp:
                        src_nums = len(test_batch[4])
                        tgt_nums = 0
                        pred_nums = 0
                        for k_tgt in sentence_from_indexes(output_lang, test_batch[2]):
                            if k_tgt not in ['+', '-', '*', '/']:
                                tgt_nums += 1
                        for k_pred in sentence_from_indexes(output_lang, test_res):
                            if k_pred not in ['+', '-', '*', '/']:
                                pred_nums += 1
                        f_out.write('Numbers in question: ' + str(src_nums) + '\n')
                        f_out.write('Numbers in Target Equation: ' + str(tgt_nums) + '\n')
                        f_out.write('Numbers in Predicted Equation: ' + str(pred_nums) + '\n')
                    f_out.write('Result: ' + str(cur_result) + '\n' + '\n')
                    f_out.close()

                ex_num+=1

            if float(train_value_ac) / train_eval_total > max_train_acc:
                    max_train_acc = float(train_value_ac) / train_eval_total

            if float(value_ac) / eval_total > max_val_acc:
                max_val_acc = float(value_ac) / eval_total
                eq_acc = float(equation_ac) / eval_total
                best_epoch = epoch+1

                state = {
                        'epoch' : epoch,
                        'best_epoch': best_epoch-1,
                        'embedding_state_dict': embedding.state_dict(),
                        'encoder_state_dict': encoder.state_dict(),
                        'predict_state_dict': predict.state_dict(),
                        'generate_state_dict': generate.state_dict(),
                        'merge_state_dict': merge.state_dict(),
                        'embedding_optimizer_state_dict': embedding_optimizer.state_dict(),
                        'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
                        'predict_optimizer_state_dict': predict_optimizer.state_dict(),
                        'generate_optimizer_state_dict': generate_optimizer.state_dict(),
                        'merge_optimizer_state_dict': merge_optimizer.state_dict(),
                        'embedding_scheduler_state_dict': embedding_scheduler.state_dict(),
                        'encoder_scheduler_state_dict': encoder_scheduler.state_dict(),
                        'predict_scheduler_state_dict': predict_scheduler.state_dict(),
                        'generate_scheduler_state_dict': generate_scheduler.state_dict(),
                        'merge_scheduler_state_dict': merge_scheduler.state_dict(),
                        'voc1': input_lang,
                        'voc2': output_lang,
                        'train_loss_epoch' : loss_total / len(input_lengths),
                        'min_train_loss' : min_train_loss,
                        'val_acc_epoch' : float(value_ac) / eval_total,
                        'max_val_acc' : max_val_acc,
                        'equation_acc' : eq_acc,
                        'max_train_acc' : max_train_acc,
                        'generate_nums' : generate_nums
                    }

                if config.save_model:
                    save_checkpoint(state, epoch, logger, config.model_path, config.ckpt)

            od = OrderedDict()
            od['Epoch'] = epoch + 1
            od['best_epoch'] = best_epoch
            od['train_loss_epoch'] = loss_total / len(input_lengths)
            od['min_train_loss'] = min_train_loss
            od['train_acc_epoch'] = float(train_value_ac) / train_eval_total
            od['max_train_acc'] = max_train_acc
            od['val_acc_epoch'] = float(value_ac) / eval_total
            od['equation_acc_epoch'] = float(equation_ac) / eval_total
            od['max_val_acc'] = max_val_acc
            od['equation_acc'] = eq_acc
            print_log(logger, od)

            logger.debug('Validation Completed...\nTime Taken: {}'.format(time_since(time.time() - start)))

        if config.results:
            store_results(config, max_train_acc, max_val_acc, eq_acc, min_train_loss, best_epoch)
            logger.info('Scores saved at {}'.format(config.result_path))

    else:
        gpu = config.gpu
        mode = config.mode
        dataset = config.dataset
        batch_size = config.batch_size
        old_run_name = config.run_name
        with open(config_file, 'rb') as f:
            config = AttrDict(pickle.load(f))
            config.gpu = gpu
            config.mode = mode
            config.dataset = dataset
            config.batch_size = batch_size

        logger.info('Initializing Models...')

        # Initialize models
        embedding = None
        if config.embedding == 'bert':
            embedding = BertEncoder(config.emb_name, device, config.freeze_emb)
        elif config.embedding == 'roberta':
            embedding = RobertaEncoder(config.emb_name, device, config.freeze_emb)
        else:
            embedding = Embedding(config, input_lang, input_size=input_lang.n_words, embedding_size=config.embedding_size, dropout=config.dropout)

        # encoder = EncoderSeq(input_size=input_lang.n_words, embedding_size=config.embedding_size, hidden_size=config.hidden_size, n_layers=config.depth, dropout=config.dropout)
        encoder = EncoderSeq(cell_type=config.cell_type, embedding_size=config.embedding_size, hidden_size=config.hidden_size, n_layers=config.depth, dropout=config.dropout)
        predict = Prediction(hidden_size=config.hidden_size, op_nums=output_lang.n_words - config.copy_nums - 1 - config.len_generate_nums, input_size=config.len_generate_nums, dropout=config.dropout)
        generate = GenerateNode(hidden_size=config.hidden_size, op_nums=output_lang.n_words - config.copy_nums - 1 - config.len_generate_nums, embedding_size=config.embedding_size, dropout=config.dropout)
        merge = Merge(hidden_size=config.hidden_size, embedding_size=config.embedding_size, dropout=config.dropout)
        # the embedding layer is only for generated number embeddings, operators, and paddings

        logger.debug('Models Initialized')

        epoch_offset, min_train_loss, max_train_acc, max_val_acc, equation_acc, best_epoch, generate_nums = load_checkpoint(config, embedding, encoder, predict, generate, merge, config.mode, checkpoint, logger, device)

        logger.info('Prediction from')
        od = OrderedDict()
        od['epoch'] = epoch_offset
        od['min_train_loss'] = min_train_loss
        od['max_train_acc'] = max_train_acc
        od['max_val_acc'] = max_val_acc
        od['equation_acc'] = equation_acc
        od['best_epoch'] = best_epoch
        print_log(logger, od)

        generate_num_ids = []
        for num in generate_nums:
            generate_num_ids.append(output_lang.word2index[num])

        value_ac = 0
        equation_ac = 0
        eval_total = 0
        start = time.time()

        with open(config.outputs_path + '/outputs.txt', 'a') as f_out:
            f_out.write('---------------------------------------\n')
            f_out.write('Test Name: ' + old_run_name + '\n')
            f_out.write('---------------------------------------\n')
            f_out.close()

        test_res_ques, test_res_act, test_res_gen, test_res_scores = [], [], [], []

        ex_num = 0
        for test_batch in test_pairs:
            batch_graph = get_single_example_graph(test_batch[0], test_batch[1], test_batch[7], test_batch[4], test_batch[5])
            test_res = evaluate_tree(config, test_batch[0], test_batch[1], generate_num_ids, embedding, encoder, predict, generate,
                                     merge, input_lang, output_lang, test_batch[4], test_batch[5], batch_graph, test_batch[7], beam_size=config.beam_size)
            val_ac, equ_ac, _, _ = compute_prefix_tree_result(test_res, test_batch[2], output_lang, test_batch[4], test_batch[6])

            cur_result = 0
            if val_ac:
                value_ac += 1
                cur_result = 1
            if equ_ac:
                equation_ac += 1
            eval_total += 1

            with open(config.outputs_path + '/outputs.txt', 'a') as f_out:
                f_out.write('Example: ' + str(ex_num) + '\n')
                f_out.write('Source: ' + stack_to_string(sentence_from_indexes(input_lang, test_batch[0])) + '\n')
                f_out.write('Target: ' + stack_to_string(sentence_from_indexes(output_lang, test_batch[2])) + '\n')
                f_out.write('Generated: ' + stack_to_string(sentence_from_indexes(output_lang, test_res)) + '\n')
                if config.nums_disp:
                    src_nums = len(test_batch[4])
                    tgt_nums = 0
                    pred_nums = 0
                    for k_tgt in sentence_from_indexes(output_lang, test_batch[2]):
                        if k_tgt not in ['+', '-', '*', '/']:
                            tgt_nums += 1
                    for k_pred in sentence_from_indexes(output_lang, test_res):
                        if k_pred not in ['+', '-', '*', '/']:
                            pred_nums += 1
                    f_out.write('Numbers in question: ' + str(src_nums) + '\n')
                    f_out.write('Numbers in Target Equation: ' + str(tgt_nums) + '\n')
                    f_out.write('Numbers in Predicted Equation: ' + str(pred_nums) + '\n')
                f_out.write('Result: ' + str(cur_result) + '\n' + '\n')
                f_out.close()

            ex_num+=1

        results_df = pd.DataFrame([test_res_ques, test_res_act, test_res_gen, test_res_scores]).transpose()
        results_df.columns = ['Question', 'Actual Equation', 'Generated Equation', 'Score']
        csv_file_path = os.path.join(config.outputs_path, config.dataset+'.csv')
        results_df.to_csv(csv_file_path, index = False)
        logger.info('Accuracy: {}'.format(sum(test_res_scores)/len(test_res_scores)))


2024-11-16 10:53:09,111 | INFO | 3377399906.py: 416 : <module>() ::	 Experiment Name: mawps_try1
2024-11-16 10:53:09,111 | INFO | 3377399906.py: 416 : <module>() ::	 Experiment Name: mawps_try1
2024-11-16 10:53:09,111 | INFO | 3377399906.py: 416 : <module>() ::	 Experiment Name: mawps_try1
2024-11-16 10:53:09,111 | INFO | 3377399906.py: 416 : <module>() ::	 Experiment Name: mawps_try1
2024-11-16 10:53:09,111 | INFO | 3377399906.py: 416 : <module>() ::	 Experiment Name: mawps_try1
2024-11-16 10:53:09,111 | INFO | 3377399906.py: 416 : <module>() ::	 Experiment Name: mawps_try1
2024-11-16 10:53:09,118 | DEBUG | 3377399906.py: 417 : <module>() ::	 Created Relevant Directories
2024-11-16 10:53:09,118 | DEBUG | 3377399906.py: 417 : <module>() ::	 Created Relevant Directories
2024-11-16 10:53:09,118 | DEBUG | 3377399906.py: 417 : <module>() ::	 Created Relevant Directories
2024-11-16 10:53:09,118 | DEBUG | 3377399906.py: 417 : <module>() ::	 Created Relevant Directories
2024-11-16 10:53:09,11

Transfer numbers...


2024-11-16 10:53:11,529 | DEBUG | 3377399906.py: 425 : <module>() ::	 Data Loaded...
2024-11-16 10:53:11,529 | DEBUG | 3377399906.py: 425 : <module>() ::	 Data Loaded...
2024-11-16 10:53:11,529 | DEBUG | 3377399906.py: 425 : <module>() ::	 Data Loaded...
2024-11-16 10:53:11,529 | DEBUG | 3377399906.py: 425 : <module>() ::	 Data Loaded...
2024-11-16 10:53:11,529 | DEBUG | 3377399906.py: 425 : <module>() ::	 Data Loaded...
2024-11-16 10:53:11,529 | DEBUG | 3377399906.py: 425 : <module>() ::	 Data Loaded...
2024-11-16 10:53:11,535 | DEBUG | 3377399906.py: 427 : <module>() ::	 Number of Training Examples: 3138
2024-11-16 10:53:11,535 | DEBUG | 3377399906.py: 427 : <module>() ::	 Number of Training Examples: 3138
2024-11-16 10:53:11,535 | DEBUG | 3377399906.py: 427 : <module>() ::	 Number of Training Examples: 3138
2024-11-16 10:53:11,535 | DEBUG | 3377399906.py: 427 : <module>() ::	 Number of Training Examples: 3138
2024-11-16 10:53:11,535 | DEBUG | 3377399906.py: 427 : <module>() ::	 Numb

keep_words 4069 / 4069 = 1.0000


2024-11-16 10:53:13,255 | DEBUG | 1049939482.py: 1086 : prepare_data() ::	 Indexed 4072 words in input language, 21 words in output
2024-11-16 10:53:13,255 | DEBUG | 1049939482.py: 1086 : prepare_data() ::	 Indexed 4072 words in input language, 21 words in output
2024-11-16 10:53:13,255 | DEBUG | 1049939482.py: 1086 : prepare_data() ::	 Indexed 4072 words in input language, 21 words in output
2024-11-16 10:53:13,255 | DEBUG | 1049939482.py: 1086 : prepare_data() ::	 Indexed 4072 words in input language, 21 words in output
2024-11-16 10:53:13,255 | DEBUG | 1049939482.py: 1086 : prepare_data() ::	 Indexed 4072 words in input language, 21 words in output
2024-11-16 10:53:13,255 | DEBUG | 1049939482.py: 1086 : prepare_data() ::	 Indexed 4072 words in input language, 21 words in output
2024-11-16 10:53:13,358 | DEBUG | 1049939482.py: 374 : get_latest_checkpoint() ::	 Checkpoint found at : models/mawps_try1/model.pt
2024-11-16 10:53:13,358 | DEBUG | 1049939482.py: 374 : get_latest_checkpoint

Completed 98 / 99...

2024-11-16 10:54:19,955 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 0 completed...
Time Taken: 0h 0m 55s
2024-11-16 10:54:19,955 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 0 completed...
Time Taken: 0h 0m 55s
2024-11-16 10:54:19,955 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 0 completed...
Time Taken: 0h 0m 55s
2024-11-16 10:54:19,955 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 0 completed...
Time Taken: 0h 0m 55s
2024-11-16 10:54:19,955 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 0 completed...
Time Taken: 0h 0m 55s
2024-11-16 10:54:19,955 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 0 completed...
Time Taken: 0h 0m 55s
2024-11-16 10:54:19,962 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 10:54:19,962 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 10:54:19,962 | INFO | 3377399906.py: 561 : <m

Completed 98 / 99...

2024-11-16 11:00:45,316 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 1 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:00:45,316 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 1 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:00:45,316 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 1 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:00:45,316 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 1 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:00:45,316 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 1 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:00:45,316 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 1 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:00:45,323 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:00:45,323 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:00:45,323 | INFO | 3377399906.py: 561 : <m

Completed 98 / 99...

2024-11-16 11:06:33,329 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 2 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:06:33,329 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 2 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:06:33,329 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 2 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:06:33,329 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 2 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:06:33,329 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 2 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:06:33,329 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 2 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:06:33,336 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:06:33,336 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:06:33,336 | INFO | 3377399906.py: 561 : <m

Completed 98 / 99...

2024-11-16 11:11:54,993 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 3 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:11:54,993 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 3 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:11:54,993 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 3 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:11:54,993 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 3 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:11:54,993 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 3 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:11:54,993 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 3 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:11:55,000 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:11:55,000 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:11:55,000 | INFO | 3377399906.py: 561 : <m

Completed 98 / 99...

2024-11-16 11:17:13,299 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 4 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:17:13,299 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 4 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:17:13,299 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 4 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:17:13,299 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 4 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:17:13,299 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 4 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:17:13,299 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 4 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:17:13,305 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:17:13,305 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:17:13,305 | INFO | 3377399906.py: 561 : <m

Completed 98 / 99...

2024-11-16 11:23:06,426 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 5 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:23:06,426 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 5 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:23:06,426 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 5 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:23:06,426 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 5 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:23:06,426 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 5 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:23:06,426 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 5 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:23:06,432 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:23:06,432 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:23:06,432 | INFO | 3377399906.py: 561 : <m

Completed 98 / 99...

2024-11-16 11:28:41,143 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 6 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:28:41,143 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 6 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:28:41,143 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 6 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:28:41,143 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 6 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:28:41,143 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 6 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:28:41,143 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 6 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:28:41,149 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:28:41,149 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:28:41,149 | INFO | 3377399906.py: 561 : <m

Completed 98 / 99...

2024-11-16 11:34:28,413 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 7 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:34:28,413 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 7 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:34:28,413 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 7 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:34:28,413 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 7 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:34:28,413 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 7 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:34:28,413 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 7 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:34:28,419 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:34:28,419 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:34:28,419 | INFO | 3377399906.py: 561 : <m

Completed 98 / 99...

2024-11-16 11:40:24,971 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 8 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:40:24,971 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 8 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:40:24,971 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 8 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:40:24,971 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 8 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:40:24,971 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 8 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:40:24,971 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 8 completed...
Time Taken: 0h 0m 55s
2024-11-16 11:40:24,978 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:40:24,978 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:40:24,978 | INFO | 3377399906.py: 561 : <m

Completed 98 / 99...

2024-11-16 11:46:29,134 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 9 completed...
Time Taken: 0h 0m 54s
2024-11-16 11:46:29,134 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 9 completed...
Time Taken: 0h 0m 54s
2024-11-16 11:46:29,134 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 9 completed...
Time Taken: 0h 0m 54s
2024-11-16 11:46:29,134 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 9 completed...
Time Taken: 0h 0m 54s
2024-11-16 11:46:29,134 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 9 completed...
Time Taken: 0h 0m 54s
2024-11-16 11:46:29,134 | DEBUG | 3377399906.py: 551 : <module>() ::	 Training for epoch 9 completed...
Time Taken: 0h 0m 54s
2024-11-16 11:46:29,141 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:46:29,141 | INFO | 3377399906.py: 561 : <module>() ::	 Computing Train Accuracy
2024-11-16 11:46:29,141 | INFO | 3377399906.py: 561 : <m

In [31]:
print("here")

here


In [32]:
def generate_full_question(question, numbers):
    for number in numbers:
        if "NUM" in question:  # Check if 'NUM' exists in the question
            question = question.replace("NUM", str(number), 1)
        else:
            break  # Stop if there are no more 'NUM' placeholders
    return question


def convert_eqn(equation, numbers):
    for i, num in enumerate(numbers):
        placeholder = f"N{i}"
        equation = equation.replace(placeholder, str(num))
    return equation


# Function to write evaluation results into a file
def write_to_file(filename, data):
    with open(filename, 'w') as f:
        for line in data:
            f.write(line + '\n')

# Loop over the validation data and collect output for file
output_lines = []
ct = 1
for test_batch in test_pairs:
    batch_graph = get_single_example_graph(test_batch[0], test_batch[1], test_batch[7], test_batch[4], test_batch[5])
    test_res = evaluate_tree(config, test_batch[0], test_batch[1], generate_num_ids, embedding, encoder, predict, generate,
                             merge, input_lang, output_lang, test_batch[4], test_batch[5], batch_graph, test_batch[7], beam_size=config.beam_size)
    val_ac, equ_ac, _, _ = compute_prefix_tree_result(test_res, test_batch[2], output_lang, test_batch[4], test_batch[6])

    numbers = test_batch[4]
    ques = generate_full_question(stack_to_string(sentence_from_indexes(input_lang, test_batch[0])), numbers)
    output_lines.append(f"Question {ct}: {ques}")
    true_eqn = convert_eqn(stack_to_string(sentence_from_indexes(output_lang, test_batch[2])), numbers)
    output_lines.append(f"True Answer: {true_eqn}")
    decode_eqn = convert_eqn(stack_to_string(sentence_from_indexes(output_lang, test_res)), numbers)
    output_lines.append(f"Decoded Answer: {decode_eqn}")
    
    result_comparison = "Correct" if true_eqn == decode_eqn else "Incorrect"
    output_lines.append(f"Predicted Result: {result_comparison}")
    output_lines.append("-" * 160)
    ct += 1
    if(ct >= 5):
        break

# Write all collected lines to eval.txt
write_to_file("gts_eval_robert.txt", output_lines)