In [1]:
import re
import os
import pickle
import numpy as np
from torch.utils.data import Dataset
from utils.utils import newsample, getId2idx, tokenize, getVocab, my_collate
from data.configs.demo import config
from torch.utils.data import DataLoader

In [8]:
class MIND_bert(Dataset):
    """ Map Style Dataset for MIND, use bert tokenizer

    Args:
        config(dict): pre-defined dictionary of hyper parameters
        news_file(str): path of news_file
        behaviors_file(str): path of behaviors_file
        shuffle(bool): whether to shuffle the order of impressions
    """

    def __init__(self, config, news_file, behaviors_file, shuffle_pos=False):
        from transformers import BertTokenizerFast
        # initiate the whole iterator
        self.npratio = config.npratio
        self.shuffle_pos = shuffle_pos
        self.signal_length = config.signal_length
        self.his_size = config.his_size
        self.k = config.k
        pat = re.search('MIND/(.*_(.*)/)news', news_file)
        self.mode = pat.group(2)

        self.cache_path = '/'.join(['data/cache', config.embedding, pat.group(1)])
        self.behav_path = re.search('(\w*)\.tsv', behaviors_file).group(1)

        # if os.path.exists(self.cache_path + 'news.pkl'):
        #     with open(self.cache_path + 'news.pkl', 'rb') as f:
        #         news = pickle.load(f)
        #         for k,v in news.items():
        #             setattr(self, k, v)

        #     with open(self.cache_path + 'behaviors.pkl', 'rb') as f:
        #         behaviors = pickle.load(f)
        #         for k,v in behaviors.items():
        #             setattr(self, k, v)

        # else:
        try:
            os.makedirs(self.cache_path, exist_ok=True)
        except:
            pass

        self.news_file = news_file
        self.behaviors_file = behaviors_file
        self.col_spliter = '\t'

        self.max_news_length = 512
        self.max_his_size = 100

        # there are only two types of vocabulary
        self.tokenizer = BertTokenizerFast.from_pretrained(config.bert)
        # self.tokenizer.max_model_input_sizes[config.bert] = 10000ok

        self.nid2index = getId2idx(
            'data/dictionaries/nid2idx_{}_{}.json'.format(config.scale, self.mode))
        self.uid2index = getId2idx(
            'data/dictionaries/uid2idx_{}.json'.format(config.scale))

        self.init_news()
        self.init_behaviors()

    def init_news(self):
        """
            init news information given news file, such as news_title_array.
        """

        # VERY IMPORTANT!!! FIXME
        # The nid2idx dictionary must follow the original order of news in news.tsv

        documents = ['[PAD]'*self.max_news_length]

        with open(self.news_file, "r", encoding='utf-8') as rd:
            for idx in rd:
                nid, vert, subvert, title, ab, url, _, _ = idx.strip("\n").split(self.col_spliter)
                # concat all fields to form the document
                # try:
                #     self.tokenizer.tokenize(' '.join([title, ab, vert, subvert]))
                # except:
                #     print(' '.join([title, ab, vert, subvert]))
                documents.append(' '.join([title, ab, vert, subvert]))

        encoded_dict = self.tokenizer(documents, add_special_tokens=False, padding=True, truncation=True, max_length=self.max_news_length, return_tensors='np')
        self.encoded_news = encoded_dict.input_ids
        self.attn_mask = encoded_dict.attention_mask

        with open(self.cache_path + 'news.pkl', 'wb') as f:
            pickle.dump(
                {
                    'encoded_news': self.encoded_news,
                    'attn_mask': self.attn_mask
                },
                f
            )


    def init_behaviors(self):
        """
            init behavior logs given behaviors file.
        """
        # list of list of history news index
        histories = []
        # list of user index
        uindexes = []
        # list of list of history padding length
        his_sizes = []
        # list of impression indexes
        # self.impr_indexes = []

        impr_index = 0

        # only store positive behavior
        if self.mode == 'train':
            # list of list of his cdd news index along with its impression index
            imprs = []
            # dictionary of list of unhis cdd news index
            negatives = {}

            with open(self.behaviors_file, "r", encoding='utf-8') as rd:
                for idx in rd:
                    _, uid, time, history, impr = idx.strip("\n").split(self.col_spliter)

                    history = [self.nid2index[i] for i in history.split()]
                    his_sizes.append(len(history))

                    # tailor user's history or pad 0
                    history = history[:self.max_his_size] + [0] * (self.max_his_size - len(history))
                    impr_news = [self.nid2index[i.split("-")[0]] for i in impr.split()]
                    labels = [int(i.split("-")[1]) for i in impr.split()]
                    # user will always in uid2index
                    uindex = self.uid2index[uid]

                    # store negative samples of each impression
                    negative = []

                    for news, label in zip(impr_news, labels):
                        if label == 1:
                            imprs.append((impr_index, news))
                        else:
                            negative.append(news)

                    # 1 impression correspond to 1 of each of the following properties
                    histories.append(history)
                    negatives[impr_index] = negative
                    uindexes.append(uindex)

                    impr_index += 1

            self.imprs = imprs
            self.histories = histories
            self.his_sizes = his_sizes
            self.negatives = negatives
            self.uindexes = uindexes

            save_dict = {
                'imprs': self.imprs,
                'histories': self.histories,
                'his_sizes': self.his_sizes,
                'negatives': self.negatives,
                'uindexes': self.uindexes
            }

        # store every behavior
        elif self.mode == 'dev':
            # list of every cdd news index along with its impression index and label
            imprs = []

            with open(self.behaviors_file, "r", encoding='utf-8') as rd:
                for idx in rd:
                    _, uid, time, history, impr = idx.strip("\n").split(self.col_spliter)

                    history = [self.nid2index[i] for i in history.split()]
                    his_sizes.append(len(history))
                    history = history[:self.max_his_size] + [0] * (self.max_his_size - len(history))

                    impr_news = [self.nid2index[i.split("-")[0]] for i in impr.split()]
                    labels = [int(i.split("-")[1]) for i in impr.split()]
                    # user will always in uid2index
                    uindex = self.uid2index[uid]

                    # store every impression
                    imprs.append((impr_index, impr_news, labels))

                    # 1 impression correspond to 1 of each of the following properties
                    histories.append(history)
                    uindexes.append(uindex)

                    impr_index += 1

            self.imprs = imprs
            self.histories = histories
            self.his_sizes = his_sizes
            self.uindexes = uindexes

            save_dict = {
                'imprs': self.imprs,
                'histories': self.histories,
                'his_sizes': self.his_sizes,
                'uindexes': self.uindexes
            }

        # store every behavior
        elif self.mode == 'test':
            # list of every cdd news index along with its impression index and label
            imprs = []

            with open(self.behaviors_file, "r", encoding='utf-8') as rd:
                for idx in rd:
                    _, uid, time, history, impr = idx.strip("\n").split(self.col_spliter)

                    history = [self.nid2index[i] for i in history.split()]
                    his_sizes.append(len(history))
                    # tailor user's history or pad 0
                    history = history[:self.max_his_size] + [0] * (self.max_his_size - len(history))

                    impr_news = [self.nid2index[i] for i in impr.split()]
                    # user will always in uid2index
                    uindex = self.uid2index[uid]

                    # store every impression
                    imprs.append((impr_index, impr_news))

                    # 1 impression correspond to 1 of each of the following properties
                    histories.append(history)
                    uindexes.append(uindex)

                    impr_index += 1

            self.imprs = imprs
            self.histories = histories
            self.his_sizes = his_sizes
            self.uindexes = uindexes

            save_dict = {
                'imprs': self.imprs,
                'histories': self.histories,
                'his_sizes': self.his_sizes,
                'uindexes': self.uindexes
            }

        with open(self.cache_path + self.behav_path + '.pkl', 'wb') as f:
            pickle.dump(save_dict, f)


    def __len__(self):
        """
            return length of the whole dataset
        """
        return len(self.imprs)

    def __getitem__(self,index):
        """ return data
        Args:
            index: the index for stored impression

        Returns:
            back_dic: dictionary of data slice
        """

        impr = self.imprs[index] # (impression_index, news_index)
        impr_index = impr[0]
        impr_news = impr[1]


        user_index = [self.uindexes[impr_index]]

        # each time called to return positive one sample and its negative samples
        if self.mode == 'train':
            # user's unhis news in the same impression
            negs = self.negatives[impr_index]
            neg_list, neg_pad = newsample(negs, self.npratio)

            cdd_ids = [impr_news] + neg_list
            label = np.asarray([1] + [0] * self.npratio)

            if self.shuffle_pos:
                s = np.arange(0, len(label), 1)
                np.random.shuffle(s)
                cdd_ids = np.asarray(cdd_ids)[s]
                label = np.asarray(label)[s]

            label = np.arange(0, len(cdd_ids), 1)[label == 1][0]

            his_ids = self.histories[impr_index][:self.his_size]

            # true means the corresponding history news is padded
            his_mask = np.zeros((self.his_size), dtype=bool)
            his_mask[:self.his_sizes[impr_index]] = 1

            # pad in cdd
            # cdd_mask = [1] * neg_pad + [0] * (self.npratio + 1 - neg_pad)

            cdd_encoded_index = self.encoded_news[cdd_ids][:, :self.signal_length]
            his_encoded_index = self.encoded_news[his_ids][:, :self.signal_length]
            cdd_attn_mask = self.attn_mask[cdd_ids][:, :self.signal_length]
            his_attn_mask = self.attn_mask[his_ids][:, :self.signal_length]

            back_dic = {
                "user_index": np.asarray(user_index),
                # "cdd_mask": np.asarray(neg_pad),
                'cdd_id': np.asarray(cdd_ids),
                'his_id': np.asarray(his_ids),
                "cdd_encoded_index": cdd_encoded_index,
                "his_encoded_index": his_encoded_index,
                "cdd_attn_mask": cdd_attn_mask,
                "his_attn_mask": his_attn_mask,
                "his_mask": his_mask,
                "label": label
            }

            return back_dic

        # each time called return one sample, and no labels
        elif self.mode == 'dev':
            cdd_ids = impr_news

            his_ids = self.histories[impr_index][:self.his_size]

            user_index = [self.uindexes[impr_index]]
            label = impr[2]

            # true means the corresponding history news is padded
            his_mask = np.zeros((self.his_size), dtype=bool)
            his_mask[:self.his_sizes[impr_index]] = 1

            cdd_encoded_index = self.encoded_news[cdd_ids][:, :self.signal_length]
            his_encoded_index = self.encoded_news[his_ids][:, :self.signal_length]
            cdd_attn_mask = self.attn_mask[cdd_ids][:, :self.signal_length]
            his_attn_mask = self.attn_mask[his_ids][:, :self.signal_length]

            back_dic = {
                "impression_index": impr_index + 1,
                "user_index": np.asarray(user_index),
                'cdd_id': np.asarray(cdd_ids),
                'his_id': np.asarray(his_ids),
                "cdd_encoded_index": cdd_encoded_index,
                "his_encoded_index": his_encoded_index,
                "cdd_attn_mask": cdd_attn_mask,
                "his_attn_mask": his_attn_mask,
                "his_mask": his_mask,
                "labels": np.asarray([label])
            }
            return back_dic

        elif self.mode == 'test':
            cdd_ids = [impr_news]

            his_ids = self.histories[impr_index][:self.his_size]

            user_index = [self.uindexes[impr_index]]
            # true means the corresponding history news is padded
            his_mask = np.zeros((self.his_size), dtype=bool)
            his_mask[:self.his_sizes[impr_index]] = 1

            cdd_encoded_index = self.encoded_news[cdd_ids][:, :self.signal_length]
            his_encoded_index = self.encoded_news[his_ids][:, :self.signal_length]
            cdd_attn_mask = self.attn_mask[cdd_ids][:, :self.signal_length]
            his_attn_mask = self.attn_mask[his_ids][:, :self.signal_length]

            back_dic = {
                "impression_index": impr_index + 1,
                "user_index": np.asarray(user_index),
                'cdd_id': np.asarray(cdd_ids),
                'his_id': np.asarray(his_ids),
                "cdd_encoded_index": cdd_encoded_index,
                "his_encoded_index": his_encoded_index,
                "cdd_attn_mask": cdd_attn_mask,
                "his_attn_mask": his_attn_mask,
                "his_mask": his_mask
            }
            return back_dic

        else:
            raise ValueError("Mode {} not defined".format(self.mode))


In [11]:
config.embedding = 'bert'
config.signal_length = 20
config.his_size = 10
train_path = config.path + 'MIND/MINDdemo_dev/'
a = MIND_bert(config, train_path + 'news.tsv', train_path + 'behaviors.tsv')

In [12]:
a[0]

{'impression_index': 1,
 'user_index': array([1929]),
 'cdd_id': array([36180, 41328, 41034, 39776, 34983, 37322, 37327, 36307, 36185,
        36349, 38581, 39227, 37368, 33705, 39921, 39640, 36275, 38848,
         7014, 37943, 39086, 24209]),
 'his_id': array([ 7499, 34171, 15796, 32731, 32139, 18228,  1736, 30270, 35479,
        30833]),
 'cdd_encoded_index': array([[13240, 12134,  2000,  6701, 18466,  1010,  2655, 27056,  9674,
          1005,  1055,  4506,  1005, 21873,  1005,  2867,  2411,  6985,
          2037, 13220],
        [ 1045,  1005,  2310,  2042,  3015,  2055,  4714,  5014,  2005,
          1037,  2095,  1998,  2633,  2985,  1016,  6385,  1999,  1037,
          3998,  1011],
        [ 5448,  1024,  6972, 10556, 13699, 11795,  6799,  2003,  2055,
          2000,  2131,  2054,  2002, 17210,  1024,  1037,  3382,  1996,
          2203,  2089],
        [ 1996, 10556, 13639,  6182,  6962,  2227, 25748,  2058,  1005,
         16021,  6132, 13043,  1005,  2155,  2833,  2954,  19

In [7]:
from torch.utils.data.distributed import DistributedSampler
s = DistributedSampler(a, num_replicas=2, rank=0, shuffle=False)

In [18]:
loader_train = DataLoader(a, batch_size=1, pin_memory=False, num_workers=0, drop_last=False, shuffle=False, sampler=None)
record = next(iter(loader_train))

In [16]:
record

{'impression_index': [1],
 'user_index': tensor([[1929]]),
 'cdd_id': tensor([[36180, 41328, 41034, 39776, 34983, 37322, 37327, 36307, 36185, 36349,
          38581, 39227, 37368, 33705, 39921, 39640, 36275, 38848,  7014, 37943,
          39086, 24209]]),
 'his_id': tensor([[ 7499, 34171, 15796, 32731, 32139, 18228,  1736, 30270, 35479, 30833]]),
 'cdd_encoded_index': tensor([[[13240, 12134,  2000,  6701, 18466,  1010,  2655, 27056,  9674,  1005,
            1055,  4506,  1005, 21873,  1005,  2867,  2411,  6985,  2037, 13220],
          [ 1045,  1005,  2310,  2042,  3015,  2055,  4714,  5014,  2005,  1037,
            2095,  1998,  2633,  2985,  1016,  6385,  1999,  1037,  3998,  1011],
          [ 5448,  1024,  6972, 10556, 13699, 11795,  6799,  2003,  2055,  2000,
            2131,  2054,  2002, 17210,  1024,  1037,  3382,  1996,  2203,  2089],
          [ 1996, 10556, 13639,  6182,  6962,  2227, 25748,  2058,  1005, 16021,
            6132, 13043,  1005,  2155,  2833,  2954,  1999, 

In [8]:
loader_train = DataLoader(b, batch_size=config.batch_size, pin_memory=False, num_workers=0, drop_last=False, shuffle=False, collate_fn=my_collate, sampler=None)
record = next(iter(loader_train))

In [11]:
record['cdd_attn_mask'].shape, record['cdd_encoded_index'].shape

(torch.Size([10, 5, 20]), torch.Size([10, 5, 20]))