#Import Stmts

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:

from collections import Counter
import csv
import argparse
from keras.preprocessing import sequence
from datetime import datetime
import numpy as np
import random
import codecs
np.random.seed(1337)  # for reproducibility
random.seed(1337)
import os
from tqdm import tqdm


import time
import tensorflow as tf
import sys
from sklearn.utils import shuffle
from collections import Counter
#import cPickle as pickle
import pickle
from keras.utils import np_utils

import string
import re
import math
import operator

from collections import defaultdict
import sys
from nltk.corpus import stopwords
#from nltk.translate.bleu_score import corpus_bleu
from sklearn.metrics import mean_squared_error
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from scipy.stats import pearsonr
from scipy.stats import spearmanr


from sklearn.metrics import mean_absolute_error
from collections import defaultdict
from collections import OrderedDict
#from Rouge155_modify import Rouge155
import sys

PAD = "<PAD>"
UNK = "<UNK>"
SOS = "<SOS>"
EOS = "<EOS>"

#Util functions


In [None]:

def batchify(data, i, bsz, max_sample):
    start = int(i * bsz)
    end = int(i * bsz) + bsz
    if(end>max_sample):
        end = max_sample
    return data[start:end]

def mkdir_p(path):
    ''' Makes path if path does not exist
    '''
    if path == '':
        return
    try:
        os.makedirs(path)
    except OSError as exc:  # Python >2.5
        pass
def print_args(args, path=None):
    ''' Print arguments to log file
    '''
    if path:
        output_file = open(path, 'w')
    args.command = ' '.join(sys.argv)
    items = vars(args)
    output_file.write('=============================================== \n')
    for key in sorted(items.keys(), key=lambda s: s.lower()):
        value = items[key]
        if not value:
            value = "None"
        if path is not None:
            output_file.write("  " + key + ": " + str(items[key]) + "\n")
    output_file.write('=============================================== \n')
    if path:
        output_file.close()
    del args.command

def show_stats(name, x):
    print("{} max={} mean={} min={}".format(name,
                                        np.max(x),
                                        np.mean(x),
                                        np.min(x)))
    
def prep_data_list(data, max_length): 

    all_data = []
    all_lengths = []
    for doc in tqdm(data, desc='building List'):
        sent_lens = len(doc)
        if sent_lens > max_length:
            sent_lens = max_length
        
        _data_list = pad_to_max(doc, max_length)
        all_data.append(_data_list)
        all_lengths.append(sent_lens)

    return all_data, all_lengths

def pad_to_max(seq, seq_max, pad_token=0):
    ''' Pad Sequence to sequence max
    '''
    new_seq = seq[:]
    while(len(new_seq)<seq_max):
        new_seq.append(pad_token)
    return new_seq[:seq_max]

def prep_hierarchical_data_list_new(data, lens, smax, dmax, threshold=True):
    """ Converts and pads hierarchical data
    """
    # print("Preparing Hiearchical Data list")

    # print(data[0])
    #print(data)

    all_data = []
    all_lengths = []
    #for i in tqdm(range(len(data)), desc='building H-dict'):
    for i in range(len(data)):
        new_data = []
        data_lengths = []
        #for data_list in doc:
        doc = data[i]
        doc_len = lens[i]
        #print (doc_len)
        for j in range(len(doc)):
            data_list = doc[j]
            sent_lens = doc_len[j]
            # for each document
            #sent_lens = len(data_list)
            #print (sent_lens)
            if(sent_lens==0):
                continue
            if(threshold and sent_lens>smax):
                sent_lens = smax

            _data_list = pad_to_max(data_list, smax)
            new_data.append(_data_list)
            data_lengths.append(sent_lens)
        new_data = pad_to_max(new_data, dmax,
                            pad_token=[0 for i in range(smax)])

        _new_data = []
        for nd in new_data:
            # flatten lists
            _new_data += nd

        data_lengths = pad_to_max(data_lengths, dmax, pad_token=0)
        all_data.append(_new_data)
        all_lengths.append(data_lengths)
    return all_data, all_lengths

def hierarchical_flatten(embed, lengths, smax):
    """ Flattens embedding for hierarchical processing.
    Args:
        embed: `tensor` [bsz x (num_docs * seq_len) x dim]
        lengths: `tensor` [bsz x num_docs]
        smax: `int` - maximum number of words in sentence
    Returns:
        embed: `tensor` [bsz x seq_len x dim] flattened input
        lengths: `tensor` [bsz] flattend lengths
    """

    _dims = embed.get_shape().as_list()[2]
    embed = tf.reshape(embed, [-1, smax, _dims])
    lengths = tf.reshape(lengths, [-1])
    return embed, lengths


#Parent experiment class

In [None]:

class Experiment(object):
    ''' Implements a base experiment class for TensorFLow
    Contains commonly used util functions.
    Extend this base Experiment class
    '''

    def __init__(self):
        self.uuid = datetime.now().strftime("%d_%H:%M:%S")
        self.eval_test = defaultdict(list)
        self.eval_train = defaultdict(list)
        self.eval_dev = defaultdict(list)
        self.eval_test2 = defaultdict(list)
        self.eval_dev2 = defaultdict(list)
        self.wiggle = False
        self.loggers = defaultdict(dict)

    def register_to_log(self, set_type, epoch, attr, val):
        if(attr not in self.loggers):
            self.loggers[attr] = {'train':defaultdict(dict),
                                    'Dev':defaultdict(dict),
                                    'Test':defaultdict(dict)}
        self.loggers[attr][set_type][epoch] = val

    def dump_all_logs(self):
        for key, value in self.loggers.items():
            for set_type, data in value.items():
                self.write_log_values(data, key, set_type)

    def write_log_values(self, data, attr, set_type):
        fp = self.out_dir +'./{}_{}.log'.format(attr, set_type)
        with open(fp, 'w+') as f:
            writer = csv.writer(f, delimiter='\t')
            writer.writerows(data.items())

    def _build_char_index(self):
        all_chars = list(string.printable)
        self.char_index = {char:index+2 for index, char in enumerate(all_chars)}
        self.char_index['<pad>'] = 0
        self.char_index['<unk>'] = 1

    def _setup(self):
        ''' Full Setup Procedure
        '''
         # Make directory for log file and saving models
        self._make_dir()
        # Select GPU
        self._designate_gpu()

    

    def _designate_gpu(self):
        ''' Choose GPU to use
        '''
        os.environ["CUDA_VISIBLE_DEVICES"] = "0"
        if(self.args.gpu == '-1'):
            os.environ["CUDA_VISIBLE_DEVICES"] = ""
        else:
            print("Selecting GPU no.{}".format(self.args.gpu))
            os.environ["CUDA_VISIBLE_DEVICES"] = str(self.args.gpu)

    def _make_dir(self):
        ''' Make log directories
        '''
        # self.model_name = self.args.rnn_type
        self.hyp_str = self.uuid + '_' + self.model_name
        if(self.args.log == 1):
            self.out_dir = './{}/{}/{}/{}/'.format(
                self.args.log_dir,
                self.args.dataset, self.model_name, self.uuid)
            # print(self.out_dir)
            mkdir_p(self.out_dir)
            self.mdl_path = self.out_dir + '/mdl.ckpt'
            self.path = self.out_dir + '/logs.txt'
            print_args(self.args, path=self.path)

    def write_to_file(self, txt):
        ''' A wrapper for printing to log and on CLI
        '''
        try:
            if(self.args.log == 1):
                with open(self.path, 'a+') as f:
                    f.write(txt + '\n')
        except:
            pass
        print(txt)

    
            
    
    def _register_eval_score(self, epoch, eval_type, metric, val):
        """ Registers eval metrics to class
        """
        eval_obj = {
            'metric':metric,
            'val':val
        }

        if(eval_type.lower()=='dev'):
            self.eval_dev[epoch].append(eval_obj)
        elif(eval_type.lower()=='test'):
            self.eval_test[epoch].append(eval_obj)
        elif(eval_type.lower()=='train'):
            self.eval_train[epoch].append(eval_obj)
        elif(eval_type.lower()=='dev2'):
            self.eval_dev2[epoch].append(eval_obj)
        elif(eval_type.lower()=='test2'):
            self.eval_test2[epoch].append(eval_obj)

    def _show_metrics(self, epoch, eval_list, show_metrics, name):
        """ Shows and outputs metrics
        """
        # print("Eval Metrics for [{}]".format(name))
        get_last = eval_list[epoch]
        for metric in get_last:
            # print(metric)
            if(metric['metric'] in show_metrics):
                self.write_to_file("[{}] {}={}".format(name,
                                                    metric['metric'],
                                                    metric['val']))


    def _select_test_by_dev(self, epoch, eval_dev, eval_test,
                            no_test=False, lower_is_better=False,
                            name='', has_dev=True):
        """ Outputs best test score based on dev score
        """

        self.write_to_file("====================================")
        primary_metrics = []
        test_metrics = []
        if(lower_is_better):
            reverse=False
        else:
            reverse=True
        # print(eval_dev)

        if(has_dev==True):
            for key, value in eval_dev.items():
                _val = [x for x in value if x['metric']==self.eval_primary]
                if(len(_val)==0):
                    continue
                primary_metrics.append([key, _val[0]['val']])

            sorted_metrics = sorted(primary_metrics,
                                        key=operator.itemgetter(1),
                                        reverse=reverse)
            cur_dev_score = primary_metrics[-1][1]
            best_epoch = sorted_metrics[0][0]

            if(no_test):
                # For MNLI or no test set
                print("[{}] Best epoch={}".format(name, best_epoch))
                self._show_metrics(best_epoch, eval_dev,
                                    self.show_metrics, name='best')
                if(self.args.wiggle_score>0 and self.wiggle==False):
                    if(cur_dev_score>self.args.wiggle_score):
                        print("Cur Dev Score at {}".format(cur_dev_score))
                        print("Activating Wiggle-SGD mode")
                        self.wiggle=True
                return best_epoch, cur_dev_score
        else:
            best_epoch = -1

        for key, value in eval_test.items():
            _val = [x for x in value if x['metric']==self.eval_primary]
            if(len(_val)==0):
                continue
            test_metrics.append([key, _val[0]['val']])

        # if(len(primary_metrics)==0):
        #     return False

        sorted_test = sorted(test_metrics, key=operator.itemgetter(1),
                                    reverse=reverse)

        max_epoch = sorted_test[0][0]

        self.write_to_file("Best epoch={}".format(best_epoch))
        self._show_metrics(best_epoch, eval_test,
                            self.show_metrics, name='best')
        self.write_to_file("Maxed epoch={}".format(max_epoch))
        self._show_metrics(max_epoch, eval_test,
                            self.show_metrics, name='max')
        if(self.args.early_stop>0):
            # Use early stopping
            if(epoch - best_epoch > self.args.early_stop):
                # print("Ended at early stop..")
                return True, max_epoch, best_epoch
        if(self.args.wiggle_after>0 and self.wiggle==False):
            # use SGD wiggling
            if(epoch - best_epoch > self.args.wiggle_after):
                print("Activating Wiggle-SGD mode")
                self.wiggle = True
        if(self.args.wiggle_score>0 and self.wiggle==False):
            if(cur_dev_score>self.args.wiggle_score):
                print("Cur Dev Score at {}".format(cur_dev_score))
                print("Activating Wiggle-SGD mode")
                self.wiggle=True

        return False, max_epoch, best_epoch

    

#Args defined

In [None]:

class args():
  def __init__(self):
    self.smax = 30
    self.dmax = 20
    self.base_encoder = 'NBOW'
    self.rnn_type = 'RAW_MSE_CAML_FN_FM'
    self.data_link = '/content/drive/MyDrive/CAML_TOY_DATA/movie100'
    self.log_dir = 'logs'
    self.log =1 
    self.dataset ='Amazon_Electronics'
    self.gpu = 0
    self.dev = 1
    self.init = 0.01
    self.gmax = 30
    self.epochs = 50 #
    self.batch_size = 50 #
    self.implicit = 1
    self.learn_rate = 1e-3
    self.num_class = 6
    self.dropout = 0.2
    self.rnn_dropout = 0.2
    self.emb_dropout = 0.2
    self.dev_lr = 0
    self.decay_steps = 0
    self.decay_epoch = 0
    self.emb_size = 50
    self.latent_size=50
    self.rnn_size = 50
    self.num_heads = 2
    self.data_prepare =1
    self.key_word_lambda = 0.25
    self.l2_reg = 1E-6
    self.word_aggregate ='MAX'
    self.self_num_heads = 2


#GumbelSoftmax


In [None]:

class GumbelSoftmax(tf.keras.layers.Layer): #checkthis
    def __init__(self, axis=-1, **kwargs):
        """Initialization method.
        Args:
            axis (int): Axis to perform the softmax operation.
        """

        # Overrides its parent class with any custom arguments if needed
        super(GumbelSoftmax, self).__init__(**kwargs)

        # Defining a property for holding the intended axis
        self.axis = axis

    def call(self, inputs, tau, hard=1):
        """Method that holds vital information whenever this class is called.
        Args:
            x (tf.Tensor): A tensorflow's tensor holding input data.
            tau (float): Gumbel-Softmax temperature parameter.
        Returns:
            Gumbel-Softmax output and its argmax token.
        """

        # Adds a sampled Gumbel distribution to the input
        x = inputs + self.gumbel_distribution(inputs.shape)

        # Applying the softmax over the Gumbel-based input
        y = tf.nn.softmax(x / tau, self.axis)
        
        if hard ==1:
          # Sampling an argmax token from the Gumbel-based input
          y_hard = tf.cast(tf.equal(y, tf.math.reduce_max(y, 1, keepdims=True)),
                         y.dtype)
          y = tf.stop_gradient(y_hard -y) + y
        return y

    def gumbel_distribution(self,input_shape, eps=1e-20):
        """Samples a tensor from a Gumbel distribution.
        Args:
            input_shape (tuple): Shape of tensor to be sampled.
        Returns:
            An input_shape tensor sampled from a Gumbel distribution.
        """

        # Samples an uniform distribution based on the input shape
        uniform_dist = tf.random.uniform(input_shape, 0, 1)

        # Samples from the Gumbel distribution
        gumbel_dist = -1 * tf.math.log(-1 * tf.math.log(uniform_dist + eps) + eps)

        return gumbel_dist

#Self Attention

In [None]:
def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.
  
  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.
    
  Returns:
    output, attention_weights
  """
  mask = tf.cast(tf.math.equal(mask, 0), tf.float32)
  mask = mask[:, tf.newaxis, tf.newaxis, :]
  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
  
  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)  

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights

#Coattention

In [None]:


#@title
class Coattention(tf.keras.layers.Layer):
  def __init__(self,  input_size,args, dropout= None):
    super(Coattention,self).__init__()
    self.args = args
    self.fc_user = tf.keras.layers.Dense(input_size, activation='relu',kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
    self.fc_item = tf.keras.layers.Dense(input_size, activation='relu',kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
    self.initializer = tf.keras.initializers.GlorotUniform()
    self.weights_U = tf.Variable(self.initializer(shape=(input_size, input_size)))
    self.gumbel_softmax = GumbelSoftmax()
    if dropout is not None:
      self.dropout_user = tf.keras.layers.Dropout(dropout)
      self.dropout_item = tf.keras.layers.Dropout(dropout)
    else:
      self.dropout_user = None

  def call(self,user_reviews, item_reviews, user_mask, item_mask, pooling='MAX', gumbel= True, temp=0.5, hard=1):
    orig_user_reviews = user_reviews
    orig_item_reviews = item_reviews
    
    a_len = user_reviews.shape[1]
    b_len = item_reviews.shape[1]
    input_dim = user_reviews.shape[2]
    max_len = a_len
    dim = input_dim

    user_reviews= self.fc_user(user_reviews)
    item_reviews = self.fc_item(item_reviews)

    _a = tf.reshape(user_reviews, [-1, dim])
    
    z = tf.matmul(_a, self.weights_U)
    z = tf.reshape( z, (-1, a_len, dim))
    y = tf.matmul(z, tf.transpose(item_reviews, [0,2,1]))
    if user_mask is not None and item_mask is not None:
        mat_mask = tf.matmul(tf.expand_dims(user_mask,2), tf.expand_dims(item_mask,1))
    if pooling == 'MAX':
      if (user_mask is not None and item_mask is not None):
            y = -1E+30 * (1-mat_mask) + y
            att_row = tf.math.reduce_max(y,1)
            att_col = tf.math.reduce_max(y,2)
    elif pooling == 'MEAN':
       if (user_mask is not None and item_mask is not None):
            y = y * mat_mask
            att_row = tf.math.reduce_mean(y,1)
            att_col = tf.math.reduce_mean(y,2)
    if (user_mask is not None and item_mask is not None):
            att_row = -1E+30 * (1-item_mask) + att_row
            att_col = -1E+30 * (1-user_mask) + att_col

    
    _sa2 = tf.nn.softmax(att_row)
    _sa1 = tf.nn.softmax(att_col)

    if (gumbel):
        att_row = self.gumbel_softmax(att_row, temp, hard=hard)
        att_col = self.gumbel_softmax(att_col, temp, hard=hard)
    else:
        att_row = tf.nn.softmax(att_row)
        att_col = tf.nn.softmax(att_col)
    
    _a2 = att_row
    _a1 = att_col
        
    att_col = tf.expand_dims(att_col, 2)
    att_row = tf.expand_dims(att_row, 2)

    final_a = att_col * user_reviews
    final_b = att_row * item_reviews

    if self.dropout_user is not None:
      final_a = self.dropout_user(final_a)
      final_b = self.dropout_item(final_b)


    return final_a, final_b, _a1, _a2, _sa1, _sa2, y


    


In [None]:

class MultiPointerCoattentionNetworks(tf.keras.layers.Layer):
  def __init__(self, args, vocab_size, input_size=50, dropout=0.2):
    super(MultiPointerCoattentionNetworks,self).__init__()
    self.args = args
    self.vocab_size = vocab_size
    self.co_attention_review_lvl = Coattention(input_size, self.args,dropout)
    self.co_attention_concept_lvl = Coattention(input_size,self.args, dropout)
    self.concept_user_embeddings = tf.keras.layers.Embedding(self.vocab_size,self.args.emb_size,embeddings_initializer='glorot_uniform',embeddings_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
    self.concept_item_embeddings = tf.keras.layers.Embedding(self.vocab_size,self.args.emb_size,embeddings_initializer='glorot_uniform',embeddings_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
    
    self.fn1 = tf.keras.layers.Dense(self.args.emb_size,kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
    self.fn2 = tf.keras.layers.Dense(self.args.emb_size,kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
    
    self.depth = self.args.emb_size // self.args.self_num_heads
    self.uwq = tf.keras.layers.Dense(self.args.emb_size)
    self.uwk = tf.keras.layers.Dense(self.args.emb_size)
    self.uwv = tf.keras.layers.Dense(self.args.emb_size)
    
    self.udense = tf.keras.layers.Dense(self.args.emb_size)

    self.iwq = tf.keras.layers.Dense(self.args.emb_size)
    self.iwk = tf.keras.layers.Dense(self.args.emb_size)
    self.iwv = tf.keras.layers.Dense(self.args.emb_size)
    
    self.idense = tf.keras.layers.Dense(self.args.emb_size)
  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.args.self_num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

  def call(self, user_reviews, item_reviews, user_length,
           item_length, o1_embed, o2_embed, o1_len, o2_len,
           c1_inputs, c1_len, c2_inputs, c2_len, num_heads=2,
           rnn_type='RAW_MSE_CAML_FN_FM'):
    self.hatt1, self.hatt2 = [], [] 
    self.att1, self.att2 = [] , []
    self.word_att1, self.word_att2 = [], []
    f1, f2 = [], []
    self.afm = []
    self.afm2 = []
    self.word_u = []
    self.word_i = []
    user_mask = tf.sequence_mask(user_length, self.args.dmax, dtype = tf.float32)
    item_mask = tf.sequence_mask(item_length, self.args.dmax, dtype = tf.float32)
    
    batch_size = tf.shape(user_reviews)[0]
    
    uq = self.uwq(user_reviews)  # (batch_size, seq_len, d_model)
    uk = self.uwk(user_reviews)  # (batch_size, seq_len, d_model)
    uv = self.uwv(user_reviews)  # (batch_size, seq_len, d_model)
    
    uq = self.split_heads(uq, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    uk = self.split_heads(uk, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    uv = self.split_heads(uv, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
    
    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    #print(uq.shape,user_mask.shape)
    user_scaled_attention, _ = scaled_dot_product_attention(
        uq, uk, uv, user_mask)
    
    user_scaled_attention = tf.transpose(user_scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    user_concat_attention = tf.reshape(user_scaled_attention, 
                                  (batch_size, -1, self.args.emb_size))  # (batch_size, seq_len_q, d_model)

    user_reviews = self.udense(user_concat_attention)  # (batch_size, seq_len_q, d_model)

    iq = self.iwq(item_reviews)  # (batch_size, seq_len, d_model)
    ik = self.iwk(item_reviews)  # (batch_size, seq_len, d_model)
    iv = self.iwv(item_reviews)  # (batch_size, seq_len, d_model)
    
    iq = self.split_heads(iq, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    ik = self.split_heads(ik, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    iv = self.split_heads(iv, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
    
    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    
    item_scaled_attention, _ = scaled_dot_product_attention(
        iq, ik, iv, item_mask)
    
    item_scaled_attention = tf.transpose(item_scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    item_concat_attention = tf.reshape(item_scaled_attention, 
                                  (batch_size, -1, self.args.emb_size))  # (batch_size, seq_len_q, d_model)

    item_reviews = self.idense(item_concat_attention)  # (batch_size, seq_len_q, d_model)


    for i in range(num_heads):  
      

      attended_user, attended_item, att_col, att_row, soft_att_col, soft_att_row, afm = self.co_attention_review_lvl(
          user_reviews, item_reviews, user_mask, item_mask, gumbel = True , pooling='MAX'
      )

      attended_user = tf.math.reduce_sum(attended_user, 1)
      attended_item = tf.math.reduce_sum(attended_item, 1)
      f1.append(attended_user)
      f2.append(attended_item)
      self.att1.append(soft_att_col)
      self.att2.append(soft_att_row)
      self.hatt1.append(att_col)
      self.hatt2.append(att_row)
      self.afm.append(afm)
      

      #print("=========================================")
      #print('Concept level attention')
    
      sub_afm=[]
      _odim = o1_embed.get_shape().as_list()[2]
      o1_embed = tf.reshape(o1_embed, [-1, self.args.dmax,
                            self.args.smax * self.args.emb_size])
      o2_embed = tf.reshape(o2_embed, [-1, self.args.dmax,
                            self.args.smax * self.args.emb_size])
      
      _att_col = tf.expand_dims(att_col, 2)
      _att_row = tf.expand_dims(att_row, 2)
      
      
      review_concept1 = tf.reduce_sum(tf.reshape(c1_inputs, [-1,self.args.dmax, self.args.smax]) * tf.cast(_att_col, dtype = tf.int64),1)
      review_concept2 = tf.reduce_sum(tf.reshape(c2_inputs, [-1,self.args.dmax, self.args.smax]) * tf.cast(_att_row, dtype = tf.int64),1)
      
      
      _o1 = self.concept_user_embeddings(review_concept1)
      _o2 = self.concept_item_embeddings(review_concept2)

      _o1_len = tf.reshape(c1_len, [-1, self.args.dmax])
      _o2_len = tf.reshape(c2_len, [-1, self.args.dmax])
      _o1_len = tf.reduce_sum(_o1_len * tf.cast(att_col, tf.int64),1)
      _o2_len = tf.reduce_sum(_o2_len * tf.cast(att_row, tf.int64),1)
      _o1_len = tf.reshape(_o1_len,[-1])
      _o2_len = tf.reshape(_o2_len, [-1])
      o1_mask = tf.sequence_mask(_o1_len, self.args.smax, dtype=tf.float32)
      o2_mask = tf.sequence_mask(_o2_len, self.args.smax, dtype=tf.float32)

      attended_user_concept, attended_item_concept, att_col_concept, att_row_concept, soft_att_col_concept, soft_att_row_concept, afm_concept = self.co_attention_concept_lvl(
          _o1, _o2, o1_mask, o2_mask, gumbel = True , pooling='MEAN'
      )
      sub_afm.append(afm_concept)
      attended_user_concept = tf.reduce_sum(attended_user_concept,1)
      attended_item_concept = tf.reduce_sum(attended_item_concept,1)

      f1.append(attended_user_concept)
      f2.append(attended_item_concept)

      self.afm2.append(afm_concept)
      self.word_att1.append(soft_att_col_concept)
      self.word_att2.append(soft_att_row_concept)
      word1 = tf.expand_dims(tf.reduce_sum(review_concept1 * tf.cast(att_col_concept, dtype=tf.int64), 1), 1)
      word2 = tf.expand_dims(tf.reduce_sum(review_concept2 * tf.cast(att_row_concept, dtype=tf.int64), 1), 1)
      self.word_u.append(word1)
      self.word_i.append(word2)
    
    self.word_u = tf.concat(self.word_u, 1)
    self.word_i = tf.concat(self.word_i, 1)
    user_output = tf.concat(f1,1)
    item_output = tf.concat(f2,1)
    user_output = self.fn1(user_output)
    item_output = self.fn2(item_output)

    return user_output, item_output, self.word_u, self.word_i


#Factorization Machine


In [None]:
#@title
class FactorizationMachine(tf.keras.layers.Layer):
  def __init__(self, fm_p, fm_k=5, dropout=0.2):
    super(FactorizationMachine, self).__init__()
    self.dropout1 = tf.keras.layers.Dropout(dropout)
    self.initializer = tf.constant_initializer( value=0.0)
    self.glorot_initializer = tf.keras.initializers.glorot_uniform()
    self.fm_w0 = tf.Variable(self.initializer([1]))
    self.fm_w = tf.Variable(self.initializer([fm_p]))
    self.fm_V = tf.Variable(self.glorot_initializer([fm_k,fm_p]))

    
  def call(self, user_output, item_output):
    
    input_vec = tf.concat([user_output, item_output],1)
    input_vec = self.dropout1(input_vec)
    fm_linear_terms = self.fm_w0 +  tf.matmul(input_vec, tf.expand_dims(self.fm_w, 1))
    fm_interactions_part1 = tf.matmul(input_vec, tf.transpose(self.fm_V))
    fm_interactions_part1 = tf.pow(fm_interactions_part1, 2)
    fm_interactions_part2 = tf.matmul(tf.pow(input_vec, 2),
                                        tf.transpose(tf.pow(self.fm_V, 2)))
    fm_interactions = fm_interactions_part1 - fm_interactions_part2
    latent_dim = fm_interactions
    fm_interactions = tf.reduce_sum(fm_interactions, 1, keepdims = True)
    fm_interactions = tf.multiply(0.5, fm_interactions)
    fm_prediction = tf.add(fm_linear_terms, fm_interactions)
    return fm_prediction

In [None]:
class RatingPredictorMLP(tf.keras.layers.Layer):
  def __init__(self,args):
    super(RatingPredictorMLP, self).__init__()
    self.args = args
    self.fc1 = tf.keras.layers.Dense(100, activation='relu',kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
    self.fc2 = tf.keras.layers.Dense(50, activation='relu',kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
    self.fc3 = tf.keras.layers.Dense(10, activation='relu',kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
    self.fc4 = tf.keras.layers.Dense(1, activation='linear',kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))

  def call(self, user_output, item_output):
    input_vec = tf.concat([user_output, item_output],1)
    input_vec = self.fc1(input_vec)
    input_vec = self.fc2(input_vec)
    input_vec = self.fc3(input_vec)
    input_vec = self.fc4(input_vec)
    return input_vec
    



#Model


In [None]:
class Decoder(tf.keras.Model):
  def __init__(self,vocab_size,args):
    super(Decoder,self).__init__()
    self.args = args
    self.embedding = tf.keras.layers.Embedding(vocab_size,self.args.emb_size,embeddings_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
    self.gru = tf.keras.layers.GRU(self.args.rnn_size,return_sequences=True,return_state=True,recurrent_initializer='glorot_uniform')
    self.fc = tf.keras.layers.Dense(vocab_size,kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
    self.dropout = tf.keras.layers.Dropout(args.dropout)
  def call(self,x,hidden):
    x = self.embedding(x)
    x = tf.concat([tf.expand_dims(hidden, 1), x], axis=-1)
    output, state = self.gru(x)
    output = tf.reshape(output, (-1, output.shape[2]))
    x = self.fc(output)
    x = self.dropout(x)
    return x,state



In [None]:
import tensorflow as tf

class Model(tf.keras.Model):
    def __init__(self, vocab_size, args, 
                 char_vocab=0, pos_vocab=0,
                 mode='RANK', num_user=10, num_item=10):
        super(Model,self).__init__()
        self.vocab_size = vocab_size
        self.char_vocab = char_vocab
        self.pos_vocab = pos_vocab
        #self.graph = tf.Graph()
        self.args = args
      
        self.inspect_op = []
        self.mode=mode
        self.write_dict = {}
        self.PAD_tag = 0
        self.SOS_tag = 2
        self.EOS_tag = 3
        self.UNK_tag = 1
        # For interaction data only (disabled and removed from this repo)
        self.num_user = num_user
        self.num_item = num_item
        print('Creating Model in [{}] mode'.format(self.mode))
        self.feat_prop = None
        
        self.embeddings = tf.keras.layers.Embedding(self.vocab_size,self.args.emb_size,embeddings_initializer='glorot_uniform',embeddings_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
        self.user_embeddings = tf.keras.layers.Embedding(self.num_user, self.args.latent_size, embeddings_initializer ='glorot_uniform',embeddings_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
        self.item_embeddings = tf.keras.layers.Embedding(self.num_item, self.args.latent_size, embeddings_initializer='glorot_uniform',embeddings_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
        self.user_dropout_1 = tf.keras.layers.Dropout(self.args.dropout)
        self.user_fc_1 = tf.keras.layers.Dense(self.args.rnn_size, activation='relu',kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
        self.user_dropout_2 = tf.keras.layers.Dropout(self.args.dropout)
        self.user_fc_2 = tf.keras.layers.Dense(self.args.rnn_size, activation ='relu',kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))

        self.item_dropout_1 = tf.keras.layers.Dropout(self.args.dropout)
        self.item_fc_1 = tf.keras.layers.Dense(self.args.rnn_size, activation='relu',kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
        self.item_dropout_2 = tf.keras.layers.Dropout(self.args.dropout)
        self.item_fc_2 = tf.keras.layers.Dense(self.args.rnn_size, activation ='relu',kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
  
              
        self.multi_pointer_coattention_networks = MultiPointerCoattentionNetworks(input_size=self.args.emb_size, args=args, vocab_size=self.vocab_size)
        self.af_co_attend_user_dropout = tf.keras.layers.Dropout(self.args.dropout)
        self.af_co_attend_item_dropout = tf.keras.layers.Dropout(self.args.dropout)
        #self.fm = FactorizationMachine(4*self.args.emb_size, fm_k=10, dropout=self.args.dropout)
        self.mlp = RatingPredictorMLP(self.args)

        #generate initial hidden state s0
        self.review_user_mapping = tf.keras.layers.Dense(self.args.rnn_size,kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
        self.review_item_mapping = tf.keras.layers.Dense(self.args.rnn_size,kernel_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
        self.review_rating_embedding = tf.keras.layers.Embedding(5,self.args.rnn_size,embeddings_regularizer= tf.keras.regularizers.L2(self.args.l2_reg))
    def make_hmasks(self, inputs, smax):
        # Hierarchical Masks
        # Inputs are bsz x (dmax * smax)
        inputs = tf.reshape(inputs,[-1, smax])
        masked_inputs = tf.cast(inputs, tf.bool)
        return masked_inputs

    def learn_single_repr(self, q1_embed, q1_len, q1_max,
                       pool=False, mask=None):
      
      if mask is not None:
          masks = tf.cast(mask, tf.float32)
          masks = tf.expand_dims(masks, 2)
          masks = tf.tile(masks, [1,1, self.args.emb_size])
          q1_embed = q1_embed * masks

      q1_output = tf.reduce_sum(q1_embed, 1)
      if(pool):
          return q1_embed, q1_output
      

      return q1_embed, q1_output
    
    def call(self, user_reviews, user_length, item_reviews,
             item_length, user_concepts, user_concepts_len, item_concepts, item_concepts_len, user_id, item_id,sig_labels,training=True):
      self.user_reviews = user_reviews
      self.item_reviews = item_reviews
      self.user_length = user_length
      self.item_length = item_length
      self.user_concepts = user_concepts
      self.user_concepts_len = user_concepts_len
      self.item_concepts = item_concepts
      self.item_concepts_len = item_concepts_len
      self.user_batch = self.user_embeddings(user_id)
      self.item_batch = self.item_embeddings(item_id)


      #self.user_batch = self.user_embeddings(user_id)
      #self.item_batch = self.item_embeddings(item_id)
      self.user_reviews_mask = tf.cast(self.user_reviews, tf.bool)
      self.item_reviews_mask = tf.cast(self.item_reviews, tf.bool)
      self.user_reviews_hmask = self.make_hmasks(self.user_reviews, self.args.smax)
      
      self.item_reviews_hmask = self.make_hmasks(self.item_reviews, self.args.smax)
      
      user_reviews_embed = self.embeddings(self.user_reviews)
      item_reviews_embed = self.embeddings(self.item_reviews)
      #print(user_reviews_embed.shape)
      #print(item_reviews_embed.shape)
      user_reviews_embed, user_length = hierarchical_flatten(user_reviews_embed, self.user_length, self.args.smax)
      item_reviews_embed, item_length = hierarchical_flatten(item_reviews_embed, self.item_length, self.args.smax) #(num_users * d_max x smax x embed_size)
      self.o1_embed = user_reviews_embed
      self.o2_embed = item_reviews_embed
      self.o1_len = user_length
      self.o2_len = item_length
      #print(user_reviews_embed.shape)
      #print(item_reviews_embed.shape)
      _, user_reviews_embed = self.learn_single_repr(user_reviews_embed, user_length,
                                                  self.args.base_encoder,pool=True, mask=self.user_reviews_hmask)
      _, item_reviews_embed = self.learn_single_repr(item_reviews_embed, item_length,
                                                  self.args.base_encoder,pool=True, mask=self.item_reviews_hmask)
      _dim = user_reviews_embed.get_shape().as_list()[1]
      self.user_reviews_embed = tf.reshape(user_reviews_embed, [-1, self.args.dmax, _dim])
      self.item_reviews_embed = tf.reshape(item_reviews_embed, [-1, self.args.dmax, _dim])
      #print("=================================================")
      #print(tf.cast(self.user_length, dtype='bool'))
      user_length = tf.cast(tf.math.count_nonzero(self.user_length, axis =1), tf.int32)
      item_length = tf.cast(tf.math.count_nonzero(self.item_length, axis =1), tf.int32)
      #print(user_length, item_length)

      #joint representation
      user_reviews_embed = self.user_dropout_1(self.user_reviews_embed)
      user_reviews_embed = self.user_fc_1(user_reviews_embed)
      user_reviews_embed = self.user_dropout_2(user_reviews_embed)
      user_reviews_embed = self.user_fc_2(user_reviews_embed)

      item_reviews_embed = self.item_dropout_1(self.item_reviews_embed)
      item_reviews_embed = self.item_fc_1(item_reviews_embed)
      item_reviews_embed = self.item_dropout_2(item_reviews_embed)
      item_reviews_embed = self.item_fc_2(item_reviews_embed)
      
      print('See here shape', user_reviews_embed.shape)
      user_output, item_output, word_u, word_i = self.multi_pointer_coattention_networks(user_reviews_embed, item_reviews_embed,
                                                             user_length, item_length, self.o1_embed,self.o2_embed,
                                                             self.o1_len, self.o2_len, self.user_concepts, self.user_concepts_len, self.item_concepts,
                                                             self.item_concepts_len,
                                                             rnn_type = self.args.rnn_type, num_heads=self.args.num_heads
                                                             )
      
      try:
        self.max_norm = tf.reduce_max(tf.norm(user_output,
                                              ord='euclidean',
                                              keep_dims=True, axis=1))
      except:
        self.max_norm = 0

      user_output = tf.concat([user_output, self.user_batch], 1)
      item_output = tf.concat([item_output, self.item_batch], 1)

      user_output = self.af_co_attend_user_dropout(user_output)
      item_output = self.af_co_attend_item_dropout(item_output)

      
      #print('=========================================================================')
      #print('FactorizationMachine')
      #prediction = self.fm(user_output, item_output)
      prediction = self.mlp(user_output, item_output)

      #generate initial hidden state s0 without feeding rating
      #move [1,5] -> [0,4]
      r_input = tf.cast(sig_labels, dtype=tf.int64) - 1
      #r_input = tf.clip_by_value(tf.cast(tf.reshape(prediction, [-1]), dtype=tf.int64), 1, 5)-1
      r_embed = self.review_rating_embedding(r_input)
      state = tf.keras.activations.tanh(r_embed+self.review_user_mapping(user_output)+self.review_item_mapping(item_output))
      return prediction,state, word_u, word_i

      # Inspired by code from https://github.com/3878anonymous/CAML/blob/master/tf_models/model_caml.py
    def _beam_search_infer(self, q1_output, q2_output, r_input, reuse=None):
          dim = q1_output.get_shape().as_list()[1]
          

          init= tf.initializers.GlorotUniform()
          #cal state
          self.review_user_mapping = tf.Variable(init(shape=(dim, self.args.rnn_dim)))
                                                    

          self.review_item_mapping = tf.Variable(init(shape=(dim, self.args.rnn_dim)))
                                                    

          

          self.review_bias = tf.Variable(init(shape=(self.args.rnn_dim)))

          self.rnn_cell = tf.compat.v2.nn.rnn_cell.GRUCell(self.args.rnn_dim)

          #cal state
          if r_input is not None:
              r_embed = tf.nn.embedding_lookup(self.review_rating_embeddings, r_input)
              state = tf.math.tanh(r_embed + tf.matmul(q1_output, self.review_user_mapping) + tf.matmul(q2_output, self.review_item_mapping) + self.review_bias)
          else:
              state = tf.math.tanh(tf.matmul(q1_output, self.review_user_mapping) + tf.matmul(q2_output, self.review_item_mapping) + self.review_bias)

          self.beam_batch = self.args.beam_size * self.args.batch_size
          self.beam_batch_max = self.args.beam_size * self.args.beam_size * self.args.batch_size

          #max_val = self.max_val
          #initializer = tf.random_uniform_initializer(-max_val, max_val, dtype=self.dtype)

          neg_words = tf.Variable(tf.constant(0.0, shape=[self.vocab_size]), trainable=False)
          neg_words = tf.reshape(tf.scatter_update(neg_words, tf.constant(self.UNK_tag), tf.constant(1.0)), shape=[1, self.vocab_size])

          neg_words_batch = tf.tile(neg_words, [self.args.batch_size, 1])
          neg_words_beam_batch = tf.tile(neg_words, [self.beam_batch, 1])

          pad_batch = tf.constant(self.PAD_tag, shape=[self.beam_batch])
          eos_batch = tf.constant(self.EOS_tag, shape=[self.beam_batch])
          min_batch = tf.constant(-100.0, shape=[self.beam_batch])
          zero_batch = tf.constant(0.0, shape=[self.beam_batch])

          #pad_batch_max = tf.constant(self.PAD_tag, shape=[self.beam_batch_max])
          eos_batch_max = tf.constant(self.EOS_tag, shape=[self.beam_batch_max])

          
          final_input = state

          
          logits = tf.compat.v1.layers.dense(final_input, self.vocab_size, kernel_initializer=self.initializer, bias_initializer=self.initializer, name='review_output_layer')
          self.preds = tf.nn.softmax(logits) - neg_words_batch
          values, indices = tf.nn.top_k(self.preds, self.args.beam_size)

          init_ans = tf.reshape(indices, shape=[self.beam_batch, 1])
          init_loss = tf.math.log(tf.reshape(values, shape=[self.beam_batch]))
          init_tag = tf.cast(tf.equal(tf.reshape(indices, shape=[self.beam_batch]), eos_batch), tf.int32)
          init_end_tag = tf.reduce_sum(init_tag, axis=None)
          sum_tag = tf.constant(self.beam_batch, dtype=tf.int32)
          max_length = tf.constant(self.args.gmax, dtype=tf.int32)

          init_len = tf.constant(1, shape=[self.beam_batch])

          init_lm_inputs = tf.reshape(indices, shape=[self.beam_batch])
          init_state = tf.reshape(tf.tile(state, [1, self.args.beam_size]), shape = [self.beam_batch, self.args.rnn_dim])

          
          def condition(end_tag, tag, answer, lens, *args):
              return tf.logical_and(end_tag < sum_tag, tf.shape(answer)[1]<= max_length)

          def forward_one_step(end_tag, tag, answer, lens, loss, lm_inputs, state):
              self.tip_inputs_embedded = tf.nn.embedding_lookup(
                      params=self.embeddings, ids=lm_inputs)

              
              self.rnnlm_outputs, new_state = self.rnn_cell(self.tip_inputs_embedded, state)

              #new_tag = 1 - tag
              loss_old_pad = tf.where(tf.cast(tag, tf.bool), zero_batch, min_batch)
              loss_old = loss + loss_old_pad
              tag_old = tag
              lens_old = lens
              state_old = state
              answer_old = tf.concat([answer, tf.reshape(pad_batch, shape=[self.beam_batch, 1])], axis=1)

              loss_new_pad = tf.where(tf.cast(tag, tf.bool), min_batch, zero_batch)
              loss_new = loss + loss_new_pad
              loss_new = tf.reshape(tf.tile(tf.reshape(loss_new, shape=[self.beam_batch, 1]), [1, self.args.beam_size]), shape = [self.beam_batch_max])

              
              final_input = self.rnnlm_outputs
              
              logits = tf.compact.v1.layers.dense(final_input, self.vocab_size, kernel_initializer=self.initializer, bias_initializer=self.initializer, reuse=True)
              
              self.preds = tf.nn.softmax(logits) - neg_words_beam_batch
              values, indices = tf.nn.top_k(self.preds, self.args.beam_size)

              values = tf.reshape(values, shape=[self.beam_batch_max])
              loss_new = loss_new + tf.log(values)

              answer_new = tf.reshape(tf.tile(answer, [1, self.args.beam_size]), shape=[self.beam_batch_max, -1])
              indices = tf.reshape(indices, shape=[self.beam_batch_max])
              answer_new = tf.concat([answer_new, tf.reshape(indices, shape=[self.beam_batch_max, 1])], axis=1)

              state_new = tf.reshape(tf.tile(new_state, [1, self.args.beam_size]), shape=[self.beam_batch_max, -1])
              tag_new = tf.cast(tf.equal(indices, eos_batch_max), tf.int32)
              lens_new = tf.reshape(tf.tile(tf.reshape(lens, shape=[self.beam_batch, 1]), [1, self.args.beam_size]), shape=[self.beam_batch_max])
              lens_new = lens_new + 1

              #merge
              merge_tag = tf.concat([tf.reshape(tag_old, shape=[self.args.batch_size, -1]), tf.reshape(tag_new, shape=[self.args.batch_size, -1])], axis=1)
              merge_lens = tf.concat([tf.reshape(lens_old, shape=[self.args.batch_size, -1]), tf.reshape(lens_new, shape=[self.args.batch_size, -1])], axis=1)
              merge_state = tf.concat([tf.reshape(state_old, shape=[self.args.batch_size, self.args.beam_size, -1]), tf.reshape(state_new, shape=[self.args.batch_size, self.args.beam_size * self.args.beam_size, -1])], axis=1)
              merge_loss = tf.concat([tf.reshape(loss_old, shape=[self.args.batch_size, -1]), tf.reshape(loss_new, shape=[self.args.batch_size, -1])], axis=1)
              #average_loss = tf.div(merge_loss, merge_lens)
              merge_answer =  tf.concat([tf.reshape(answer_old, shape=[self.args.batch_size, self.args.beam_size, -1]), tf.reshape(answer_new, shape=[self.args.batch_size, self.args.beam_size * self.args.beam_size, -1])], axis=1)
              merge_inputs = tf.concat([tf.reshape(lm_inputs, shape=[self.args.batch_size, -1]), tf.reshape(indices, shape=[self.args.batch_size, -1])], axis=1)

              merge_values, merge_indices = tf.nn.top_k(merge_loss, self.args.beam_size)

              #new_loss = tf.reshape(merge_values, shape=[self.beam_batch])
              merge_indices = tf.reshape(merge_indices, shape=[self.beam_batch, 1])
              range_batch = tf.reshape(tf.tile(tf.reshape(tf.range(self.args.batch_size), shape=[self.args.batch_size, 1]), [1, self.args.beam_size]), shape=[self.beam_batch, 1])
              index = tf.concat([range_batch, merge_indices], axis=1)

              new_loss = tf.reshape(tf.gather_nd(merge_loss, index), shape=[self.beam_batch])
              new_tag = tf.reshape(tf.gather_nd(merge_tag, index), shape=[self.beam_batch])
              new_lens = tf.reshape(tf.gather_nd(merge_lens, index), shape=[self.beam_batch])
              new_state = tf.reshape(tf.gather_nd(merge_state, index), shape=[self.beam_batch, -1])
              new_answer = tf.reshape(tf.gather_nd(merge_answer, index), shape=[self.beam_batch, -1])
              new_inputs = tf.reshape(tf.gather_nd(merge_inputs, index), shape=[self.beam_batch])

              sum_end = tf.reduce_sum(new_tag, axis=None)

              return sum_end, new_tag, new_answer, new_lens, new_loss, new_inputs, new_state

          sum_end, tag, answer, lens, loss, lm_inputs, state = tf.while_loop(condition, forward_one_step, [init_end_tag, init_tag, init_ans, init_len, init_loss, init_lm_inputs, init_state],  shape_invariants=[init_end_tag.get_shape(), init_tag.get_shape(), tf.TensorShape([self.beam_batch, None]), init_len.get_shape(), init_loss.get_shape(), init_lm_inputs.get_shape(), init_state.get_shape()])

          return answer
        

    
    

In [None]:
import matplotlib.pyplot as plt
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

def loss_function(real, pred):
  #print(real)
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  return tf.reduce_mean(loss_)

def _cal_key_loss(args, pred, word_u, word_i):
  preds = tf.nn.softmax(pred)
  preds = -tf.math.log(preds + 1e-7)
  word = tf.concat([word_u, word_i], 1)
  num_words = word.shape[1]
  prefix = tf.range(args.batch_size)
  prefix = tf.tile(tf.expand_dims(prefix,1), [1, num_words])
  word = tf.cast(word,dtype='int32')
  indices = tf.concat([tf.expand_dims(prefix, 2), tf.expand_dims(word, 2)], 2)
  l = tf.gather_nd(preds, indices)
  if args.word_aggregate == 'MEAN':
            l = tf.reduce_mean(l, 1)
  else:
      if args.word_aggregate == 'MAX':
        l = tf.reduce_max(l, 1)
      else:
        l = tf.reduce_min(l, 1)

  l = args.key_word_lambda * tf.reduce_mean(l)

  return l


class CFExperiment(Experiment):
    """ Main experiment class for collaborative filtering.
    Check tylib/exp/experiment.py for base class.
    """
    
    def __init__(self, inject_params=None, preprocessed=False):
        print("Starting Rec Experiment")
        super(CFExperiment, self).__init__()
        self.uuid = datetime.now().strftime("%d:%m:%H:%M:%S")
        #self.parser = build_parser()
        self.no_text_mode = False

        #self.args = self.parser.parse_args()
        self.args= args()
        #print(self.args.l2_reg)
        self.max_val = 5
        self.min_val = 1

        self.show_metrics = ['MSE','RMSE','MAE', 'MSE_int','RMSE_int','MAE_int', 'Gen_loss', 'All_loss', 'Gen_loss', 'F1', 'ACC', 'Review_acc']
        #self.eval_primary = 'RMSE'
        self.eval_primary = 'All_loss'
        # For hierarchical setting
        self.args.qmax = self.args.smax * self.args.dmax
        self.args.amax = self.args.smax * self.args.dmax

        print("Setting up environment..")

        #self.model_wrapper()
        self.model_name = 'RAW_MSE_CAML_FN_FM'
        self._setup()
        print(preprocessed)
        if (preprocessed==False):
          self._load_sets()
          self.train_data = self._combine_reviews(self.train_rating_set, self.train_reviews)
          #self.test_set = self._combine_reviews(self.test_rating_set, self.test_reviews)
          #self.dev_set = self._combine_reviews(self.dev_rating_set, self.dev_reviews)
          self.data = self._prepare_set(self.train_data)
          self.num_user = max(self.ui_review_dict.keys())+1
          self.num_item = max(self.iu_review_dict.keys())+1
          # Saving preprocessed data
          saved_data = {'data':self.data,'vocab':self.vocab,'num_user':self.num_user, 'num_item':self.num_item}
          with open('/content/drive/MyDrive/CAML_TOY_DATA/movie_100/preprocessed_dict.bin','wb') as f:
            pickle.dump(saved_data,f)
        else:
          with open('/content/drive/MyDrive/CAML_TOY_DATA/movie_100/preprocessed_dict.bin','rb') as f:
            saved_data = pickle.load(f)
            self.data = saved_data['data']
            self.vocab = saved_data['vocab']
            self.num_user = saved_data['num_user']
            self.num_item = saved_data['num_item']
      


        
    def _load_sets(self):
        # Load train, test and dev sets

        data_link = '/content/drive/MyDrive/CAML_TOY_DATA/movie_100'

        if(self.no_text_mode==False):
            self.word_index = self.load_vocab(data_link)
            self.index_word = {k:v for v, k in self.word_index.items()}

            self.stop_concept = {}
            frequent_words = 100
            for i in range(frequent_words + 4):
                self.stop_concept[self.index_word[i]] = 1
            self.vocab = len(self.word_index)
            print("vocab={}".format(self.vocab))
            #self.word2df = None


        self.train_rating_set, self.train_reviews = self.load_dataset(data_link, 'train')
        self.dev_rating_set, self.dev_reviews = self.load_dataset(data_link, 'valid')

        if(self.args.dev==0):
            self.train_rating_set += self.dev_rating_set
        self.test_rating_set, self.test_reviews = self.load_dataset(data_link, 'test')

        if(self.no_text_mode==False):

            #load_reviews
            self.ui_review_dict, self.iu_review_dict = self.load_review_data(data_link, "review")
            self.ui_concept_dict, self.iu_concept_dict = self.load_review_data(data_link, "concepts")
            #self.num_users = len(self.ui_review_dict)
            #self.num_items = len(self.iu_review_dict)
            self.num_users = max(self.ui_review_dict.keys()) + 1
            self.num_items = max(self.iu_review_dict.keys()) + 1


        self.write_to_file("Train={} Dev={} Test={}".format(
                                len(self.train_rating_set),
                                len(self.dev_rating_set),
                                len(self.test_rating_set)))

    def load_vocab(self, data_dir):
        lines_vocab = open('%s/vocabulary.txt' % data_dir, 'r', encoding='utf-8').readlines()

        vocab = {}
        for i,word in enumerate(lines_vocab):
            vocab[word.strip()] = i + 4
        vocab[PAD] = 0
        vocab[UNK] = 1
        vocab[SOS] = 2
        vocab[EOS] = 3
        
        return vocab

    def load_dataset(self, data_dir, dataset_type):
        output = []
        lines_user_id = open('%s/%s_userid.txt' % (data_dir, dataset_type), 'r', encoding='utf-8').readlines()
        lines_item_id = open('%s/%s_itemid.txt' % (data_dir, dataset_type), 'r', encoding='utf-8').readlines()
        lines_rating = open('%s/%s_rating.txt' % (data_dir, dataset_type), 'r', encoding='utf-8').readlines()
        lines_review = open('%s/%s_review_1.txt' % (data_dir, dataset_type), 'r', encoding='utf-8').readlines()

        reviews = []

        concept_dict = {}
        for key in self.word_index:
            words = key.split(' ')
            l = len(words)
            for i in range(l - 1):
                concept_dict[" ".join(words[:l - i])] = 1

        for i in range(len(lines_rating)):
            output.append([int(lines_user_id[i].strip()), int(lines_item_id[i].strip()), int(lines_rating[i].strip())])

            linedata = []
            line = lines_review[i].strip()
            line = line.split('\t')

            linedata.append(self.word_index[EOS])
            l = len(line)
            pos = l
            while 1:
                pos = pos - 1
                if pos < 0:
                    break

                match_string = line[pos]
                new_pos = pos
                for j in range(pos):
                    if (" ".join(line[pos - j - 1: pos + 1]) in concept_dict):
                        if (" ".join(line[pos - j - 1: pos + 1]) in self.word_index):
                            match_string = " ".join(line[pos - j - 1: pos + 1])
                            new_pos = pos - j - 1
                        continue
                    else:
                        break
                if match_string in self.word_index:
                    linedata.append(self.word_index[match_string])
                else:
                    linedata.append(self.word_index[UNK])
                pos = new_pos

            linedata.append(self.word_index[SOS])
            linedata = linedata[::-1]

            reviews.append(linedata)

        return output, reviews

    def load_review_data(self, data_dir, data_type):
        lines_user_id = open('%s/train_userid.txt' % data_dir, 'r', encoding='utf-8').readlines()
        lines_item_id = open('%s/train_itemid.txt' % data_dir, 'r', encoding='utf-8').readlines()
        lines_review = open('%s/train_%s.txt' % (data_dir, data_type), 'r', encoding='utf-8').readlines()

        ui_dict = {}
        iu_dict = {}

        stop_concept = self.stop_concept

        for i in range(len(lines_review)):
            user = int(lines_user_id[i].strip())
            item = int(lines_item_id[i].strip())

            linedata = []
            line = lines_review[i].strip()
            if not (len(line) == 0):
                line = line.split('\t')
                for j in range(len(line)):
                    if line[j] in stop_concept:
                        if data_type == "concepts":
                            continue
                    if line[j] in self.word_index:
                        linedata.append(self.word_index[line[j]])
                    else:
                        linedata.append(self.word_index[UNK])
                if len(linedata)>self.args.smax:
                   linedata = linedata[:self.args.smax]

            if user not in ui_dict:
                ui_dict[user] = {}
            if item not in iu_dict:
                iu_dict[item] = {}

            ui_dict[user][item] = linedata
            iu_dict[item][user] = linedata

        length1 = [len(ui_dict[x]) for x in ui_dict]
        length2 = [len(iu_dict[x]) for x in iu_dict]
        length3 = []
        for x in ui_dict:
         length3 += [len(ui_dict[x][y]) for y in ui_dict[x]]
        show_stats('{}:user num review'.format(data_type), length1)
        show_stats('{}:item num review'.format(data_type), length2)
        show_stats('{}:review num word'.format(data_type), length3)

        return ui_dict, iu_dict

    def _combine_reviews(self, data, reviews = None):
        user = [x[0] for x in data]
        items = [x[1] for x in data]
        labels = [x[2] for x in data]

        #prep generation outputs
        if reviews != None:

            gen_outputs, gen_len = prep_data_list(reviews, self.args.gmax)

        output = []
        for i in range(len(user)):
            output.append([user[i], items[i], labels[i], gen_outputs[i], gen_len[i]])

        return output

    def _prepare_set(self, data):

        user = [x[0] for x in data]
        items = [x[1] for x in data]
        labels = [x[2] for x in data]

        # Raw user-item ids
        user_idx = user
        item_idx = items

        user_list = []
        item_list = []
        user_concept_list = []
        item_concept_list = []
        user_len = []
        item_len = []
        user_concept_len = []
        item_concept_len = []
        for i in range(len(user)):
            user_reviews = []
            item_reviews = []
            user_concepts = []
            item_concepts = []
            user_r_len = []
            user_c_len = []
            item_r_len = []
            item_c_len = []

            if self.args.data_prepare == 1:
                if items[i] in self.ui_review_dict[user[i]]:
                    user_reviews.append(self.ui_review_dict[user[i]][items[i]])
                    user_concepts.append(self.ui_concept_dict[user[i]][items[i]])
                    user_r_len.append(len(self.ui_review_dict[user[i]][items[i]]))
                    user_c_len.append(len(self.ui_concept_dict[user[i]][items[i]]))
                for x in self.ui_review_dict[user[i]]:
                    if not x==items[i]:
                        user_reviews.append(self.ui_review_dict[user[i]][x])
                        user_concepts.append(self.ui_concept_dict[user[i]][x])
                        user_r_len.append(len(self.ui_review_dict[user[i]][x]))
                        user_c_len.append(len(self.ui_concept_dict[user[i]][x]))
                        if len(user_reviews) == self.args.dmax:
                            break
                if user[i] in self.iu_review_dict[items[i]]:
                    item_reviews.append(self.iu_review_dict[items[i]][user[i]])
                    item_concepts.append(self.iu_concept_dict[items[i]][user[i]])
                    item_r_len.append(len(self.iu_review_dict[items[i]][user[i]]))
                    item_c_len.append(len(self.iu_concept_dict[items[i]][user[i]]))
                for x in self.iu_review_dict[items[i]]:
                    if not x==user[i]:
                        item_reviews.append(self.iu_review_dict[items[i]][x])
                        item_concepts.append(self.iu_concept_dict[items[i]][x])
                        item_r_len.append(len(self.iu_review_dict[items[i]][x]))
                        item_c_len.append(len(self.iu_concept_dict[items[i]][x]))
                        if len(item_reviews) == self.args.dmax:
                            break
                user_list.append(user_reviews)
                item_list.append(item_reviews)
                user_concept_list.append(user_concepts)
                item_concept_list.append(item_concepts)
                user_len.append(user_r_len)
                item_len.append(item_r_len)
                user_concept_len.append(user_c_len)
                item_concept_len.append(item_c_len)
            elif self.args.data_prepare == -1:
                tmp = len(self.ui_review_dict[user[i]])
                for x in self.ui_review_dict[user[i]]:
                    if (not x==items[i]):
                        user_reviews.append(self.ui_review_dict[user[i]][x])
                        user_concepts.append(self.ui_concept_dict[user[i]][x])
                        user_r_len.append(len(self.ui_review_dict[user[i]][x]))
                        user_c_len.append(len(self.ui_concept_dict[user[i]][x]))
                        if len(user_reviews) == self.args.dmax:
                            break
                tmp = len(self.iu_review_dict[items[i]])
                for x in self.iu_review_dict[items[i]]:
                    if (not x==user[i]):
                        item_reviews.append(self.iu_review_dict[items[i]][x])
                        item_concepts.append(self.iu_concept_dict[items[i]][x])
                        item_r_len.append(len(self.iu_review_dict[items[i]][x]))
                        item_c_len.append(len(self.iu_concept_dict[items[i]][x]))
                        if len(item_reviews) == self.args.dmax:
                            break
                user_list.append(user_reviews)
                item_list.append(item_reviews)
                user_concept_list.append(user_concepts)
                item_concept_list.append(item_concepts)
                user_len.append(user_r_len)
                item_len.append(item_r_len)
                user_concept_len.append(user_c_len)
                item_concept_len.append(item_c_len)
            else:
                for x in self.ui_review_dict[user[i]]:
                    user_reviews.append(self.ui_review_dict[user[i]][x])
                    user_concepts.append(self.ui_concept_dict[user[i]][x])
                    user_r_len.append(len(self.ui_review_dict[user[i]][x]))
                    user_c_len.append(len(self.ui_concept_dict[user[i]][x]))
                    if len(user_reviews) == self.args.dmax:
                        break
                user_list.append(user_reviews)
                user_concept_list.append(user_concepts)
                for x in self.iu_review_dict[items[i]]:
                    item_reviews.append(self.iu_review_dict[items[i]][x])
                    item_concepts.append(self.iu_concept_dict[items[i]][x])
                    item_r_len.append(len(self.iu_review_dict[items[i]][x]))
                    item_c_len.append(len(self.iu_concept_dict[items[i]][x]))
                    if len(item_reviews) == self.args.dmax:
                        break
                item_list.append(item_reviews)
                item_concept_list.append(item_concepts)
                user_len.append(user_r_len)
                item_len.append(item_r_len)
                user_concept_len.append(user_c_len)
                item_concept_len.append(item_c_len)

        if(self.args.base_encoder!='Flat'):

            user_concept, user_concept_len = prep_hierarchical_data_list_new(user_concept_list, user_concept_len,
                                                                                              self.args.smax,
                                                                                              self.args.dmax)
            items_concept, item_concept_len = prep_hierarchical_data_list_new(item_concept_list, item_concept_len,
                                                                                              self.args.smax,
                                                                                              self.args.dmax)

            user, user_len = prep_hierarchical_data_list_new(user_list, user_len,
                                                                  self.args.smax,
                                                                  self.args.dmax)
            items, item_len = prep_hierarchical_data_list_new(item_list, item_len,
                                                                   self.args.smax,
                                                                   self.args.dmax)

        output = [user, user_len, items, item_len]

        output.append(user_concept)
        output.append(user_concept_len)
        output.append(items_concept)
        output.append(item_concept_len)
        
        if self.args.implicit == 1:
            output.append(user_idx)
            output.append(item_idx)
      
        gen_outputs = [x[3] for x in data]
        gen_len = [x[4] for x in data]

        output.append(gen_outputs)
        output.append(gen_len)
        output.append(labels)
        output = list(zip(*output))
        return output
    def infer(self, checkpoint_path):
        scores = []

        #data = self._prepare_set(self.test_rating_set, self.test_reviews)
        data = self._combine_reviews(self.test_rating_set, self.test_reviews)
        num_batches = int(len(data) / self.args.batch_size)

        mkdir_p(self.args.gen_dir)
        mkdir_p(self.args.gen_true_dir)

        model= Model(self.vocab, self.args, num_user = self.num_user, num_item = self.num_item)
        model.load_weights(checkpoint_path)
        data_len = len(data)

        gen_sentences = []
        ref_sentences = []

        for i in tqdm(range(0, num_batches+1)):
            batch = batchify(data, i, self.args.batch_size,
                                max_sample=data_len)

            if(len(batch)==0):
                    continue

            batch = self._prepare_set(batch)
            
            predicted_rating,dec_hidden, word_u, word_i=model(user_reviews, user_len, item_reviews, item_len, 
                                         user_concept, user_concept_len, item_concept, item_concept_len, userid, itemid,label)

            gen_results  = model._beam_search_infer(model.user_output, model.item_output, predicted_rating)
            gen_results = gen_results[0]

            for j in range(self.args.batch_size):
                if (self.args.batch_size * i + j)< data_len:

                    f = open(self.args.gen_dir + '/gen_review.'+str(self.args.batch_size * i + j)+'.txt', 'w+')
                    #for t in xrange(args.beamsize):
                    new_sentence = []
                    for k in range(len(gen_results[j*self.args.beam_size])):
                        if (gen_results[j*self.args.beam_size][k]==self.word_index[EOS]):
                            break
                        if k!=0:
                            f.write(' ')
                        f.write(self.index_word[gen_results[j*self.args.beam_size][k]])
                        tmp = self.index_word[gen_results[j*self.args.beam_size][k]].split(' ')
                        for l in range(len(tmp)):
                            new_sentence.append(tmp[l])
                        #new_sentence.append(self.index_word[gen_results[j*self.args.beam_size][k]])
                        #f.write(' '+str(ppls[j*args.beamsize+t]))
                        #f.write(' '+str(tags[j*args.beamsize+t]))
                        #f.write(' '+str(lens[j*args.beamsize+t]))
                        #f.write('\n')
                    f.close()
                    gen_sentences.append(new_sentence)
                    #f.write('\n')

                    #write truth tips
                    #truth_tip_batch = feed_dict[model.tip_inputs]
                    new_sentence = []
                    true_review = self.test_reviews[self.args.batch_size * i + j]
                    f1 = open(self.args.gen_true_dir + '/true_review.A.'+str(self.args.batch_size * i + j)+'.txt', 'w+')
                    for k in range(len(true_review)):
                        if (true_review[k]==self.word_index[EOS]):
                            break
                        if k==0:
                            continue
                        if k!=1:
                            f1.write(' ')
                        f1.write(self.index_word[true_review[k]])
                        tmp = self.index_word[true_review[k]].split(' ')
                        for l in range(len(tmp)):
                            new_sentence.append(tmp[l])
                    #f1.write('\n')
                    f1.close()
                    ref_sentences.append([new_sentence])

        print ('Infer finished!')
        score = corpus_bleu(ref_sentences, gen_sentences)
        print ('bleu score: {}'.format(score))

    def train(self):
        """ Main training loop
        """
        scores = []
        best_score = -1
        best_dev = -1
        best_epoch = -1
        counter = 0
        min_loss = 1e+7
        epoch_scores = {}
        self.eval_list = []
        data=list(zip(*self.data))
        a1,a2,a3,a4,a5,a6,a7,a8,a9,a10,a11,a12,a13=data
        dataset = tf.data.Dataset.from_tensor_slices((np.array(a1),np.array(a2),np.array(a3),np.array(a4),np.array(a5),
                                              np.array(a6),np.array(a7),np.array(a8),np.array(a9),
                                              np.array(a10),np.array(a11),np.array(a12),np.array(a13)))
        
        
        steps=len(a1)//self.args.batch_size 
        loader =dataset.batch(self.args.batch_size)
        
  
        
        total_loss = []
        total_loss_n = []
        total_loss_c = []
        total_loss_r = []
        print("Training Interactions={}".format(len(data)))
        #self.sess.run(tf.assign(self.mdl.is_train,self.mdl.true))
        optimizer = tf.keras.optimizers.Adam(0.01)
        checkpoint_path="training/cp-{epoch:04d}.ckpt"
        min_epoch_loss = float('-inf')
        for epoch in range(1, self.args.epochs):
            
            print('EPOCH:',epoch)
            all_att_dict = {}
            pos_val, neg_val = [],[]
            t0 = time.clock()
            self.write_to_file("=====================================")
            losses = []
            review_losses = []
            #random.shuffle(data)
            #num_batches = int(len(data) / self.args.batch_size)
            norms = []
            all_acc = 0
            review_acc = 0
            user_entropies = []
            item_entropies = []
            user_review_hits = []
            item_review_hits = []

            model = Model(self.vocab, self.args,
                              num_user = self.num_user, num_item = self.num_item)
            decoder = Decoder(self.vocab,self.args)
            epoch_loss=[]
            epoch_loss_r = []
            epoch_loss_n=[]
            epoch_loss_c=[]
                
            for user_reviews, user_len, item_reviews, item_len, user_concept, user_concept_len, item_concept, item_concept_len, userid, itemid, gen_output, gen_length, label in loader.take(steps):
                loss_n = 0
                key_loss = 0
                #max_review_length = tf.reduce_max(gen_length)
                #masks = tf.transpose(tf.sequence_mask(gen_length, maxlen=max_review_length, dtype=tf.float32), perm=[1,0])
                #print('Gen_Output:')
                #print(gen_output)
                with tf.GradientTape() as tape:
                  predicted_rating,dec_hidden, word_u, word_i=model(user_reviews, user_len, item_reviews, item_len, 
                                         user_concept, user_concept_len, item_concept, item_concept_len, userid, itemid,label)
                  #dec_input = tf.expand_dims([2] * self.args.batch_size, 1)
                  dec_input = tf.expand_dims(gen_output[:,0],1)
                  for t in range(1,gen_output.shape[1]):
                    predictions, dec_hidden = decoder(dec_input,dec_hidden)
                    loss_n += loss_function(gen_output[:,t],predictions)
                    key_loss+= _cal_key_loss(self.args, predictions, word_u, word_i)
                    #using teacher forcing
                    dec_input = tf.expand_dims(gen_output[:,t],1)
                  
                  #variables = model.trainable_variables + decoder.trainable_variables
                  
                  loss = tf.keras.losses.MSE(label, tf.reshape(predicted_rating, [-1]))
                  epoch_loss_r.append(loss/int(gen_output.shape[1]))
                  loss += loss_n+key_loss
                  epoch_loss.append(loss)
                  epoch_loss_n.append(loss_n/int(gen_output.shape[1]))
                  epoch_loss_c.append(key_loss/int(gen_output.shape[1]))
                  
                variables = model.trainable_variables + decoder.trainable_variables
                gradients = tape.gradient(loss, variables)
                #print('GRADIENTS')
                #print(gradients)
                optimizer.apply_gradients(zip(gradients,variables))
            #print()
            print()
            print('Total_loss in epoch', epoch-1)
            print(sum(epoch_loss))
            print('Cross entropy loss')
            print(sum(epoch_loss_r))
            print('Generation loss:')
            print(sum(epoch_loss_n))
            print('Key loss:')
            print(sum(epoch_loss_c))
            
            total_loss.append(sum(epoch_loss))
            total_loss_r.append(sum(epoch_loss_r))
            total_loss_n.append(sum(epoch_loss_n))
            total_loss_c.append(sum(epoch_loss_c))
            if sum(epoch_loss) < min_epoch_loss:
              model.save_weights(checkpoint_path.format(epoch=epoch-1))
              model.load_weights(checkpoint_path.format(epoch=epoch-1))
              min_epoch_loss=sum(epoch_loss)
        print(total_loss)
        plt.ioff()
        plt.plot(total_loss_r)
        plt.savefig("mseloss.png")
        plt.plot(total_loss)
        plt.savefig("totalloss.png")
        fig = plt.figure()
        plt.plot(total_loss_n)
        plt.savefig("crossentropylossreview.png")
        fig = plt.figure()        
        plt.plot(total_loss_c)
        plt.savefig("conceptrelevanceloss.png")
        fig = plt.figure()
               

In [None]:
c=CFExperiment(preprocessed=True)

Starting Rec Experiment
Setting up environment..
Selecting GPU no.0
True


# Run


In [None]:
c.train()