In [None]:
# default_exp transforms.sampling

# Sampling
> Data sampling methods.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
from abc import *
from pathlib import Path
import pickle
import os
from tqdm import trange
from collections import Counter
import numpy as np
import pandas as pd

In [None]:
#export
def simple_negative_sampling(data,
                             num_negatives=4,
                             binarization=False,
                             feedback_column='RATING'):
  
  # Get a list of all Item IDs
  all_itemsIds = data['ITEMID'].unique()

  # Placeholders that will hold the data
  users, items, labels = [], [], []

  if binarization:
    data.loc[:,feedback_column] = 1

  user_item_set = set(zip(data['USERID'], data['ITEMID'], data[feedback_column]))

  for (u, i, r) in user_item_set:
    users.append(u)
    items.append(i)
    labels.append(r)
    for _ in range(num_negatives):
      # randomly select an item
      negative_item = np.random.choice(all_itemsIds) 
      # check that the user has not interacted with this item
      while (u, negative_item) in user_item_set:
          negative_item = np.random.choice(all_itemsIds)
      users.append(u)
      items.append(negative_item)
      labels.append(0) # items not interacted with are negative
  ns_data = pd.DataFrame(list(zip(users, items, labels)),
                         columns=['USERID','ITEMID',feedback_column])
  return ns_data

In [None]:
#export
class AbstractNegativeSampler(metaclass=ABCMeta):
    def __init__(self, train, val, test, user_count, item_count, sample_size, seed, flag, save_folder):
        self.train = train
        self.val = val
        self.test = test
        self.user_count = user_count
        self.item_count = item_count
        self.sample_size = sample_size
        self.seed = seed
        self.flag = flag
        self.save_path = save_path

    @classmethod
    @abstractmethod
    def code(cls):
        pass

    @abstractmethod
    def generate_negative_samples(self):
        pass

    def get_negative_samples(self):
        savefile_path = self._get_save_path()
        print("Negative samples don't exist. Generating.")
        seen_samples, negative_samples = self.generate_negative_samples()
        with savefile_path.open('wb') as f:
            pickle.dump([seen_samples, negative_samples], f)
        return seen_samples, negative_samples

    def _get_save_path(self):
        folder = Path(self.save_path)
        if not folder.is_dir():
            folder.mkdir(parents=True)
        # filename = '{}-sample_size{}-seed{}-{}.pkl'.format(
        #     self.code(), self.sample_size, self.seed, self.flag)
        filename = 'negative_samples_{}.pkl'.format(self.flag)
        return folder.joinpath(filename)

In [None]:
#export
class RandomNegativeSampler(AbstractNegativeSampler):
    @classmethod
    def code(cls):
        return 'random'

    def generate_negative_samples(self):
        assert self.seed is not None, 'Specify seed for random sampling'
        np.random.seed(self.seed)
        num_samples = 2 * self.user_count * self.sample_size
        all_samples = np.random.choice(self.item_count, num_samples) + 1

        seen_samples = {}
        negative_samples = {}
        print('Sampling negative items randomly...')
        j = 0
        for i in trange(self.user_count):
            user = i + 1
            seen = set(self.train[user])
            seen.update(self.val[user])
            seen.update(self.test[user])
            seen_samples[user] = seen

            samples = []
            while len(samples) < self.sample_size:
                item = all_samples[j % num_samples]
                j += 1
                if item in seen or item in samples:
                    continue
                samples.append(item)
            negative_samples[user] = samples

        return seen_samples, negative_samples

In [None]:
#export
class PopularNegativeSampler(AbstractNegativeSampler):
    @classmethod
    def code(cls):
        return 'popular'

    def generate_negative_samples(self):
        assert self.seed is not None, 'Specify seed for random sampling'
        np.random.seed(self.seed)
        popularity = self.items_by_popularity()
        items = list(popularity.keys())
        total = 0
        for i in range(len(items)):
            total += popularity[items[i]]
        for i in range(len(items)):
            popularity[items[i]] /= total
        probs = list(popularity.values())
        num_samples = 2 * self.user_count * self.sample_size
        all_samples = np.random.choice(items, num_samples, p=probs)

        seen_samples = {}
        negative_samples = {}
        print('Sampling negative items by popularity...')
        j = 0
        for i in trange(self.user_count):
            user = i + 1
            seen = set(self.train[user])
            seen.update(self.val[user])
            seen.update(self.test[user])
            seen_samples[user] = seen

            samples = []
            while len(samples) < self.sample_size:
                item = all_samples[j % num_samples]
                j += 1
                if item in seen or item in samples:
                    continue
                samples.append(item)
            negative_samples[user] = samples

        return seen_samples, negative_samples

    def items_by_popularity(self):
        popularity = Counter()
        self.users = sorted(self.train.keys())
        for user in self.users:
            popularity.update(self.train[user])
            popularity.update(self.val[user])
            popularity.update(self.test[user])

        popularity = dict(popularity)
        popularity = {k: v for k, v in sorted(popularity.items(), key=lambda item: item[1], reverse=True)}
        return popularity

In [None]:
#hide
!pip install -q watermark
%reload_ext watermark
%watermark -a "Sparsh A." -m -iv -u -t -d

Author: Sparsh A.

Last updated: 2021-12-18 09:51:57

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.4.104+
Machine     : x86_64
Processor   : x86_64
CPU cores   : 2
Architecture: 64bit

pandas : 1.1.5
IPython: 5.5.0
numpy  : 1.19.5

