In [1]:
import re
import os
import pickle
import numpy as np
import logging
import torch
import math
from torch.utils.data import Dataset
from utils.utils import newsample, getId2idx, tokenize, getVocab, my_collate, Partition_Sampler
from data.configs.demo import config
from torch.utils.data import DataLoader
from collections import defaultdict

from transformers import BertTokenizer,BertModel
from utils.MIND import MIND

logger = logging.getLogger(__name__)

In [2]:
class MIND(Dataset):
    """ Map Style Dataset for MIND, use bert tokenizer

    Args:
        config(dict): pre-defined dictionary of hyper parameters
        news_file(str): path of news_file
        behaviors_file(str): path of behaviors_file
        shuffle(bool): whether to shuffle the order of impressions
    """

    def __init__(self, config, news_file, behaviors_file, shuffle_pos=False):
        # initiate the whole iterator
        self.npratio = config.npratio
        self.shuffle_pos = shuffle_pos
        self.signal_length = config.signal_length
        self.his_size = config.his_size
        self.impr_size = config.impr_size
        self.k = config.k
        self.order_history = config.order_history
        pat = re.search('MIND/(.*_(.*)/)news', news_file)
        self.mode = pat.group(2)

        self.cache_directory = '/'.join(['data/cache', config.embedding, pat.group(1)])
        self.behav_path = self.cache_directory + '{}/{}'.format(self.impr_size, re.search('(\w*\.)tsv', behaviors_file).group(1) + '.pkl')

        if os.path.exists(self.behav_path):
            logger.info('using cached user behavior from {}'.format(self.behav_path))
            with open(self.behav_path, 'rb') as f:
                behaviors = pickle.load(f)
                for k,v in behaviors.items():
                    setattr(self, k, v)

        else:
            if config.rank in [-1, 0]:
                logger.info("encoding user behaviors of {}...".format(behaviors_file))
                os.makedirs(self.cache_directory + str(self.impr_size), exist_ok=True)
                self.behaviors_file = behaviors_file
                self.nid2index = getId2idx('data/dictionaries/nid2idx_{}_{}.json'.format(config.scale, self.mode))
                self.uid2index = getId2idx('data/dictionaries/uid2idx_{}.json'.format(config.scale))

                self.init_behaviors()

        self.reducer = config.reducer

        if config.reducer == 'bm25':
            self.news_path = self.cache_directory + 'news_bm25.pkl'
            if os.path.exists(self.news_path):
                logger.info('using cached news tokenization from {}'.format(self.news_path))
                with open(self.news_path, 'rb') as f:
                    news = pickle.load(f)
                    for k,v in news.items():
                        setattr(self, k, v)
            else:
                if config.rank in [-1, 0]:
                    from transformers import BertTokenizerFast
                    from utils.utils import BM25
                    logger.info("encoding news of {}...".format(news_file))
                    self.news_file = news_file
                    self.max_news_length = 512
                    # there are only two types of vocabulary
                    self.tokenizer = BertTokenizerFast.from_pretrained(config.bert, cache=config.path + 'bert_cache/')
                    self.nid2index = getId2idx('data/dictionaries/nid2idx_{}_{}.json'.format(config.scale, self.mode))
                    self.init_news(reducer=BM25())

        else:
            self.news_path = self.cache_directory + 'news.pkl'
            if os.path.exists(self.news_path):
                logger.info('using cached news tokenization from {}'.format(self.news_path))
                with open(self.news_path, 'rb') as f:
                    news = pickle.load(f)
                    for k,v in news.items():
                        setattr(self, k, v)
            else:
                if config.rank in [-1, 0]:
                    from transformers import BertTokenizerFast
                    logger.info("encoding news of {}...".format(news_file))
                    self.news_file = news_file
                    self.max_news_length = 512
                    # there are only two types of vocabulary
                    self.tokenizer = BertTokenizerFast.from_pretrained(config.bert, cache=config.path + 'bert_cache/')
                    self.nid2index = getId2idx('data/dictionaries/nid2idx_{}_{}.json'.format(config.scale, self.mode))
                    self.init_news()


    def init_news(self, reducer=None):
        """
            init news information given news file, such as news_title_array.

        Args:
            bm25: whether to sort the terms by bm25 score
        """

        # VERY IMPORTANT!!! FIXME
        # The nid2idx dictionary must follow the original order of news in news.tsv

        documents = ['']

        with open(self.news_file, "r", encoding='utf-8') as rd:
            for idx in rd:
                nid, vert, subvert, title, ab, url, _, _ = idx.strip("\n").split('\t')
                documents.append(' '.join(['[CLS]', title, ab, vert, subvert]))

        if reducer:
            encoded_dict = self.tokenizer(documents, add_special_tokens=False, padding=True, truncation=True, max_length=self.max_news_length, return_tensors='np')
            self.encoded_news = encoded_dict.input_ids
            self.attn_mask = encoded_dict.attention_mask

            documents_sorted, attn_mask_sorted = reducer(self.encoded_news)
            print(attn_mask_sorted)

            self.encoded_news_sorted = documents_sorted
            self.attn_mask_sorted = attn_mask_sorted * self.attn_mask

            with open(self.news_path, 'wb') as f:
                pickle.dump(
                    {
                        'encoded_news': self.encoded_news,
                        'encoded_news_sorted': self.encoded_news_sorted,
                        'attn_mask': self.attn_mask,
                        'attn_mask_sorted': self.attn_mask_sorted
                    },
                    f
                )
        else:
            encoded_dict = self.tokenizer(documents, add_special_tokens=False, padding=True, truncation=True, max_length=self.max_news_length, return_tensors='np')
            self.encoded_news = encoded_dict.input_ids
            self.attn_mask = encoded_dict.attention_mask

            with open(self.news_path, 'wb') as f:
                pickle.dump(
                    {
                        'encoded_news': self.encoded_news,
                        'attn_mask': self.attn_mask
                    },
                    f
                )


    def init_behaviors(self):
        """
            init behavior logs given behaviors file.
        """
        # list of list of history news index
        histories = []
        # list of user index
        uindexes = []
        # list of impression indexes
        # self.impr_indexes = []

        impr_index = 0

        # only store positive behavior
        if self.mode == 'train':
            # list of lists, each list represents a
            imprs = []
            negatives = []

            with open(self.behaviors_file, "r", encoding='utf-8') as rd:
                for idx in rd:
                    _, uid, time, history, impr = idx.strip("\n").split('\t')

                    history = [self.nid2index[i] for i in history.split()]

                    impr_news = [self.nid2index[i.split("-")[0]] for i in impr.split()]
                    labels = [int(i.split("-")[1]) for i in impr.split()]

                    # user will always in uid2index
                    uindex = self.uid2index[uid]
                    # store negative samples of each impression
                    negative = []

                    for news, label in zip(impr_news, labels):
                        if label == 1:
                            imprs.append((impr_index, news))
                        else:
                            negative.append(news)

                    # 1 impression correspond to 1 of each of the following properties
                    histories.append(history)
                    negatives.append(negative)
                    uindexes.append(uindex)

                    impr_index += 1

            self.imprs = imprs
            self.histories = histories
            self.negatives = negatives
            self.uindexes = uindexes

            save_dict = {
                'imprs': self.imprs,
                'histories': self.histories,
                'negatives': self.negatives,
                'uindexes': self.uindexes
            }

        # store every behavior
        elif self.mode == 'dev':
            # list of every cdd news index along with its impression index and label
            imprs = []

            with open(self.behaviors_file, "r", encoding='utf-8') as rd:
                for idx in rd:
                    _, uid, time, history, impr = idx.strip("\n").split('\t')

                    history = [self.nid2index[i] for i in history.split()]

                    impr_news = [self.nid2index[i.split("-")[0]] for i in impr.split()]
                    labels = [int(i.split("-")[1]) for i in impr.split()]
                    # user will always in uid2index
                    uindex = self.uid2index[uid]

                    # store every impression
                    for i in range(0, len(impr_news), self.impr_size):
                        imprs.append((impr_index, impr_news[i:i+self.impr_size], labels[i:i+self.impr_size]))

                    # 1 impression correspond to 1 of each of the following properties
                    histories.append(history)
                    uindexes.append(uindex)

                    impr_index += 1

            self.imprs = imprs
            self.histories = histories
            self.uindexes = uindexes

            save_dict = {
                'imprs': self.imprs,
                'histories': self.histories,
                'uindexes': self.uindexes
            }

        # store every behavior
        elif self.mode == 'test':
            # list of every cdd news index along with its impression index and label
            imprs = []

            with open(self.behaviors_file, "r", encoding='utf-8') as rd:
                for idx in rd:
                    _, uid, time, history, impr = idx.strip("\n").split('\t')

                    history = [self.nid2index[i] for i in history.split()]

                    impr_news = [self.nid2index[i] for i in impr.split()]
                    # user will always in uid2index
                    uindex = self.uid2index[uid]

                    # store every impression
                    for i in range(0, len(impr_news), self.impr_size):
                        imprs.append((impr_index, impr_news[i:i+self.impr_size]))

                    # 1 impression correspond to 1 of each of the following properties
                    histories.append(history)
                    uindexes.append(uindex)

                    impr_index += 1

            self.imprs = imprs
            self.histories = histories
            self.uindexes = uindexes

            save_dict = {
                'imprs': self.imprs,
                'histories': self.histories,
                'uindexes': self.uindexes
            }

        with open(self.behav_path, 'wb') as f:
            pickle.dump(save_dict, f)


    def __len__(self):
        """
            return length of the whole dataset
        """
        return len(self.imprs)

    def __getitem__(self,index):
        """ return data
        Args:
            index: the index for stored impression

        Returns:
            back_dic: dictionary of data slice
        """

        impr = self.imprs[index] # (impression_index, news_index)
        impr_index = impr[0]
        impr_news = impr[1]


        user_index = [self.uindexes[impr_index]]

        # each time called to return positive one sample and its negative samples
        if self.mode == 'train':
            # user's unhis news in the same impression
            negs = self.negatives[impr_index]
            neg_list, neg_pad = newsample(negs, self.npratio)

            cdd_ids = [impr_news] + neg_list
            label = np.asarray([1] + [0] * self.npratio)

            if self.shuffle_pos:
                s = np.arange(0, len(label), 1)
                np.random.shuffle(s)
                cdd_ids = np.asarray(cdd_ids)[s]
                label = np.asarray(label)[s]

            label = np.arange(0, len(cdd_ids), 1)[label == 1][0]

            his_ids = self.histories[impr_index][:self.his_size]
            # true means the corresponding history news is padded
            his_mask = np.zeros((self.his_size), dtype=bool)
            his_mask[:len(his_ids)] = 1

            if self.order_history:
                his_ids = his_ids + [0] * (self.his_size - len(his_ids))
            else:
                his_ids = his_ids[::-1] + [0] * (self.his_size - len(his_ids))

            # pad in cdd
            # cdd_mask = [1] * neg_pad + [0] * (self.npratio + 1 - neg_pad)

            cdd_encoded_index = self.encoded_news[cdd_ids][:, :self.signal_length]
            cdd_attn_mask = self.attn_mask[cdd_ids][:, :self.signal_length]
            if self.reducer == 'bm25':
                his_encoded_index = self.encoded_news_sorted[his_ids][:, :self.k + 1]
                his_attn_mask = self.attn_mask_sorted[his_ids][:, :self.k + 1]
            else:
                his_encoded_index = self.encoded_news[his_ids][:, :self.signal_length]
                his_attn_mask = self.attn_mask[his_ids][:, :self.signal_length]
                his_attn_mask[:, :self.k+1] = 1

            back_dic = {
                "user_index": np.asarray(user_index),
                # "cdd_mask": np.asarray(neg_pad),
                'cdd_id': np.asarray(cdd_ids),
                'his_id': np.asarray(his_ids),
                "cdd_encoded_index": cdd_encoded_index,
                "his_encoded_index": his_encoded_index,
                "cdd_attn_mask": cdd_attn_mask,
                "his_attn_mask": his_attn_mask,
                "his_mask": his_mask,
                "label": label
            }

            return back_dic

        # each time called return one sample, and no labels
        elif self.mode == 'dev':
            cdd_ids = impr_news

            his_ids = self.histories[impr_index][:self.his_size]
            # true means the corresponding history news is padded
            his_mask = np.zeros((self.his_size), dtype=bool)
            his_mask[:len(his_ids)] = 1

            if self.order_history:
                his_ids = his_ids + [0] * (self.his_size - len(his_ids))
            else:
                his_ids = his_ids[::-1] + [0] * (self.his_size - len(his_ids))

            user_index = [self.uindexes[impr_index]]
            label = impr[2]

            cdd_encoded_index = self.encoded_news[cdd_ids][:, :self.signal_length]
            cdd_attn_mask = self.attn_mask[cdd_ids][:, :self.signal_length]
            if self.reducer == 'bm25':
                his_encoded_index = self.encoded_news_sorted[his_ids][:, :self.k + 1]
                his_attn_mask = self.attn_mask_sorted[his_ids][:, :self.k + 1]
            else:
                his_encoded_index = self.encoded_news[his_ids][:, :self.signal_length]
                his_attn_mask = self.attn_mask[his_ids][:, :self.signal_length]
                his_attn_mask[:, :self.k+1] = 1

            back_dic = {
                "impr_index": impr_index + 1,
                "user_index": np.asarray(user_index),
                'cdd_id': np.asarray(cdd_ids),
                'his_id': np.asarray(his_ids),
                "cdd_encoded_index": cdd_encoded_index,
                "his_encoded_index": his_encoded_index,
                "cdd_attn_mask": cdd_attn_mask,
                "his_attn_mask": his_attn_mask,
                "his_mask": his_mask,
                "label": np.asarray(label)
            }
            return back_dic

        elif self.mode == 'test':
            cdd_ids = impr_news

            his_ids = self.histories[impr_index][:self.his_size]
            # true means the corresponding history news is padded
            his_mask = np.zeros((self.his_size), dtype=bool)
            his_mask[:len(his_ids)] = 1

            if self.order_history:
                his_ids = his_ids + [0] * (self.his_size - len(his_ids))
            else:
                his_ids = his_ids[::-1] + [0] * (self.his_size - len(his_ids))

            user_index = [self.uindexes[impr_index]]

            cdd_encoded_index = self.encoded_news[cdd_ids][:, :self.signal_length]
            cdd_attn_mask = self.attn_mask[cdd_ids][:, :self.signal_length]
            if self.reducer == 'bm25':
                his_encoded_index = self.encoded_news_sorted[his_ids][:, :self.k + 1]
                his_attn_mask = self.attn_mask_sorted[his_ids][:, :self.k + 1]
            else:
                his_encoded_index = self.encoded_news[his_ids][:, :self.signal_length]
                his_attn_mask = self.attn_mask[his_ids][:, :self.signal_length]
                his_attn_mask[:, :self.k+1] = 1

            back_dic = {
                "impr_index": impr_index + 1,
                "user_index": np.asarray(user_index),
                'cdd_id': np.asarray(cdd_ids),
                'his_id': np.asarray(his_ids),
                "cdd_encoded_index": cdd_encoded_index,
                "his_encoded_index": his_encoded_index,
                "cdd_attn_mask": cdd_attn_mask,
                "his_attn_mask": his_attn_mask,
                "his_mask": his_mask
            }
            return back_dic

        else:
            raise ValueError("Mode {} not defined".format(self.mode))


In [3]:
config.reducer = 'bm25'
# config.reducer = 'matching'

path = config.path + 'MIND/MINDdemo_dev/'
a = MIND(config, path + 'news.tsv', path + 'behaviors.tsv')

[2021-08-19 22:06:27,570] INFO (__main__) using cached user behavior from data/cache/bert/MINDdemo_dev/10/behaviors..pkl
[2021-08-19 22:06:27,576] INFO (__main__) encoding news of ../../../Data/MIND/MINDdemo_dev/news.tsv...
[2021-08-19 22:06:47,171] INFO (utils.utils) computing BM25 scores...


[[1 0 0 ... 0 0 0]
 [1 1 1 ... 0 0 0]
 [1 1 1 ... 0 0 0]
 ...
 [1 1 1 ... 0 0 0]
 [1 1 1 ... 0 0 0]
 [1 1 1 ... 0 0 0]]


In [4]:
a[1]

{'impr_index': 1,
 'user_index': array([1929]),
 'cdd_id': array([38581, 39227, 37368, 33705, 39921, 39640, 36275, 38848,  7014,
        37943]),
 'his_id': array([36210,  1692, 30302, 26650, 12951, 30833, 35479, 30270,  1736,
        18228, 32139, 32731, 15796, 34171,  7499,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0]),
 'cdd_encoded_index': array([[  101,  1996,  9832,  2732,  1997,  2026,  2155,  1005,  1055,
         15060,  2795,  2009,  1005,  1055,  2524,  2000,  3342,  2166,
          2077, 22953, 21408,  3669,  8808,  5785, 16220, 10624,  2571,
          1012,  2030,  1010,  1038,  1012,  1038,  1012,  1039,  1012,
          1054,  1012,  1039,  1012,  1041,  1012,  1006,  2077, 22953,
         21408,  3669,  8808,  5785, 16220, 10624,  2571,  3690

In [7]:
a.attn_mask_sorted

array([[0, 0, 0, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0]])

In [8]:
t = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
t('')

In [4]:
loader1 = DataLoader(a, batch_size=1, pin_memory=False, num_workers=0, drop_last=False, shuffle=False)
records1 = list(loader1)

In [5]:
records1[0]['his_id'], records1[0]['his_mask'] 

(tensor([[36210,  1692, 30302, 26650, 12951, 30833, 35479, 30270,  1736, 18228,
          32139, 32731, 15796, 34171,  7499,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0]]),
 tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
           True,  True,  True,  True,  True, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False,
          False, False, False, False, False, False, False, False, False, False]]))

In [None]:
t = BertTokenizer.from_pretrained('bert-base-uncased')